import unittest

import networkx as nx
import numpy as np
import pandas as pd
import scipy
import torch
from pingouin import partial_corr

from utils.graph_dataset import GraphDatasetGenerator, ConfoundedGraphDataset, marginalize


class TestLosses(unittest.TestCase):

    def test_marginalisation_mediation(self):
        graph = nx.DiGraph()
        graph.add_edge('X', 'Y')
        graph.add_edge('Y', 'Z')
        graph.add_edge('Z', 'V')
        graph.add_edge('V', 'W')
        adj, conf = marginalize(graph, ['X', 'V'])
        self.assertTrue(np.all(adj == np.array([[0, 1], [0, 0]])))
        self.assertTrue(np.all(conf == np.zeros((2, 2))))

    def test_marginalisation_pure_confounding(self):
        graph = nx.DiGraph()
        graph.add_edge('X', 'Y')
        graph.add_edge('X', 'Z')
        graph.add_edge('Z', 'V')
        adj, conf = marginalize(graph, ['Y', 'Z'])
        self.assertTrue(np.all(adj == np.zeros((2, 2))))
        self.assertTrue(np.all(conf == np.array([[0, 1], [1, 0]])))

    def test_marginalisation_confounding(self):
        graph = nx.DiGraph()
        graph.add_edge('X', 'Y')
        graph.add_edge('X', 'Z')
        graph.add_edge('Z', 'V')
        graph.add_edge('Y', 'Z')
        adj, conf = marginalize(graph, ['Y', 'Z'])
        self.assertTrue(np.all(adj == np.array([[0, 1], [0, 0]])))
        self.assertTrue(np.all(conf == np.array([[0, 1], [1, 0]])))

    def test_no_confounders(self):
        num_nodes = 10
        train_set = ConfoundedGraphDataset(num_graphs=10, num_samples=3, num_nodes=num_nodes,
                                           fraction_confounded_datasets=0)
        for _, target in train_set:
            _, _, conf = torch.split(target, [num_nodes ** 2, 1, num_nodes ** 2])
            self.assertFalse(torch.any(conf).item())

    def test_num_variables(self):
        num_nodes = 10
        train_set = GraphDatasetGenerator(num_graphs=3, num_samples=3, num_nodes=num_nodes)
        for i in range(len(train_set)):
            data, target = train_set[i]
            self.assertEqual(num_nodes, data.shape[1])
            self.assertEqual(num_nodes ** 2, target.shape[0])

    def test_degree(self):
        num_nodes = 10
        EXP_DEGREE = 2.0
        node_degrees = []
        for i in range(20):
            train_set = GraphDatasetGenerator(num_graphs=20, num_samples=3, num_nodes=num_nodes)
            for j in range(len(train_set)):
                _, target_tensor = train_set[j]
                graph = nx.from_numpy_array(target_tensor.reshape(num_nodes, num_nodes).numpy(),
                                            create_using=nx.DiGraph)
                node_degrees += [graph.degree(n) for n in graph.nodes]
        self.assertAlmostEqual(EXP_DEGREE, np.mean(node_degrees), places=1)

    def test_markov_property(self):
        num_nodes = 4
        train_set = GraphDatasetGenerator(num_graphs=3, num_samples=2000, num_nodes=num_nodes, noise='gaussian')
        for i in range(len(train_set)):
            data, target_tensor = train_set[i]
            graph = nx.from_numpy_array(target_tensor.reshape(num_nodes, num_nodes).numpy(),
                                        create_using=nx.DiGraph)
            df = pd.DataFrame(data, columns=graph.nodes)
            correct_separations = []
            for x, y in [(x, y) for x in graph.nodes for y in graph.nodes if x != y]:
                for z in [set([])] + [{z} for z in graph.nodes if z != x and z != y]:
                    d_sep = nx.d_separated(graph, {x}, {y}, z)
                    indep = partial_corr(df, x, y, covar=list(z))['p-val'] > .05
                    correct_separations.append(d_sep == indep)
            self.assertGreater(np.mean(correct_separations), .7)

    def test_alpha_level(self):
        num_nodes = 4
        correct = []
        train_set = GraphDatasetGenerator(num_graphs=3, num_samples=200, num_nodes=num_nodes, noise='gaussian')
        for i in range(len(train_set)):
            data, target_tensor = train_set[i]
            adj_matrix = target_tensor.reshape(num_nodes, num_nodes).numpy()
            graph = nx.from_numpy_array(adj_matrix, create_using=nx.DiGraph)
            nx.relabel_nodes(graph, {i: 'V{}'.format(i) for i in range(num_nodes)}, copy=False)
            df = pd.DataFrame(data, columns=graph.nodes)
            confidence = 0.05
            false_dep = 0
            num_tests = 0
            for x in graph.nodes:
                for y in graph.nodes:
                    if x != y:
                        d_sep = nx.d_separated(graph, {x}, {y}, set([]))
                        ind = scipy.stats.pearsonr(df[x], df[y])[1] > confidence
                        num_tests += 1
                        if not ind and d_sep:
                            false_dep += 1
            self.assertLess(false_dep / float(num_tests), 3 * confidence)


if __name__ == '__main__':
    unittest.main()
