Source code for tftb.tests.test_base
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2015 jaidev <jaidev@newton>
#
# Distributed under terms of the MIT license.
"""Base class for tests."""
import unittest
import numpy as np
from scipy import angle
from tftb.utils import is_linear
from skimage.measure import compare_ssim
[docs]class TestBase(unittest.TestCase):
[docs] def assert_is_linear(self, signal, decimals=5):
"""Assert that the signal is linear."""
self.assertTrue(is_linear(signal, decimals=decimals))
[docs] def assert_is_analytic(self, signal, amlaw=None):
"""Assert that signal is analytic."""
omega = angle(signal)
if amlaw is not None:
recons = np.exp(1j * omega) * amlaw
else:
recons = np.exp(1j * omega)
real_identical = np.allclose(np.real(recons), np.real(signal))
imag_identical = np.allclose(np.imag(recons), np.imag(signal))
if not (imag_identical and real_identical):
raise AssertionError("Signal is not analytic.")
[docs] def assert_is_concave(self, signal):
second_derivative = np.diff(np.diff(signal))
if not np.all(second_derivative < 0):
raise AssertionError("Signal is not concave.")
[docs] def assert_is_convex(self, signal):
second_derivative = np.diff(np.diff(signal))
if not np.all(second_derivative > 0):
raise AssertionError("Signal is not convex.")
[docs] def assert_is_monotonic_increasing(self, signal):
derivative = np.diff(signal)
if not np.all(derivative >= 0):
raise AssertionError("Signal is not monotonically increasing.")
[docs] def assert_is_monotonic_decreasing(self, signal):
derivative = np.diff(signal)
if not np.all(derivative <= 0):
raise AssertionError("Signal is not monotonically decreasing.")
[docs] def assert_is_hermitian(self, x):
"""Assert that the input is a Hermitian matrix."""
conj_trans = np.conj(x).T
np.testing.assert_allclose(x, conj_trans)
[docs] def assert_tfr_equal(self, x, y, sqmod=True, threshold=0.05, tol=0.99):
"""Assert that TFRs x and y are qualitatively equivalent."""
if sqmod:
x = np.abs(x) ** 2
y = np.abs(y) ** 2
x_thresh = np.amax(x) * threshold
x[x <= x_thresh] = 0.0
y_thresh = np.amax(y) * threshold
y[y <= y_thresh] = 0.0
x = np.ascontiguousarray(x)
y = np.ascontiguousarray(y)
similarity = compare_ssim(x, y)
self.assertTrue(similarity >= tol)