import networkx as nx
import numpy as np


class DagModuleGenerator:
    """
    Create the structure of a module graph

    Args:
    features (int): Number of features in the graph to generate.
    modules (int): Number of modules
    expected_density (int): Expected number of edge per node.
    """

    def __init__(
        self, features=100, modules=10, expected_density=3, p_vertex=None, p_module=None
    ) -> None:
        self.features = features
        self.modules = modules
        self.total_vertices = self.features + self.modules
        self.adjacency_matrix = np.zeros((self.total_vertices, self.total_vertices))
        self.expected_density = expected_density

        if p_vertex:
            self.p_vertex = p_vertex
        else:
            self.p_vertex = (
                2 * self.modules * self.expected_density / (self.features - 1)
            )
        if p_module:
            self.p_module = p_module
        else:
            self.p_module = 2 * self.expected_density / (self.modules - 1)
        self.n_failure = 0

    def _bipartite_half_info(self):
        deg = np.sum(self.bipartite_half, 0) + np.sum(self.bipartite_half, 1)
        n_disconnected = np.sum(deg == 0)
        print('---------------------------------')
        print('Bipartite half info:')
        print(f'Max degree: {np.max(deg)}, Min degree: {np.min(deg)}, Avg degree: {np.mean(deg)}')
        print(f'Q1: {np.quantile(deg, 0.25)}, Median {np.median(deg)}, Q3: {np.quantile(deg, 0.75)}')
        print('Disconnected nodes:', n_disconnected)
        print('---------------------------------')

    def _bipartite_half_sanity(self):
        deg = np.sum(self.bipartite_half, 0) + np.sum(self.bipartite_half, 1)
        n_disconnected = np.sum(deg == 0)
        check = n_disconnected <= 0.25 * self.features
        return check


    def __call__(self):
        # create partition of nodes into modules and features
        causal_order = np.random.permutation(np.arange(self.total_vertices))
        modules_ind = np.random.choice(
            causal_order[1:-1], size=(self.modules), replace=False
        )
        modules_ind.sort()
        features_ind = np.setdiff1d(causal_order, modules_ind, assume_unique=True)
        features_ind.sort()
        for i in range(self.total_vertices - 1):
            vertex = causal_order[i]
            if vertex in modules_ind:
                # parent must be a node
                possible_parents = np.intersect1d(causal_order[(i + 1):], features_ind)
                prob_connection = self.p_module
            else:
                # parent must be a module
                possible_parents = np.intersect1d(causal_order[(i + 1):], modules_ind)
                prob_connection = self.p_vertex
            num_parents = np.random.binomial(
                n=possible_parents.shape[0], p=prob_connection
            )
            parents = np.random.choice(
                possible_parents, size=num_parents, replace=False
            )
            self.adjacency_matrix[parents, vertex] = 1

        try:
            self.U = self.adjacency_matrix[features_ind][:, modules_ind]
            self.V = self.adjacency_matrix[modules_ind][:, features_ind]
            self.bipartite_half = np.where(np.dot(self.U, self.V) > 0, 1, 0)
            self.module_list = modules_ind
            self.g = nx.DiGraph(self.adjacency_matrix)
            for node in features_ind:
                self.g.nodes[node]["type"] = "node"
            for module in modules_ind:
                self.g.nodes[module]["type"] = "module"
            assert not list(nx.simple_cycles(self.g))
            while not self._bipartite_half_sanity():
                self.n_failure += 1
                if self.n_failure > 20:
                    self._bipartite_half_info()
                    raise ValueError("Too many failures")
                print("Regenerating, sanity check failed...")
                return self()
            self.n_failure = 0
        except AssertionError:
            print("Regenerating, graph non valid...")
            return self()
        self._bipartite_half_info()
        return self.g, causal_order


