import igraph as ig
import numpy as np
import networkx as nx

################################################################
# -------- Gaussian Random Partition Graphs Generator -------- #
################################################################
class GaussianRandomPartition:
    def __init__(
        self,
        d,
        p_in,
        p_out
    ):
        """
        Custom generator of Gaussian Random Partition Graphs
        Parameters
        ----------
        d : int
            Number of nodes
        p_in : float
            Probability of edge connection with nodes in the cluster
        p_out : float
            Probability of edge connection with nodes in different clusters
        """
        self.d = d
        self.p_in = p_in
        self.p_out = p_out

    
    def _sample_number_of_clusters(self):
        """Sample the number of clusters from multinomial distribution with n=1.
        Sample from 3 to 5 clusters for 20, 30 nodes.
        Sample from 4 to 6 clusters for 50 nodes.
        """
        num_clusters = [3, 4, 5]
        c = np.random.multinomial(n=1, pvals=[1/3, 1/3, 1/3]).argmax()
        if self.d <= 10 : # for medium graph 2 clusters
            return 2
        elif self.d < 50:
            return num_clusters[c]
        else:
            return num_clusters[c] + 1

    def _sample_cluster_sizes(self, n_clusters):
        """The size of the clusters is sampled.
        Size of clusters is sampled from a polynomial, 
        and later adjusted to ensure at least 3 nodes per cluster

        Parameters
        ----------
        n_clusters : int
            The number of clusters in the graph
        """
        cluster_sizes = np.random.multinomial(self.d, pvals=[1/n_clusters for _ in range(n_clusters)])
        # At least 3 elements per cluster
        while np.min(cluster_sizes) < 3:
            argmax = np.argmax(cluster_sizes)
            argmin = np.argmin(cluster_sizes)
            cluster_sizes[argmax] -= 1
            cluster_sizes[argmin] += 1
        return cluster_sizes

    def _sample_er_cluster(self, cluster_size):
        """Sample each cluster of GRP graphs with Erdos-Renyi model
        """
        A = erdos_renyi_p(cluster_size, self.p_in)
        return A

    def _disjoint_union(self, A, c_size):
        """
        Merge current adjacency matrix A with ER graph with c_size nodes
        such that the resuling graph is a DAG.
        Additionally, label nodes with respect to their cluster of belonging

        Parameters
        ----------
        A : np.array
            Current adjacency matrix
        c_size : int 
            Size of the cluster to generate
        """
        # Join the graphs by block matrices
        # A = np.hstack(A, np.zeros(c_size, c_size))
        n = A.shape[0]
        er_cluster = self._sample_er_cluster(cluster_size=c_size)
        er_cluster = np.hstack([np.zeros((c_size, n)), er_cluster])
        A = np.hstack([A, np.zeros((n, c_size))])
        A = np.vstack([A, er_cluster])

        # Add connections among clusters from A to er_cluster
        for i in range(n):
            for j in range(n, i+c_size):
                if np.random.binomial(n=1, p=self.p_out) == 1:
                    print(f"edge {(i, j)} between clusters!")
                    A[i, j] = 1

        return A
   

    def gaussian_random_partition_graph(self):
        n_clusters = self._sample_number_of_clusters()
        size_of_clusters = self._sample_cluster_sizes(n_clusters)
        print(f"size of the clusters: {size_of_clusters}")

        # Initialize with the first cluster and remove it from the list
        A = self._sample_er_cluster(size_of_clusters[0])
        size_of_clusters = np.delete(size_of_clusters, [0])

        # Join all clusters together
        for c_size in size_of_clusters:
            A = self._disjoint_union(A, c_size)
            assert nx.is_directed_acyclic_graph(nx.from_numpy_array(A, create_using=nx.DiGraph)), "Error: graph is not a DAG!"
        
        return A
    

################################################################
# -------------------- Graphs Simulations -------------------- #
################################################################

def _acyclic_orientation(A):
    return np.triu(A, k=1)

def _ig_to_adjmat(G : ig.Graph):
    return np.array(G.get_adjacency().data)

def graph_viz(A : np.array):
    G = nx.from_numpy_array(A, create_using=nx.DiGraph)
    nx.draw_networkx(G)

