import copy
import torch

class DAG:
    """ Sorted DAG.

    Attributes:
    -----------
    nodes : list of Nodes
        List of node objects for each node in sorted DAG.

    Methods:
    --------
    rewire(index, parents_new)
        Rewire DAG at specified node by providing new set of parents.
    """

    def __init__(self, parents_list):
        """
        Arguments:
        ----------
        parents_list : list of lists of ints
            List of node's parents' indices for each node.
                Example: [[], [0], [0, 1]] (linear graph)
                Example: [[], [0], [1]] (chain graph)
            
        """            

        self._check_DAG(parents_list)        
        self.parents_list = parents_list
        self.nodes = list(range(len(self.parents_list)))
        

    def _check_DAG(self, parents_list):
        """ Throw error if a parents_list does not correspond to a DAG.
        """

        for node, parents in enumerate(parents_list):
            for parent in parents:
                if parent >= node:
                    raise ValueError(f'Only links (i,j) with i<j are allowed in a sorted DAG. But node #{node} has node #{parent} as parent.')        
    
    def get_parents(self, node):
        """ Return list of parents of given node.
        """
        
        return self.parents_list[node]

    def get_nonparents(self, node):
        """ Return list of non-parents of given node.
        """
        
        n_nodes = len(self.parents_list)
        parents = self.get_parents(node)
        nonparents = list(set(range(n_nodes)).difference(set(parents).union(set([node]))))
        
        return nonparents

    def rewire(self, node, parents_new):
        """ Returns a new DAG, rewired at specified node by providing new set of parents.

        Arguments:
        ----------
        index : int
            Node to rewire.
        parents_new : list of ints
            List of node's new parents.
        """
        # create new parents list
        parents_list_new = copy.deepcopy(self.parents_list)
        parents_list_new[node] = parents_new

        # check that rewiring will not break the DAG property
        self._check_DAG(parents_list_new)
        
        return DAG(parents_list_new)


class ErdosRenyiDAG(DAG):
    """ Erdos-Renyi DAG.

    Each link (i, j) with i < j is sampled iid from a Bernoulli distribution.
    """

    def __init__(self, n, p):
        """
        Arguments:
        ----------
        n : int
            Number of nodes
        p : float
            Probability of edge
        """

        parents_list = self._generate_parents_list(n, p)
        
        super().__init__(parents_list)

    def _generate_parents_list(self, n, p):
        """ Return the DAG's parents' list.
        """
        
        edge_distr = torch.distributions.bernoulli.Bernoulli(p) # iid edge indicator distribution
        adj = torch.triu(edge_distr.sample(sample_shape=(n, n)), diagonal=1) # adjacency matrix
        parents_list = [(row == 1).nonzero(as_tuple=True)[0].tolist() for row in adj.T]
        
        return parents_list

class BipartiteDAG(DAG):
    """ Erdos-Renyi Bipartite DAG.

    Each link (i, j) with is sampled iid from a Bernoulli distribution provided
    i and j belong to two distinct sets.
    """

    def __init__(self, set1, set2, p):
        """
        Arguments:
        ----------
        set1 : list nodes
            List of nodes in first set.
        set2 : list nodes
            List of nodes in second set.
        p : float
            Probability of edge
        """
        self.set1 = set1
        self.set2 = set2
        parents_list = self._generate_parents_list(p)
        super().__init__(parents_list)

    def _generate_parents_list(self, p):
        """ Return the DAG's parents' list.
        """
        n = len(self.set1) + len(self.set2)
        dag = ErdosRenyiDAG(n, p)

        parents_list = []
        for node, node_parents in enumerate(dag.parents_list):
            if node in self.set1:
                parents_list.append(list(set(node_parents).difference(self.set1)))
            elif node in self.set2:
                parents_list.append(list(set(node_parents).difference(self.set2)))

        return parents_list
        

    