import networkx as nx
import numpy as np
import logging


class EnhancedDagModuleGenerator:
    """
    Create the structure of a module graph, with gaurantee of 
    unique parent and child for each factor

    Args:
    features (int): Number of features in the graph to generate.
    modules (int): Number of modules
    """

    def __init__(
        self,
        features=100,
        modules=10,
        mean_par=5,
        mean_chr=5,
        min_unique_par=1,
        min_unique_chr=1,
        min_upstream=-1,
        min_downstream=-1,
        verbose=False,
    ) -> None:
        """

        Args:
            features (int, optional): Number of nodes in the f-DAG. Defaults to 100.
            modules (int, optional): Number of modules (factors) in the f-DAG. Defaults to 10.
            mean_par (int, optional): Poisson mean of number of parents. Defaults to 5.
            mean_chr (int, optional): Poisson mean of number of children. Defaults to 5.
            min_unique_par (int, optional): Minimum number of unique parents for a module. Defaults to 1.
            min_unique_chr (int, optional): Minimum number of unique children for a module. Defaults to 1.
            min_upstream (int, optional): Minimum number of nodes in the first partition. Defaults to -1.
            min_downstream (int, optional): Minimum number of nodes in the last partition. Defaults to -1.
            verbose (bool, optional): Enable debugging printing. Defaults to False.
        """
        self.features = features
        self.modules = modules
        self.total_vertices = self.features + self.modules
        self.adjacency_matrix = np.zeros((self.total_vertices, self.total_vertices))
        self.mean_par = mean_par
        self.mean_chr = mean_chr
        self.min_unique_par = min_unique_par
        self.min_unique_chr = min_unique_chr
        if min_upstream < 0:
            min_upstream = min_unique_par
        if min_downstream < 0:
            min_downstream = min_unique_chr
        self.min_upstream = min_upstream
        self.min_downstream = min_downstream
        assert max(min_unique_chr, min_unique_par) * self.modules + min(min_unique_chr, min_unique_par) <= self.features
        self.logger = logging.getLogger(__name__)
        logging.basicConfig()
        if verbose:
            self.logger.setLevel(logging.DEBUG)
        self.n_failure = 0

    def _assign_unique_nodes(self, causal_order, modules_ind_index):
        # keeps track of common parents after sampling unique ones
        common_parents = set()  
        common_children = set()
        for i in range(self.modules):
            if i == 0:
                common_parents = common_parents.union(
                    set(causal_order[:modules_ind_index[i]])
                )
            else:
                common_parents = common_parents.union(
                    set(causal_order[modules_ind_index[i-1]+1:modules_ind_index[i]])
                )
            if i == self.modules - 1:
                common_children = common_children.union(
                    set(causal_order[modules_ind_index[i]+1:self.features+self.modules])
                )
            else:
                common_children = common_children.union(
                    set(causal_order[modules_ind_index[i]+1:modules_ind_index[i+1]])
                )

        par_candidates = set()  # accumulates possible parents along partitioning
        chr_candidates = set()
        
        # calculates the maximum number of parents and children for each module
        n_prev_nodes = [modules_ind_index[i] - i for i in range(self.modules)]
        n_next_nodes = [self.features - modules_ind_index[i] + i for i in range(self.modules)]
        if np.any(n_prev_nodes == 0) or np.any(n_next_nodes == 0):
            return common_parents, common_children, False
        n_par_sampler = np.random.poisson(self.mean_par, self.modules)
        n_par_sampler[n_par_sampler == 0] = 1
        n_chr_sampler = np.random.poisson(self.mean_chr, self.modules)
        n_chr_sampler[n_chr_sampler == 0] = 1
        self.n_par_sampler = np.minimum(n_par_sampler, n_prev_nodes)
        self.n_chr_sampler = np.minimum(n_chr_sampler, n_next_nodes)
        
        # assign unique parents and children
        # parents
        for i in range(self.modules):
            if i == 0:  # parent candidates of the first module
                par_candidates = par_candidates.union(
                    set(causal_order[:modules_ind_index[i]])
                )
            else:
                par_candidates = par_candidates.union(
                    set(causal_order[modules_ind_index[i-1]+1:modules_ind_index[i]])
                )
            # Randomly draw the number of unique parents
            if self.min_unique_par < self.n_par_sampler[i] and self.modules > 1:
                n_unique_par = np.random.randint(
                    self.min_unique_par, self.n_par_sampler[i]
                )
            else:
                n_unique_par = self.n_par_sampler[i]
            # Handle edge cases
            if len(par_candidates) == 0:
                self.logger.warning(f"No possible unique parent. Degrade to common parent.")
                return common_parents, common_children, False
            elif n_unique_par == 0 and self.modules > 1:
                self.logger.warning(f"Module {i} has no possible parent. Randomly pick one.")
                return common_parents, common_children, False
            elif len(par_candidates) < n_unique_par and self.modules > 1:
                self.logger.warning(f"Module {i} has fewer possible parents than unique parents. Set unique parents to 1.")
                self.partition_u[i] = set(
                    np.random.choice(
                        set(causal_order[:modules_ind_index[i]]), size=1, replace=False
                    )
                )
            else:
                self.partition_u[i] = set(
                    np.random.choice(
                        list(par_candidates), size=n_unique_par, replace=False
                    )
                )
            common_parents = common_parents.difference(self.partition_u[i])
            par_candidates = par_candidates.difference(self.partition_u[i])
            
        # children
        for i in range(self.modules-1, -1, -1):
            if i == self.modules - 1:
                chr_candidates = chr_candidates.union(
                    set(causal_order[modules_ind_index[i]+1:self.features+self.modules])
                )
            else:
                chr_candidates = chr_candidates.union(
                    set(causal_order[modules_ind_index[i]+1:modules_ind_index[i+1]])
                )
            if self.min_unique_chr < self.n_chr_sampler[i] and self.modules > 1:
                n_unique_chr = np.random.randint(
                    self.min_unique_chr, self.n_chr_sampler[i]
                )
            else:
                n_unique_chr = self.n_chr_sampler[i]
            # Handle edge cases
            if len(chr_candidates) == 0:
                self.logger.warning(f"No possible child. Degrade to common parent.")
                return common_parents, common_children, False
            elif n_unique_chr == 0 and self.modules > 1:
                self.logger.warning(f"Module {i} has no possible child.")
                return common_parents, common_children, False
            elif len(chr_candidates) < n_unique_chr and self.modules > 1:
                self.logger.warning(f"Module {i} has less possible children than unique children. Set unique children to 1.")
                self.partition_v[i] = set(
                    np.random.choice(
                        list(chr_candidates), size=1, replace=False
                    )
                )
            else:
                self.partition_v[i] = set(
                    np.random.choice(
                        list(chr_candidates), size=n_unique_chr, replace=False
                    )
                )
            common_children = common_children.difference(self.partition_v[i])
            chr_candidates = chr_candidates.difference(self.partition_v[i])
        return common_parents, common_children, True

    def _assign_common_nodes(self, causal_order, modules_ind_index, common_parents, common_children):
        for i in range(self.modules):
            # parents
            par_candidates = set(causal_order[:modules_ind_index[i]]).intersection(common_parents)
            
            n_par = self.n_par_sampler[i] - len(self.partition_u[i])
            n_par = min(n_par, len(par_candidates))
            if n_par > 0:
                self.partition_u[i] = self.partition_u[i].union(
                    set(
                        np.random.choice(
                            list(par_candidates), size=n_par, replace=False
                        )
                    )
                )
                par_candidates = par_candidates.difference(self.partition_u[i])
            
            # children
            chr_candidates = set(causal_order[modules_ind_index[i]+1:self.features+self.modules]).intersection(common_children)
            
            n_chr = self.n_chr_sampler[i] - len(self.partition_v[i])
            n_chr = min(n_chr, len(chr_candidates))
            if n_chr > 0:
                self.partition_v[i] = self.partition_v[i].union(
                    set(
                        np.random.choice(
                            list(chr_candidates), size=n_chr, replace=False
                        )
                    )
                )
                chr_candidates = chr_candidates.difference(self.partition_v[i])

    def _partition_nodes(self, causal_order, modules_ind_index, max_iter=100):
        # assign up- and down-stream nodes to modules
        self.partition_u = [set()]*self.modules  # final assignment
        self.partition_v = [set()]*self.modules
        
        # sample unique parents and children
        # import pdb; pdb.set_trace()
        common_parents, common_children, success = self._assign_unique_nodes(causal_order, modules_ind_index)
        count = 0
        while count < max_iter and not success:
            common_parents, common_children, success = self._assign_unique_nodes(causal_order, modules_ind_index)
            count += 1
        
        # assign the rest of the nodes as parents or children
        if len(common_parents) > 0 or len(common_children) > 0:
            self._assign_common_nodes(causal_order, modules_ind_index, common_parents, common_children)
        return

    def _bipartite_half_info(self):
        deg = np.sum(self.bipartite_half, 0) + np.sum(self.bipartite_half, 1)
        n_disconnected = np.sum(deg == 0)
        print('---------------------------------')
        print('Bipartite half info:')
        print(f'Max degree: {np.max(deg)}, Min degree: {np.min(deg)}, Avg degree: {np.mean(deg)}')
        print(f'Q1: {np.quantile(deg, 0.25)}, Median {np.median(deg)}, Q3: {np.quantile(deg, 0.75)}')
        print('Disconnected nodes:', n_disconnected)
        print('---------------------------------')

    def _bipartite_half_sanity(self):
        deg = np.sum(self.bipartite_half, 0) + np.sum(self.bipartite_half, 1)
        n_disconnected = np.sum(deg == 0)
        check = n_disconnected <= 0.25 * self.features
        return check

    def __call__(self):
        # create partition of nodes into modules and features
        causal_order = np.random.permutation(np.arange(self.total_vertices))

        # Make factors more evenly spaced
        delta = (self.total_vertices - self.min_upstream - self.min_downstream) / (self.modules + 1)
        modules_ind_index = [np.random.poisson(delta) + self.min_upstream]
        n_fail = 0
        while modules_ind_index[0] >= self.total_vertices - self.min_downstream and n_fail < 100:
            modules_ind_index[0] = np.random.poisson(delta) + self.min_upstream
            n_fail += 1
        if modules_ind_index[0] >= self.total_vertices - self.min_downstream:
            raise ValueError("Cannot find a valid starting point")

        n_fail = 0
        for i in range(1, self.modules):
            modules_ind_index.append(np.random.poisson(delta) + modules_ind_index[-1])
            while modules_ind_index[-1] >= self.total_vertices - self.min_downstream or \
                modules_ind_index[-1] <= modules_ind_index[-2] and n_fail < 100:
                modules_ind_index[-1] = np.random.poisson(delta) + modules_ind_index[-2]
                n_fail += 1
            if modules_ind_index[-1] >= self.total_vertices - self.min_downstream or \
                modules_ind_index[-1] <= modules_ind_index[-2]:
                raise ValueError("Cannot find a valid partition length.")
        assert len(modules_ind_index) == self.modules, "Incorrect number of modules"
        
        modules_ind_index.sort()
        modules_ind = causal_order[modules_ind_index]  # in causal order

        features_ind = np.setdiff1d(causal_order, modules_ind, assume_unique=True)
        features_ind.sort()
        
        self._partition_nodes(causal_order, modules_ind_index)

        self.adjacency_matrix = np.zeros((self.total_vertices, self.total_vertices))
        for i in range(self.modules):
            # i-th module i the causal order
            if len(self.partition_u[i]) > 0:
                node2module = np.array(list(self.partition_u[i]))
                node2module.sort()
                self.adjacency_matrix[node2module, causal_order[modules_ind_index[i]]] = 1
            if len(self.partition_v[i]) > 0:
                module2node = np.array(list(self.partition_v[i]))
                module2node.sort()
                self.adjacency_matrix[causal_order[modules_ind_index[i]], module2node] = 1
        try:
            self.U = self.adjacency_matrix[features_ind][:, modules_ind]
            self.V = self.adjacency_matrix[modules_ind][:, features_ind]
            self.bipartite_half = np.where(np.dot(self.U, self.V) > 0, 1, 0)

            self.module_list = modules_ind
            self.g = nx.DiGraph(self.adjacency_matrix)
            for node in features_ind:
                self.g.nodes[node]["type"] = "node"
            for module in modules_ind:
                self.g.nodes[module]["type"] = "module"
            assert not list(nx.simple_cycles(self.g))
            while not self._bipartite_half_sanity():
                self.n_failure += 1
                if self.n_failure > 20:
                    self._bipartite_half_info()
                    raise ValueError("Too many failures")
                print("Regenerating, sanity check failed...")
                return self()
            self.n_failure = 0
        except AssertionError:
            print("Regenerating, graph non valid...")
            return self()
        self.modules_ind = modules_ind
        self._bipartite_half_info()
        return self.g, causal_order