def erdos_renyi_p(d, p):
    """
    Generate ER DAG with d nodes and p average number of edges

    Parameters
    ----------
    d : int
        Number of nodes.
    m : int
        Expected edges per node.
    """
    A = np.zeros((d, d))
    while np.sum(A) < 2:
        G_und = ig.Graph.Erdos_Renyi(n=d, p=p)
        A_und = _ig_to_adjmat(G_und)
        A = _acyclic_orientation(A_und)
    assert nx.is_directed_acyclic_graph(nx.from_numpy_array(A, create_using=nx.DiGraph)), "Error: graph is not a DAG!"
    return A

def erdos_renyi_m(d, m):
    """
    Generate ER DAG with d nodes and d*m expected number of edges

    Parameters
    ----------
    d : int
        Number of nodes.
    m : int
        Expected edges per node.
    """
    A = np.zeros((d, d))
    while np.sum(A) < 2:
        G_und = ig.Graph.Erdos_Renyi(n=d, m=m*d)
        A_und = _ig_to_adjmat(G_und)
        A = _acyclic_orientation(A_und)

    assert nx.is_directed_acyclic_graph(nx.from_numpy_array(A, create_using=nx.DiGraph)), "Error: graph is not a DAG!"
    return A


def barabasi_albert_in(d, m):
    """
    Generate Barabasi Albert SF graph with d nodes and m (expected) edges per node.
    Hub nodes have large in-degree (i.e. a lot of colliders)

    Parameters
    ----------
    d : int
        Number of nodes.
    m : int
        Expected edges per node.
    """
    A = np.zeros((d, d))
    while np.sum(A) < 2:
        G = ig.Graph.Barabasi(n=d, m=m, directed=True)
        A = _ig_to_adjmat(G)

    assert nx.is_directed_acyclic_graph(nx.from_numpy_array(A, create_using=nx.DiGraph)), "Error: graph is not a DAG!"
    return A


def barabasi_albert_out(d, m):
    """
    Generate Barabasi Albert SF graph with d nodes and m (expected) edges per node.
    Hub nodes have large out-degree. 
    This is equivalent to paul_barabasi_albert().

    Parameters
    ----------
    d : int
        Number of nodes.
    m : int
        Expected edges per node.
    """
    A = barabasi_albert_in(d, m)
    return A.transpose(1, 0)


def paul_barabasi_albert(d, m):
    """
    Generate Barabasi Albert as in Rolland et al. 2022.
    This is basically equivalent to barabasi_albert_out

    Parameters
    ----------
    d : int
        Number of nodes.
    m : int
        Expected edges per node.
    """
    A = np.zeros((d, d))
    while np.sum(A) < 2:
        G_und = ig.Graph.Barabasi(n=d, m=m)
        A_und = _ig_to_adjmat(G_und)
        A = _acyclic_orientation(A_und)

    assert nx.is_directed_acyclic_graph(nx.from_numpy_array(A, create_using=nx.DiGraph)), "Error: graph is not a DAG!"
    return A


def fully_connected(d : int):
    """Sample fully connected DAG. 
    At the end of the day, we are intersted in the ability to infer the topological ordering,
    the rest is just variable selection (uner causal sufficiency). A method that performs well
    here is a method that is capable of good inference of the topolgical ordering. 
    
    Parameters
    ----------
    d : int
        Number of nodes
    """
    # Sample random topological ordering of the variables
    order = np.random.permutation(range(d))
    A_full = np.triu(np.ones((d, d)), k=1)
    return A_full


def gaussian_random_partition(d, p_in, p_out) -> np.array:
    grp = GaussianRandomPartition(d, p_in, p_out)
    A = grp.gaussian_random_partition_graph()
    assert nx.is_directed_acyclic_graph(nx.from_numpy_array(A, create_using=nx.DiGraph)), "Error: graph is not a DAG!"
    return A


# def gaussian_random_partition(d, p_in, p_out):
#     grp = GaussianRandomPartition(d, p_in, p_out)
#     n_clusters = grp._sample_number_of_clusters()
#     s = d/n_clusters # avg size of the cluster
#     v = s # shape parameter. Variance of cluster size is s/v
#     G = nx.gaussian_random_partition_graph(d, s, v, p_in, p_out, directed=False)
#     A_und = nx.to_numpy_array(G)
#     A = _acyclic_orientation(A_und)
#     assert nx.is_directed_acyclic_graph(nx.from_numpy_array(A, create_using=nx.DiGraph)), "Error: graph is not a DAG!"
#     return A