import unittest

import torch

from metrics import average_shd, num_dags, average_out_degree


class TestLosses(unittest.TestCase):

    def test_shd_edge_error(self):
        target = torch.tensor([[0, 1, 0], [0, 0, 1], [0, 0, 0]]).flatten().unsqueeze(0)
        output = torch.tensor([[0, 1, 0], [0, 1, 0], [0, 0, 0]]).flatten().unsqueeze(0)
        shd = average_shd(output, target, batch_size=1)
        self.assertEqual(shd, 2)

        output = torch.tensor([[0, 1, 0], [0, 0, 0], [0, 1, 0]]).flatten().unsqueeze(0)
        shd = average_shd(output, target, batch_size=1)
        self.assertEqual(shd, 2)

        shd = average_shd(target, target, batch_size=1)
        self.assertEqual(shd, 0)

    def test_num_dags(self):
        line_dag = torch.tensor([[0, 1, 0], [0, 0, 1], [0, 0, 0]]).flatten().unsqueeze(0)
        circle = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).flatten().unsqueeze(0)
        output = torch.stack([line_dag, circle], dim=0)
        num = num_dags(output, batch_size=2, num_nodes=3)
        self.assertEqual(num, 1)

        output = torch.stack([circle, circle], dim=0)
        num = num_dags(output, batch_size=2, num_nodes=3)
        self.assertEqual(num, 0)

        output = torch.stack([line_dag, line_dag, line_dag], dim=0)
        num = num_dags(output, batch_size=3, num_nodes=3)
        self.assertEqual(num, 3)

    def test_avg_degree(self):
        output = torch.tensor([[0, 1, 1], [0, 0, 1], [0, 0, 0]]).flatten().unsqueeze(0)
        degree = average_out_degree(output, batch_size=1, num_nodes=3)
        self.assertEqual(degree, 1.)

        output = torch.tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]]).flatten().unsqueeze(0)
        degree = average_out_degree(output, batch_size=1, num_nodes=3)
        self.assertEqual(degree, 0)

        output = torch.tensor([[0, 1, 0], [0, 0, 0], [0, 0, 0]]).flatten().unsqueeze(0)
        degree = average_out_degree(output, batch_size=1, num_nodes=3)
        self.assertEqual(degree, 1. / 3.)


if __name__ == '__main__':
    unittest.main()