class IntvEnhancedDagModuleGenerator(EnhancedDagModuleGenerator):
    def __init__(
        self,
        features,
        modules,
        intvs,
        mean_par=5,
        mean_chr=5,
        min_unique_par=1,
        min_unique_chr=1,
        min_upstream=-1,
        min_downstream=-1,
        verbose=False,
    ) -> None:
        self.features = features
        self.modules = modules
        self.intvs = intvs
        self.total_vertices = self.features + self.modules + self.intvs
        self.adjacency_matrix = np.zeros((self.total_vertices, self.total_vertices))
        self.mean_par = mean_par
        self.mean_chr = mean_chr
        self.min_unique_par = min_unique_par
        self.min_unique_chr = min_unique_chr
        if min_upstream < 0:
            min_upstream = min_unique_par
        if min_downstream < 0:
            min_downstream = min_unique_chr
        self.min_upstream = min_upstream
        self.min_downstream = min_downstream
        assert max(min_unique_chr, min_unique_par) * self.modules + min(min_unique_chr, min_unique_par) <= self.features
        self.logger = logging.getLogger(__name__)
        logging.basicConfig()
        if verbose:
            self.logger.setLevel(logging.DEBUG)
        self.n_failure = 0

    def _bipartite_half_info(self):
        deg = np.sum(self.bipartite_half, 0) + np.sum(self.bipartite_half, 1)
        n_disconnected = np.sum(deg == 0)
        print('---------------------------------')
        print('Bipartite half info:')
        print(f'Max degree: {np.max(deg)}, Min degree: {np.min(deg)}, Avg degree: {np.mean(deg)}')
        print(f'Q1: {np.quantile(deg, 0.25)}, Median {np.median(deg)}, Q3: {np.quantile(deg, 0.75)}')
        print('Effective interventions', np.sum(np.any(self.W > 0, axis=1)))
        print('Affected modules', np.sum(np.any(self.W > 0, axis=0)))
        print('Disconnected nodes:', n_disconnected)
        print('---------------------------------')

    def _bipartite_half_sanity(self):
        deg = np.sum(self.bipartite_half, 0) + np.sum(self.bipartite_half, 1)
        n_disconnected = np.sum(deg == 0)
        check = n_disconnected <= 0.25 * self.features
        check_w = np.sum(np.any(self.W > 0, axis=1)) > 0.75 * self.intvs \
            and np.sum(np.any(self.W > 0, axis=0)) > 0.75 * self.modules
        return check and check_w        
    
    def __call__(self, intv_targets):
        # create partition of nodes into modules and features.
        #   Intervention nodes are always at the beginning.
        causal_order = np.concatenate(
            [
                np.random.permutation(np.arange(self.total_vertices - self.intvs)),
                np.arange(self.total_vertices - self.intvs, self.total_vertices)
            ]
        )

        # Make factors more evenly spaced
        delta = (self.features + self.modules - self.min_upstream - self.min_downstream) / (self.modules + 1)
        modules_ind_index = [np.random.poisson(delta) + self.min_upstream]
        n_fail = 0
        while modules_ind_index[0] >= self.features + self.modules - self.min_downstream:
            modules_ind_index[0] = np.random.poisson(delta) + self.min_upstream
            n_fail += 1
        if modules_ind_index[0] >= self.features + self.modules - self.min_downstream:
            modules_ind_index[0] = 1

        n_fail = 0
        for i in range(1, self.modules):
            modules_ind_index.append(np.random.poisson(delta) + modules_ind_index[-1])
            while modules_ind_index[-1] >= self.features + self.modules - self.min_downstream or \
                modules_ind_index[-1] <= modules_ind_index[-2] and n_fail < 100:
                modules_ind_index[-1] = np.random.poisson(delta) + modules_ind_index[-2]
                n_fail += 1
            if modules_ind_index[-1] >= self.features + self.modules - self.min_downstream or \
                modules_ind_index[-1] <= modules_ind_index[-2]:
                    modules_ind_index[-1] = self.features + self.modules - 1
                
        assert len(modules_ind_index) == self.modules, "Incorrect number of modules"
        
        modules_ind_index.sort()
        modules_ind = causal_order[modules_ind_index]  # in causal order
        modules_ind.sort()

        features_ind = np.setdiff1d(causal_order[:-self.intvs], modules_ind, assume_unique=True)
        features_ind.sort()
        
        intv_ind = np.array(range(self.total_vertices - self.intvs, self.total_vertices))
        
        self._partition_nodes(causal_order, modules_ind_index)

        self.adjacency_matrix = np.zeros((self.total_vertices, self.total_vertices))
        # Assign intervention relations
        #  each entry of adj_list is a list of indices of interventions (0-intvs) that are connected to the i-th module
        adj_list = [[] for i in range(self.modules)]
        for i, targets in enumerate(intv_targets):
            for m in targets:
                adj_list[m].append(i)
        for i in range(self.modules):
            # i-th module i the causal order
            if len(self.partition_u[i]) > 0:
                node2module = np.array(list(self.partition_u[i]))
                node2module.sort()
                self.adjacency_matrix[node2module, causal_order[modules_ind_index[i]]] = 1
            if len(self.partition_v[i]) > 0:
                module2node = np.array(list(self.partition_v[i]))
                module2node.sort()
                self.adjacency_matrix[causal_order[modules_ind_index[i]], module2node] = 1
        
        # Sanity check used for debuggings
        assert np.all(self.adjacency_matrix[intv_ind][:, modules_ind] == 0)
        assert np.all(self.adjacency_matrix[modules_ind][:, intv_ind] == 0)
        assert np.all(self.adjacency_matrix[intv_ind][:, features_ind] == 0)
        assert np.all(self.adjacency_matrix[features_ind][:, intv_ind] == 0)

        for i in range(self.modules):
            intv_idx = np.array([intv_ind[j] for j in adj_list[i]])
            self.adjacency_matrix[intv_idx, causal_order[modules_ind_index[i]]] = 1

        try:
            self.U = self.adjacency_matrix[features_ind][:, modules_ind]
            self.V = self.adjacency_matrix[modules_ind][:, features_ind]
            self.W = self.adjacency_matrix[intv_ind][:, modules_ind]
            self.bipartite_half = np.where(np.dot(self.U, self.V) > 0, 1, 0)
            self.intv_half = np.where(np.dot(self.W, self.V) > 0, 1, 0)

            self.module_list = modules_ind
            self.g = nx.DiGraph(self.adjacency_matrix)
            for node in features_ind:
                self.g.nodes[node]["type"] = "node"
            for module in modules_ind:
                self.g.nodes[module]["type"] = "module"
            for intv in intv_ind:
                self.g.nodes[intv]["type"] = "intervention"
            assert not list(nx.simple_cycles(self.g))
            while not self._bipartite_half_sanity():
                self.n_failure += 1
                if self.n_failure > 20:
                    self._bipartite_half_info()
                    print("Too many failures. Save the last generated graph.")
                    print(self.U)
                    print(self.V)
                    print(self.W)
                    break
                print("Regenerating, sanity check failed...")
                return self(intv_targets)
            self.n_failure = 0
        except AssertionError:
            print("Regenerating, graph non valid...")
            return self(intv_targets)
        self.modules_ind = modules_ind
        self._bipartite_half_info()
        return self.g, causal_order
    