#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2015 jaidev <jaidev@newton>
#
# Distributed under terms of the MIT license.
"""
Base time-frequency representation class.
"""
import numpy as np
import matplotlib.pyplot as plt
[docs]class BaseTFRepresentation(object):
isaffine = False
[docs] def __init__(self, signal, **kwargs):
"""Create a base time-frequency representation object.
:param signal: Signal to be analyzed.
:param **kwargs: Other arguments required for performing the analysis.
:type signal: array-like
:return: BaseTFRepresentation object
:rtype:
"""
if (signal.ndim == 2) and (1 in signal.shape):
signal = signal.ravel()
self.signal = signal
timestamps = kwargs.get('timestamps')
if timestamps is None:
timestamps = np.arange(signal.shape[0])
self.ts = self.timestamps = timestamps
n_fbins = kwargs.get('n_fbins')
if n_fbins is None:
n_fbins = signal.shape[0]
self.n_fbins = n_fbins
fwindow = kwargs.get('fwindow')
if fwindow is None:
fwindow = self._make_window()
self.fwindow = fwindow
if self.n_fbins % 2 == 0:
freqs = np.hstack((np.arange(self.n_fbins / 2),
np.arange(-self.n_fbins / 2, 0)))
else:
freqs = np.hstack((np.arange((self.n_fbins - 1) / 2),
np.arange(-(self.n_fbins - 1) / 2, 0)))
self.freqs = freqs.astype(float) / self.n_fbins
self.tfr = np.zeros((self.n_fbins, self.ts.shape[0]), dtype=complex)
def _get_spectrum(self):
if not self.isaffine:
return np.fft.fftshift(np.abs(np.fft.fft(self.signal)) ** 2)
nf2 = self.tfr.shape[0]
spec = np.abs(np.fft.fft(self.signal[self.ts.min():(self.ts.max() + 1)],
2 * nf2)) ** 2
return spec[:nf2]
def _make_window(self):
"""Make a Hamming window function.
The window function has a length equal to quarter of the length of the
input signal.
:return: Hamming window function.
:rtype: array-like
"""
h = np.floor(self.n_fbins / 4.0)
h += 1 - np.remainder(h, 2)
from scipy import hamming
fwindow = hamming(int(h))
# No need to normalize the window
# fwindow = fwindow / np.linalg.norm(fwindow)
return fwindow
def _plot_tfr(self, ax, kind, extent, contour_x=None, contour_y=None,
levels=None, show_tf=True, cmap=plt.cm.gray):
if kind == "cmap":
ax.imshow(self.tfr, cmap=cmap, origin="bottomleft", extent=extent,
aspect='auto')
elif kind == "contour":
if contour_x is None:
contour_x = self.ts
if contour_y is None:
if show_tf:
if self.isaffine:
contour_y = np.linspace(self.fmin, self.fmax, self.n_voices)
else:
contour_y = np.linspace(0, 0.5, self.signal.shape[0])
else:
contour_y = np.linspace(0, 0.5, self.tfr.shape[0])
contour_x, contour_y = np.meshgrid(contour_x, contour_y)
if levels is not None:
ax.contour(contour_x, contour_y, self.tfr, levels)
else:
if self.isaffine:
maxi = np.amax(self.tfr)
mini = max(np.amin(self.tfr), maxi * self._viz_threshold)
levels = np.linspace(mini, maxi, 65)
ax.contour(contour_x, contour_y, self.tfr, levels=levels)
else:
ax.contour(contour_x, contour_y, self.tfr)
def _annotate_tfr(self, ax):
ax.grid(True)
ax.set_xlabel("Time")
ax.set_ylabel("Normalized Frequency")
ax.set_title(self.name.upper())
ax.yaxis.set_label_position("right")
def _plot_signal(self, ax):
ax.plot(np.real(self.signal))
def _annotate_signal(self, ax):
ax.set_xticklabels([])
ax.set_xlim(0, self.signal.shape[0])
ax.set_ylabel('Real part')
ax.set_title('Signal in time')
ax.grid(True)
def _plot_spectrum(self, ax, freq_x, freq_y):
k = int(np.floor(self.signal.shape[0] / 2.0))
if freq_x is None:
freq_x = self._get_spectrum()[::-1][:k]
if freq_y is None:
if self.isaffine:
freq_y = self.freqs
else:
freq_y = np.arange(k)
ax.plot(freq_x, freq_y)
if not self.isaffine:
ax.set_ylim(0, freq_y.shape[0] - 1)
else:
ax.set_ylim(freq_y[0], freq_y[-1])
def _annotate_spectrum(self, ax):
ax.set_ylabel('Spectrum')
ax.set_yticklabels([])
ax.set_xticklabels([])
ax.grid(True)
if not self.isaffine:
ax.invert_xaxis()
ax.invert_yaxis()
[docs] def plot(self, ax=None, kind='cmap', show=True, default_annotation=True,
show_tf=False, scale="linear", threshold=0.05, **kwargs):
"""Visualize the time frequency representation.
:param ax: Axes object to draw the plot on.
:param kind: One of "cmap" (default), "contour".
:param show: Whether to call ``plt.show()``.
:param default_annotation: Whether to make default annotations for the
plot. Default annotations consist of setting the X and Y axis labels to
"Time" and "Normalized Frequency" respectively, and setting the title
to the name of the particular time-frequency distribution.
:param show_tf: Whether to show the signal and it's spectrum alongwith
the plot. In this is True, the ``ax`` argument is ignored.
:param **kwargs: Parameters to be passed to the plotting function.
:type ax: matplotlib.axes.Axes object
:type kind: str
:type show: bool
:type default_annotation: bool
:return: None
:rtype: None
"""
self._viz_threshold = threshold
extent = kwargs.pop('extent', None)
if extent is None:
extent = [self.ts.min(), self.ts.max(), self.freqs.min(),
self.freqs.max()]
contour_x = kwargs.pop('contour_x', None)
contour_y = kwargs.pop('contour_y', None)
levels = kwargs.pop('levels', None)
freq_x = kwargs.pop('freq_x', None)
freq_y = kwargs.pop('freq_y', None)
cmap = kwargs.pop("cmap", plt.cm.gray)
if show_tf:
fig, axTF = plt.subplots(figsize=(10, 8))
self._plot_tfr(axTF, kind, extent, contour_x, contour_y, levels,
show_tf, cmap=cmap)
from mpl_toolkits.axes_grid1 import make_axes_locatable
divider = make_axes_locatable(axTF)
axTime = divider.append_axes("top", 1.2, pad=0.5)
self._plot_signal(axTime)
axSpec = divider.append_axes("left", 1.2, pad=0.5)
self._plot_spectrum(axSpec, freq_x, freq_y)
if default_annotation:
self._annotate_tfr(axTF)
self._annotate_signal(axTime)
self._annotate_spectrum(axSpec)
else:
if (ax is None) and (kind != "surf"):
fig = plt.figure()
ax = fig.add_subplot(111)
if kind == "cmap":
ax.imshow(self.tfr,
aspect='auto', origin='bottomleft', extent=extent,
**kwargs)
elif kind == "surf":
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.gca(projection="3d")
x = np.arange(self.signal.shape[0])
y = np.linspace(0, 0.5, self.signal.shape[0])
X, Y = np.meshgrid(x, y)
ax.plot_surface(X, Y, np.abs(self.tfr), cmap=plt.cm.jet)
if default_annotation:
ax.set_zlabel("Amplitude")
elif kind == "wireframe":
from mpl_toolkits.mplot3d import Axes3D # NOQA
ax = fig.gca(projection="3d")
x = np.arange(self.signal.shape[0])
y = np.linspace(0, 0.5, self.signal.shape[0])
X, Y = np.meshgrid(x, y)
ax.plot_wireframe(X, Y, np.abs(self.tfr), cmap=plt.cm.jet,
rstride=3, cstride=3)
else:
t, f = np.meshgrid(self.ts, np.linspace(0, 0.5, self.tfr.shape[0]))
ax.contour(t, f, self.tfr, **kwargs)
if default_annotation:
grid = kwargs.get('grid', True)
ax.grid(grid)
ax.set_xlabel("Time")
ax.set_ylabel("Normalized Frequency")
ax.set_title(self.name.upper())
if show:
plt.show()