import unittest

import cdt
import networkx as nx
import numpy as np
import torch

from losses import aux_loss, my_cross_entropy


class TestLosses(unittest.TestCase):

    def test_aux_loss(self):
        for _ in range(10):
            # Generate data
            num_nodes = 10
            generator = cdt.data.AcyclicGraphGenerator('linear', 'uniform', nodes=num_nodes,
                                                       npoints=3,
                                                       noise_coeff=1.,
                                                       expected_degree=1, dag_type='erdos')
            _, ground_truth = generator.generate()
            adj = nx.adjacency_matrix(ground_truth).todense()
            # bring into aux format
            aux_gt = np.expand_dims(np.concatenate([adj, adj.T], axis=1), axis=0)
            eps = 10e-7
            aux_gt = torch.tensor(eps + aux_gt - aux_gt * 2 * eps)
            # generate shuffeled copy
            aux_shuffeled = np.concatenate([adj, adj.T], axis=1)
            np.random.shuffle(aux_shuffeled)
            aux_shuffeled = np.expand_dims(aux_shuffeled, axis=0)
            aux_shuffeled = torch.tensor(eps + aux_shuffeled - aux_shuffeled * eps)
            # bring target into aux format
            target = torch.tensor(np.expand_dims(adj.flatten(), axis=0))

            aux_loss_gt = aux_loss(aux_gt, target, num_nodes)
            self.assertAlmostEqual(0, aux_loss_gt.item(), 3)
            self.assertGreater(aux_loss(aux_shuffeled, target, num_nodes), aux_loss_gt)

    def test_cross_entropy_loss(self):
        for _ in range(10):
            # Generate data
            num_nodes = 10
            generator = cdt.data.AcyclicGraphGenerator('linear', 'uniform', nodes=num_nodes,
                                                       npoints=3,
                                                       noise_coeff=1.,
                                                       expected_degree=1, dag_type='erdos')
            _, ground_truth = generator.generate()
            adj = nx.adjacency_matrix(ground_truth).todense()
            eps = 10e-7
            adj_tensor = np.expand_dims((eps + adj - adj * 2 * eps).flatten(), axis=0)
            adj_tensor = torch.tensor(adj_tensor)
            # generate shuffeled copy
            adj_shuffeled = np.concatenate([adj, adj.T], axis=1)
            np.random.shuffle(adj_shuffeled)
            adj_shuffeled = np.expand_dims(adj_shuffeled, axis=0)
            adj_shuffeled = torch.tensor(eps + adj_shuffeled - adj_shuffeled * eps)
            # bring target into aux format
            target = torch.tensor(np.expand_dims(adj.flatten(), axis=0))

            loss_gt = my_cross_entropy(adj_tensor, target)
            self.assertAlmostEqual(0, loss_gt.item(), 3)
            self.assertGreater(aux_loss(adj_shuffeled, target, num_nodes), loss_gt)


if __name__ == '__main__':
    unittest.main()
