"""Unit test for module."""

####################################################################
############################ INFO ##################################
####################################################################

### Doc ###

# to be used with pytest

####################################################################
####################### OUTSIDE SCOPE ##############################
####################################################################

### IMPORTS
import torch
from torch import nn
import numpy as np

from topobstruction.modules.module import Shallow
from topobstruction.modules.module import MLP
from topobstruction.modules.module import generate_randDoublyStochastic

### HYPERPARAMETERS

### CLASSES

### FUNCTIONS
def test_trivialTrue():
    """Test that True is True."""
    assert True

# def test_trivialFalse():
#     """Test that False is False."""
#     assert False


def test_shallow():
    """Test the Shallow class initialization."""
    model = Shallow(n_hidden=3, n_input=2)
    assert model.forward(torch.randn(1, 2)).shape==(1, 1)

def test_computesquaredlayerNorms():
    """Test the compute_squaredlayerNorms method."""
    model = Shallow(n_hidden=3, n_input=2)
    res = model.compute_squaredlayerNorms()
    assert len(res)==2
    assert len(res[0])==3
    assert len(res[1])==3
    assert res[0][0].shape==(2,)  # output norm input layer
    assert res[0][1].shape==(3,)  # output norm hidden layer
    assert res[1][1].shape == (3,)  # input norm hidden layer
    assert res[1][2].shape == (1,)  # input norm output layer

def test_computehyperbolae():
    """Test the compute_hyperbolae method."""
    model = Shallow(n_hidden=3, n_input=2)
    res = model.compute_hyperbolae()
    assert len(res)==3
    assert res[0] is None
    assert res[2] is None
    assert res[1].shape == (3,)

def test_computehyperbolaeReturnNorms():
    """Test the compute_hyperbolae method with return_norms=True."""
    model = Shallow(n_hidden=3, n_input=2)
    _, l_output_norms, l_input_norms = model.compute_hyperbolae(return_norms=True)
    res2 = model.compute_squaredlayerNorms()
    for i, a in enumerate(l_output_norms):
        if isinstance(a, np.ndarray):
            assert (a==res2[0][i]).all()
        else:
            assert a==res2[0][i]
    for i, a in enumerate(l_input_norms):
        if isinstance(a, np.ndarray):
            assert (a==res2[1][i]).all()
        else:
            assert a==res2[1][i]

def test_rescaling():
    """Test the rescaling method."""
    model = Shallow(n_hidden=3, n_input=2)
    c_rand = np.random.rand(3,)
    alpha_rand = model.compute_alpha(c_rand)
    model.rescale_hidden(alpha_rand)
    assert np.isclose(model.compute_hyperbolae()[1], c_rand).all()

def test_mlpComputeSquareNorms():
    """Test the compute_squaredlayerNorms method of MLP class."""
    model = MLP(2, [3,4,1], bias=False, activation_layer=nn.ReLU)
    l_params = list(model.parameters())
    idx_rand = int(np.random.randint(0, len(l_params)))
    l_params[idx_rand].data = np.sqrt(torch.ones_like(l_params[idx_rand].data
                                                      )/(l_params[idx_rand].data.shape[1]))
    assert np.isclose(model.compute_squaredlayerNorms()[1][idx_rand+1], 1).all()

def test_createBoundaryShallow():
    """test B for shallow [2(input), 3, 1(output)]"""
    B = np.array([
        [1,1,0,0,0,0,-1,0,0,],
        [0,0,1,1,0,0,0,-1,0,],
        [0,0,0,0,1,1,0,0,-1,],
        ])
    model = MLP(in_channels=2, hidden_channels=[3,1]
            , bias=False, skip_connections=None
            )
    assert (model.B==torch.tensor(B)).all().item()

def test_createBoundaryMLP():
    """test B for MLP [2(input),3,4,1(output)]"""
    B = np.array([
        [1,1,0,0,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,0,0,0,0],
        [0,0,1,1,0,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,0,0,0],
        [0,0,0,0,1,1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,0,0],
        [0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,-1,0,0,0],
        [0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,-1,0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,-1,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,-1],
        ])
    model = MLP(in_channels=2, hidden_channels=[3,4,1]
            , bias=False, skip_connections=None
            )
    assert (model.B==torch.tensor(B)).all().item()

def test_createBoundaryMLPWithBias():
    """Test B for MLP [2(input),2,1(output)] with biases."""
    B = torch.tensor([
        [1,1,0,0,-1,0,1,0],
        [0,0,1,1,0,-1,0,1]
    ])
    model = MLP(in_channels=2, hidden_channels=[2,1], bias=True)
    assert (model.B == B).all().item()

