import sys
from inspect import getfullargspec

import numpy as np

from tensorly import backend as T


def assert_array_equal(a, b, *args, **kwargs):
    np.testing.assert_array_equal(T.to_numpy(a), T.to_numpy(b), *args, **kwargs)


def assert_array_almost_equal(a, b, *args, **kwargs):
    np.testing.assert_array_almost_equal(T.to_numpy(a), T.to_numpy(b), *args, **kwargs)


def assert_allclose(
    actual,
    desired,
    rtol: float = 1e-07,
    atol: float = 0.0,
    equal_nan: bool = True,
    err_msg="",
    verbose: bool = True,
):
    """Check if two arrays are equal up to a given relevant and absolute tolerance.

    See the `NumPy documentation <https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_array_equal.html>`_ for more details.
    """
    np.testing.assert_allclose(
        T.to_numpy(actual),
        T.to_numpy(desired),
        rtol=rtol,
        atol=atol,
        equal_nan=equal_nan,
        err_msg=err_msg,
        verbose=verbose,
    )


def assert_equal(actual, desired, *args, **kwargs):
    def _tensor_to_numpy(x):
        if T.is_tensor(x):
            x = T.to_numpy(x)
            return x[0] if x.shape == (1,) else x
        return x

    np.testing.assert_equal(
        _tensor_to_numpy(actual), _tensor_to_numpy(desired), *args, **kwargs
    )


def _get_defaultkwargs(func):
    """Returns a dictionary containing all of the input function's arguments with default values."""
    argspec = getfullargspec(func)

    arguments = argspec.args
    defaults = argspec.defaults
    kwonlydefaults = argspec.kwonlydefaults
    if defaults is None:
        defaults = tuple()
    if kwonlydefaults is None:
        kwonlydefaults = {}

    start_defaults_idx = len(arguments) - len(defaults)
    arguments = arguments[start_defaults_idx:]
    default_args = {argument: default for argument, default in zip(arguments, defaults)}

    return {
        **default_args,
        **kwonlydefaults,
    }


def _get_decomposition_checker(supposed_kwargs, output_length):
    """Factory function whose output asserts that all entries in ``supposed_kwargs`` match entries in the kwargs-dictionary.

    This is a utility function used to automate testing of the object oriented interface.

    Arguments
    ---------
    supposed_kwargs : dict
        All keyword arguments that should be in the kwargs dict whenever the output function is called
        and their supposed value.
    output_length : int
        The number of outputs from the function

    Returns
    -------
    function
        Function that iterates over the supposed_kwarg dictionary and checks that each key and value
        matches those of the function call.
    """

    def decomposition_function(*args, **kwargs):
        for argument, supposed_default in supposed_kwargs.items():
            np.testing.assert_(
                argument in kwargs,
                "All arguments with a default must be passed as keyword argument when the decomposition class calls the decomposition function",
            )
            np.testing.assert_(kwargs[argument] == supposed_default)
        return [None for _ in range(output_length)]

    return decomposition_function


def assert_class_wrapper_correctly_passes_arguments(
    monkeypatch,
    decomposition_function,
    DecompositionClass,
    ignore_args=None,
    decomposition_output_length=2,
    **extra_args
):
    """Used to ensure that all arguments are passed correctly from the decomposition class to the decomposition function

    This code must be used in a test ran with the PyTest framework.

    Arguments:
    ----------
    monkeypatch : pytest.monkeypatch
        Monkeypatch fixture
    decomposition_function : Function
        Decomposition function wrapped by the class
    DecompositionClass : Class
        Class that wraps the function
    ignore_args : iterable
        List of arguments that shouldn't be checked
    decomposition_output_length : int
        Number of outputs from the decomposition function
    **extra_args
        Extra keyword-arguments passed to the decomposition class

    Example:
    --------

    Here is a simple example to check that the CP class' arguments match that of the parafac function.

    >>> from tensorly.decomposition import parafac, CP
    ... def test_cp(monkeypatch):
    ...     assert_class_wrapper_correctly_passes_arguments(monkeypatch, parafac, CP, ignore_args={'return_errors'}, rank=3)
    """
    kwargs = _get_defaultkwargs(decomposition_function)
    test_kwargs = {
        argument: "this_is_used_to_test_correct_passing_of_arguments"
        for argument in kwargs
    }
    if ignore_args is not None:
        for arg in ignore_args:
            del test_kwargs[arg]
    decomposition_checker = _get_decomposition_checker(
        test_kwargs, decomposition_output_length
    )

    decomposition_module = sys.modules[decomposition_function.__module__]
    monkeypatch.setattr(
        decomposition_module, decomposition_function.__name__, decomposition_checker
    )
    DecompositionClass(**extra_args, **test_kwargs).fit(None)


assert_ = np.testing.assert_
assert_raises = np.testing.assert_raises