class IntvDagModuleGenerator:
    def __init__(
        self,
        features=100,
        modules=10,
        intvs=10,
        expected_density=3,
        p_vertex=None,
        p_module=None
    ) -> None:
        self.features = features
        self.modules = modules
        self.intvs = intvs
        self.total_vertices = self.features + self.modules + self.intvs
        self.adjacency_matrix = np.zeros((self.total_vertices, self.total_vertices))
        self.expected_density = expected_density

        if p_vertex:
            self.p_vertex = p_vertex
        else:
            self.p_vertex = (
                2 * self.modules * self.expected_density / (self.features - 1)
            )
        if p_module:
            self.p_module = p_module
        else:
            self.p_module = 2 * self.expected_density / (self.modules - 1)
        self.n_failure = 0
    
    def _bipartite_half_info(self):
        deg = np.sum(self.bipartite_half, 0) + np.sum(self.bipartite_half, 1)
        n_disconnected = np.sum(deg == 0)
        print('---------------------------------')
        print('Bipartite half info:')
        print(f'Max degree: {np.max(deg)}, Min degree: {np.min(deg)}, Avg degree: {np.mean(deg)}')
        print(f'Q1: {np.quantile(deg, 0.25)}, Median {np.median(deg)}, Q3: {np.quantile(deg, 0.75)}')
        print('Effective interventions', np.sum(np.any(self.W > 0, axis=1)))
        print('Affected modules', np.sum(np.any(self.W > 0, axis=0)))
        print('Disconnected nodes:', n_disconnected)
        print('---------------------------------')

    def _bipartite_half_sanity(self):
        deg = np.sum(self.bipartite_half, 0) + np.sum(self.bipartite_half, 1)
        n_disconnected = np.sum(deg == 0)
        check = n_disconnected <= 0.25 * self.features
        check_w = np.sum(np.any(self.W > 0, axis=1)) > 0.75 * self.intvs \
            and np.sum(np.any(self.W > 0, axis=0)) > 0.75 * self.modules
        return check and check_w      
    
    def __call__(self, intv_targets):
        # create partition of nodes into modules and features
        causal_order = np.concatenate(
            [
                np.random.permutation(np.arange(self.total_vertices - self.intvs)),
                np.arange(self.total_vertices - self.intvs, self.total_vertices)
            ]
        )
        modules_ind = np.random.choice(
            causal_order[1:self.modules+self.features], size=(self.modules), replace=False
        )
        modules_ind.sort()
        features_ind = np.setdiff1d(causal_order[:-self.intvs], modules_ind, assume_unique=True)
        features_ind.sort()
        intv_ind = np.array(range(self.total_vertices - self.intvs, self.total_vertices))
        # Assign intervention relations
        #  each entry of adj_list is a list of indices of interventions (0-intvs) that are connected to the i-th module
        adj_list = [[] for i in range(self.modules)]
        for i, targets in enumerate(intv_targets):
            for m in targets:
                adj_list[m].append(i)
        for i in range(self.modules + self.features - 1):
            vertex = causal_order[i]
            if vertex in modules_ind:
                # parent must be a node
                possible_parents = np.intersect1d(causal_order[(i + 1):], features_ind)
                prob_connection = self.p_module
            else:
                # parent must be a module
                possible_parents = np.intersect1d(causal_order[(i + 1):], modules_ind)
                prob_connection = self.p_vertex
            num_parents = np.random.binomial(
                n=possible_parents.shape[0], p=prob_connection
            )
            parents = np.random.choice(
                possible_parents, size=num_parents, replace=False
            )
            self.adjacency_matrix[parents, vertex] = 1
        
        # Sanity check used for debuggings
        assert np.all(self.adjacency_matrix[intv_ind][:, modules_ind] == 0)
        assert np.all(self.adjacency_matrix[modules_ind][:, intv_ind] == 0)
        assert np.all(self.adjacency_matrix[intv_ind][:, features_ind] == 0)
        assert np.all(self.adjacency_matrix[features_ind][:, intv_ind] == 0)
        
        modules_ordered = []
        for node in causal_order:
            if node in modules_ind:
                modules_ordered.append(node)
        for i, m in enumerate(modules_ordered):
            intv_idx = np.array([intv_ind[j] for j in adj_list[i]])
            self.adjacency_matrix[intv_idx, causal_order[m]] = 1

        try:
            self.U = self.adjacency_matrix[features_ind][:, modules_ind]
            self.V = self.adjacency_matrix[modules_ind][:, features_ind]
            self.W = self.adjacency_matrix[intv_ind][:, modules_ind]
            self.bipartite_half = np.where(np.dot(self.U, self.V) > 0, 1, 0)
            self.intv_half = np.where(np.dot(self.W, self.V) > 0, 1, 0)
            self.module_list = modules_ind
            self.g = nx.DiGraph(self.adjacency_matrix)
            for node in features_ind:
                self.g.nodes[node]["type"] = "node"
            for module in modules_ind:
                self.g.nodes[module]["type"] = "module"
            assert not list(nx.simple_cycles(self.g))
            while not self._bipartite_half_sanity():
                self.n_failure += 1
                if self.n_failure > 20:
                    self._bipartite_half_info()
                    raise ValueError("Too many failures")
                print("Regenerating, sanity check failed...")
                return self()
            self.n_failure = 0
        except AssertionError:
            print("Regenerating, graph non valid...")
            return self()
        self._bipartite_half_info()
        return self.g, causal_order