Source code for tftb.tests.test_base

#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2015 jaidev <jaidev@newton>
#
# Distributed under terms of the MIT license.

"""Base class for tests."""

import unittest
import numpy as np
from scipy import angle
from tftb.utils import is_linear
from skimage.measure import compare_ssim


[docs]class TestBase(unittest.TestCase):
[docs] def assert_is_linear(self, signal, decimals=5): """Assert that the signal is linear.""" self.assertTrue(is_linear(signal, decimals=decimals))
[docs] def assert_is_analytic(self, signal, amlaw=None): """Assert that signal is analytic.""" omega = angle(signal) if amlaw is not None: recons = np.exp(1j * omega) * amlaw else: recons = np.exp(1j * omega) real_identical = np.allclose(np.real(recons), np.real(signal)) imag_identical = np.allclose(np.imag(recons), np.imag(signal)) if not (imag_identical and real_identical): raise AssertionError("Signal is not analytic.")
[docs] def assert_is_concave(self, signal): second_derivative = np.diff(np.diff(signal)) if not np.all(second_derivative < 0): raise AssertionError("Signal is not concave.")
[docs] def assert_is_convex(self, signal): second_derivative = np.diff(np.diff(signal)) if not np.all(second_derivative > 0): raise AssertionError("Signal is not convex.")
[docs] def assert_is_monotonic_increasing(self, signal): derivative = np.diff(signal) if not np.all(derivative >= 0): raise AssertionError("Signal is not monotonically increasing.")
[docs] def assert_is_monotonic_decreasing(self, signal): derivative = np.diff(signal) if not np.all(derivative <= 0): raise AssertionError("Signal is not monotonically decreasing.")
[docs] def assert_is_hermitian(self, x): """Assert that the input is a Hermitian matrix.""" conj_trans = np.conj(x).T np.testing.assert_allclose(x, conj_trans)
[docs] def assert_tfr_equal(self, x, y, sqmod=True, threshold=0.05, tol=0.99): """Assert that TFRs x and y are qualitatively equivalent.""" if sqmod: x = np.abs(x) ** 2 y = np.abs(y) ** 2 x_thresh = np.amax(x) * threshold x[x <= x_thresh] = 0.0 y_thresh = np.amax(y) * threshold y[y <= y_thresh] = 0.0 x = np.ascontiguousarray(x) y = np.ascontiguousarray(y) similarity = compare_ssim(x, y) self.assertTrue(similarity >= tol)