import unittest

import numpy as np
import torch

from main import prepare_model_input
from model.causal_inducer import CausalInducer
from model.discovery_result import DiscoveryResult


class TestLosses(unittest.TestCase):

    def test_prepare_input(self):
        num_sample = 100
        num_nodes = 10
        data = torch.randn((1, num_sample, num_nodes))
        target = torch.tensor(np.arange(num_nodes)).unsqueeze(0)
        device = 'cpu'
        start_t = 42
        data_new, target_new, shifted_target = prepare_model_input(data, target, device, start_t)
        self.assertTrue(torch.all(data == data_new).item())
        self.assertTrue(torch.all(target == target_new).item())
        exp_shift_target = torch.tensor(np.arange(num_nodes) - 1).unsqueeze(0)
        exp_shift_target[0, 0] = start_t
        self.assertTrue(torch.all(shifted_target == exp_shift_target).item())
        self.assertEqual(data_new.device.type, device)
        self.assertEqual(target_new.device.type, device)
        self.assertEqual(shifted_target.device.type, device)

    def test_attention_mask(self):
        num_nodes = 2
        model = CausalInducer(num_nodes, 1, 8, 8, 8, 8, 8)
        mask = model._create_look_ahead_mask('cpu')
        expected_mask = torch.tensor([[0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]])
        self.assertTrue(torch.all(mask == expected_mask).item())

    def test_mask_works(self):
        num_sample = 100
        num_nodes = 10
        data = torch.randn((1, num_sample, num_nodes))
        model = CausalInducer(num_nodes, 1, 8, 8, 8, 8, 8)
        mask = model._create_look_ahead_mask('cpu')
        un_masked = torch.zeros_like(mask)
        seq_zeros = torch.zeros((1, num_nodes ** 2))
        seq_one = torch.ones_like(seq_zeros)
        seq_one[0, 0] = 0
        mem = model.encoder(data)
        self.assertEqual(model.decoder(seq_zeros, mem, mask)[0, 0], model.decoder(seq_one, mem, mask)[0, 0])
        self.assertNotEqual(model.decoder(seq_zeros, mem, un_masked)[0, 0],
                            model.decoder(seq_one, mem, un_masked)[0, 0])

    def test_predict_autoregressive(self):
        num_sample = 100
        num_nodes = 10
        model = MockCausalInducer(num_nodes)
        data = torch.randn((1, num_sample, num_nodes))
        model.start_token = 1
        output = model.predict_autoregressive(data)
        self.assertTrue(torch.all(output.output == 1))
        model.start_token = 0
        output = model.predict_autoregressive(data)
        self.assertTrue(torch.all(output.output == 0))


class MockCausalInducer(CausalInducer):
    def __init__(self, num_nodes):
        self.encoder = lambda x: x
        self.decoder = lambda adj, mem, mask: adj
        self.start_token = 1
        self.num_nodes = num_nodes
        self.batch_size = 1
        self.threshold = .5

    def forward(self, x, adj):
        return DiscoveryResult(adj, None, self.num_nodes, self.threshold)


if __name__ == '__main__':
    unittest.main()
