# Authors: Lukas Gemein <l.gemein@gmail.com>
#
#
# License: BSD-3

import torch
import numpy as np

from braindecode.models.tcn import TCN
from braindecode.util import set_random_seeds


def test_tcn():
    set_random_seeds(0, False)
    tcn = TCN(
        n_in_chans=21,
        n_outputs=2,
        n_filters=55,
        n_blocks=5,
        kernel_size=16,
        drop_prob=0.05270154233150525,
        add_log_softmax=True
    )
    # braindecode models are always in eval mode after initialization
    # original model implementation was not
    tcn.train()
    x = torch.rand(1, 21, 1000, 1)
    out = tcn(x)
    # this is the output of the original model implementation using the same
    # initialization arguments as above
    expected = np.array(
        [[[-0.5504, -0.5304, -0.6023, -0.5231, -0.5387, -0.5522, -0.5323,
           -0.5540, -0.5297, -0.5333, -0.5743, -0.5330, -0.5117, -0.5051,
           -0.5523, -0.5507, -0.5724, -0.5380, -0.5697, -0.4871, -0.5400,
           -0.4986, -0.5502, -0.5524, -0.5263, -0.5440, -0.5464, -0.5005,
           -0.5404, -0.5098, -0.5197, -0.5578, -0.5419, -0.5601, -0.5031,
           -0.5616, -0.5205, -0.5378, -0.5472, -0.4897, -0.5216, -0.5560,
           -0.5480, -0.5488, -0.5258, -0.5637, -0.5318, -0.5134, -0.5460,
           -0.5294, -0.5513, -0.5310, -0.5307, -0.5326, -0.5270, -0.5156,
           -0.5569, -0.5416, -0.5279, -0.5553, -0.5589, -0.5166, -0.5108,
           -0.5076, -0.5279, -0.5208, -0.5367, -0.5557, -0.5690, -0.5494],
          [-0.8597, -0.8877, -0.7931, -0.8982, -0.8758, -0.8573, -0.8849,
           -0.8549, -0.8887, -0.8834, -0.8280, -0.8839, -0.9150, -0.9250,
           -0.8572, -0.8593, -0.8305, -0.8769, -0.8340, -0.9530, -0.8741,
           -0.9350, -0.8600, -0.8570, -0.8935, -0.8685, -0.8652, -0.9319,
           -0.8735, -0.9179, -0.9031, -0.8497, -0.8714, -0.8466, -0.9280,
           -0.8447, -0.9019, -0.8771, -0.8640, -0.9489, -0.9003, -0.8521,
           -0.8630, -0.8619, -0.8942, -0.8419, -0.8856, -0.9124, -0.8658,
           -0.8890, -0.8585, -0.8867, -0.8872, -0.8845, -0.8924, -0.9092,
           -0.8509, -0.8718, -0.8912, -0.8531, -0.8482, -0.9077, -0.9163,
           -0.9212, -0.8912, -0.9016, -0.8787, -0.8525, -0.8349,
           -0.8611]]])
    np.testing.assert_allclose(
        out.detach().numpy(), expected, rtol=1e-3, atol=1e-3)