def test_createBoundaryMLPWithSingleSkips():
    """Test B for MLP [1(input), 2, 1, 2, 1(output)] with 1 skip connections."""
    B = torch.tensor([
        [1, 0,-1, 0, 0, 0, 0, 0,-1, 0],
        [0, 1, 0,-1, 0, 0, 0, 0, 0,-1],
        [0, 0, 1, 1,-1,-1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0,-1, 0, 1, 0],
        [0, 0, 0, 0, 0, 1, 0,-1, 0, 1],
    ])
    model = MLP(in_channels=1, hidden_channels=[2, 1, 2, 1], bias=False, skip_connections=[(0, 2)])
    assert (model.B == B).all().item()

def test_createBoundaryMLPWithMultipleSkips():
    """Test B for MLP [1(input), 2, 3, 2, 3, 1(output)] with 2 skip connections."""
    B_residual = torch.tensor([ [-1, 0, 0, 0, 0],
                                [ 0,-1, 0, 0, 0],
                                [ 0, 0,-1, 0, 0],
                                [ 0, 0, 0,-1, 0],
                                [ 0, 0, 0, 0,-1],
                                [ 1, 0, 0, 0, 0],
                                [ 0, 1, 0, 0, 0],
                                [ 0, 0, 1, 0, 0],
                                [ 0, 0, 0, 1, 0],
                                [ 0, 0, 0, 0, 1],
    ])
    model = MLP(in_channels=1, hidden_channels=[2,3,2,3,1], bias=False
            , skip_connections=[(0,2),(1,3)])
    assert (model.B[:, -model.n_skip_connections:] == B_residual).all().item()
    assert (model.B[:, -5:] == B_residual).all().item()

def test_createBoundaryWithSkipsFarApart():
    """Test B for MLP [1(input), 2, 1, 1, 2, 1(output)]
    with 1 skip connection between the hidden layer having 2 neurons."""
    B = torch.tensor([  [1, 0,-1, 0, 0, 0, 0, 0, 0,-1, 0],
                    [0, 1, 0,-1, 0, 0, 0, 0, 0, 0,-1],
                    [0, 0, 1, 1,-1, 0, 0, 0, 0, 0, 0],
                    [0, 0, 0, 0, 1,-1,-1, 0, 0, 0, 0],
                    [0, 0, 0, 0, 0, 1, 0,-1, 0, 1, 0],
                    [0, 0, 0, 0, 0, 0, 1, 0,-1, 0, 1],
])

    model = MLP(in_channels=1, hidden_channels=[2,1,1,2,1], bias=False
                , skip_connections=[(0, 3)])

    assert (model.B == B).all().item()

def test_createBoundaryWithSkipsAndBias():
    """Test B for MLP [1(input), 2, 2, 2, 1(output)]
    with 1 skip connection and biases."""
    B = torch.tensor([[1, 0,-1, 0,-1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,-1, 0],
                  [0, 1, 0,-1, 0,-1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,-1],
                  [0, 0, 1, 1, 0, 0,-1, 0,-1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 1, 1, 0,-1, 0,-1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 1, 1, 0, 0,-1, 0, 0, 0, 0, 0, 1, 0, 1, 0],
                  [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,-1, 0, 0, 0, 0, 0, 1, 0, 1],
                #   [1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17,18,19,20]
                ])

    model = MLP(in_channels=1, hidden_channels=[2,2,2,1], bias=True
                ,skip_connections=[(0,2)])
    assert (model.B == B).all().item()

def test_randomDoublyStochastic():
    """Test the random_doubly_stochastic method."""
    M = generate_randDoublyStochastic(100)
    assert M.sum(axis=0).all() == 1
    assert M.sum(axis=1).all() == 1

def test_sendToCone():
    """Test the send_to_cone method."""
    norm_target = 0.1
    model = MLP(in_channels=2, hidden_channels=[4,4,4,1]
                , bias=True
                , skip_connections=[(0,2)])
    model.send_toCone(norm_target=norm_target)

    l = model.compute_hyperbolae()
    assert np.array([np.isclose(a,0, atol=1e-4).all() for a in l if a is not None]).all()

    l_output_norms, l_input_norms = model.compute_squaredlayerNorms()
    assert np.isclose(np.array(l_output_norms[1:-1]), norm_target, atol=1e-4).all()
    assert np.isclose(np.array(l_input_norms[1:-1]), norm_target, atol=1e-4).all()
