import igraph as ig
import networkx as nx
import numpy as np
import torch
import random

from abc import ABCMeta, abstractmethod
from utils._data import DataSimulator
from utils._utils import directed_np2nx
from utils._random_graphs import erdos_renyi_m, erdos_renyi_p, barabasi_albert_out, gaussian_random_partition, fully_connected

###################### Base data generator ######################
class DataGenerator(metaclass=ABCMeta):
    """Base class for data generation.
    DataGenerator is implemented by a subclass for each different experimental scenario

    Parameters
    ----------
    graph_type : str
        ER, SF, or GRP graph
    num_nodes : int
        Number fo nodes in the graph
    graph_size : str
        Size of the graph: [small, medium, large]
    graph_density : str
        density of the graph: [sparse, dense]
    num_samples : int:
        Number of samples in the dataset
    noise_distr : str
        Distribution of the noise terms (Gauss or Random)
    noise_std_support : tuple[float]
        Std. deviation is smapled from a Uniform supported in the noise_std_support interval
    seed : int
        Seed for reproducibility
    GP : bool
        If True, non linearities are functions sampled from a Gaussian process
    lengthscale : float
        Bandwidth of the caussian process

    Attributes
    ----------
    A : np.array 
        Adjacency matrix of the returned by generate_data()
    """
    def __init__(
        self, 
        graph_type : str,
        num_nodes : int, 
        graph_size : str,
        graph_density : str, 
        num_samples : int, 
        noise_distr : str, 
        noise_std_support : tuple[float], 
        seed = 42, 
        GP = True, 
        lengthscale=1
    ):
        self.graph_type=graph_type
        self.num_nodes=num_nodes
        self.graph_size=graph_size
        self.graph_density=graph_density
        self.num_samples=num_samples
        self.noise_distr=noise_distr 
        self.noise_std_support=noise_std_support
        self.seed=seed
        self.GP=GP
        self.lengthscale=lengthscale
        self.adjacency = None 


    @abstractmethod
    def generate_data(
        self
    ):
        raise NotImplementedError


    def get_density_param(self):
        """Return the density parameter according to the specification of graph size and density.
        For small graphs returns probability of an edge,
        for medium and large graphs the average number of links per node.
        Density parameter scales with dimension, to account for polynomial increase of possible edges.

        Configs
        -------
        small graphs:
            sparse: p = 0.1
            dense:  p = 0.4
        medium graphs:
            sparse: m = 1
            dense:  m = 2
        large graphs (less than 50 nodes):
            sparse: m = 1
            dense:  m = 4
        large graphs (at least 50 nodes):
            sparse: m = 2
            dense:  m = 8
        """
        if self.graph_size == "small":
            density_configs = {
                "sparse": .1,
                "dense": .4
            }
        elif self.graph_size == "medium":
            density_configs = {
                "sparse": 1,
                "dense": 2
            }
        elif self.graph_size == "large20" or self.graph_size == "large30":
            density_configs = {
                "sparse": 1,
                "dense": 4
            }
        elif self.graph_size == "large50":
            density_configs = {
                "sparse": 2,
                "dense": 8
            }
        else:
            raise ValueError(f"Unkown size of the graph {self.graph_size}. Provide a value between [small, medium, large]")

        # Handle self.graph_density = "cluster" for GRP graphs
        density_configs["cluster"] = 0.4 

        # Handle self.graph_density = "full" for FC graphs
        density_configs["full"] = None 

        return density_configs[self.graph_density]


    def simulate_dag(self):
        """Simulate random DAG according to self.graph_type specified algorithm.

        Returns
        -------
            A ; np.array
                Adjacency matrix of the sampled DAG
        """
        density_param = self.get_density_param()

        if self.graph_type == "ER":
            if self.graph_size == "small":
                A = erdos_renyi_p(self.num_nodes, density_param)
            else:
                A = erdos_renyi_m(self.num_nodes, density_param)
        elif self.graph_type == "SF":
            assert self.graph_size != "small"
            A = barabasi_albert_out(self.num_nodes, density_param) # High OUT degree
        elif self.graph_type == "GRP":
            assert self.graph_size != "small"
            # TODO: remove hard coding of p_in and p_out. In the paper plot how they result in practice
            p_out = 0.06 if self.num_nodes <= 20 else 0.03
            A = gaussian_random_partition(self.num_nodes, p_in=density_param, p_out=p_out)
        elif self.graph_type == "FC": # Fully connected
            A = fully_connected(self.num_nodes)  
        else:
            raise ValueError(f"Unkown graph type {self.graph_type}. Provie a value between [ER, SF, GRP]")
        
        # Do not permute node: else sampling is wrong!
        assert(np.allclose(A, np.triu(A)))

        self.adjacency = A


    def set_seeds(self):
        # I am sure that there is a smarter way
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)


    def full_DAG(self, top_order):
        d = len(top_order)
        A = np.zeros((d,d))
        for i, var in enumerate(top_order):
            A[var, top_order[i+1:]] = 1
        return A


###################### Vanilla setting generator ######################
class VanillaGenerator(DataGenerator):
    def __init__(
        self,
        graph_type : str,
        num_nodes : int, 
        graph_size : str,
        graph_density : str, 
        num_samples : int, 
        noise_distr : str, 
        noise_std_support : tuple[float], 
        seed = 42, 
        GP = True, 
        lengthscale=1
    ):
        super().__init__(graph_type, num_nodes, graph_size, graph_density, num_samples, noise_distr, noise_std_support, seed, GP, lengthscale)

    def generate_data(self):
        """Generate vanilla dataset according to arguments specifications

        Return
        ------
        X : torch.tensor
            Dataset
        A : np.array
            Groundtruth adjacency
        """
        self.set_seeds()
        self.simulate_dag() # Set self.adjacency

        teacher = DataSimulator(self.num_samples, self.num_nodes, self.noise_std_support, self.noise_distr, self.adjacency, GP=self.GP, lengthscale=self.lengthscale)
        X = teacher.sample()
        return X, self.adjacency


###################### Confounded setting generator ######################
class ConfundedGenerator(DataGenerator):
    def __init__(
        self, 
        rho : float,
        graph_type : str,
        num_nodes : int, 
        graph_size : str,
        graph_density : str, 
        num_samples : int, 
        noise_distr : str, 
        noise_std_support : tuple[float], 
        seed = 42, 
        GP = True, 
        lengthscale=1
    ):
        """
        rho : float
            Probabilities of adding a confounder betwen two nodes.
            Sampled at random for each pair of nodes
        """
        super().__init__(graph_type, num_nodes, graph_size, graph_density, num_samples, noise_distr, noise_std_support, seed, GP, lengthscale)
        self.rho = rho
        self.confounded_adjacency = None


    def generate_data(
        self
    ):
        """Generate confounded dataset according to arguments specifications
        """
        self.set_seeds()
        self.simulate_dag() # Set self.adjacency
        
        # Confounders are the first d nodes in the matrix
        self.confounded_adjacency = self.confound_adjacency(self.adjacency)

        teacher = DataSimulator(self.num_samples, self.num_nodes, self.noise_std_support, self.noise_distr, self.confounded_adjacency, GP=self.GP, lengthscale=self.lengthscale)
        confounded_X  = teacher.sample()

        # Delete confounders from X, i.e. drop first d columns
        d, _ = self.adjacency.shape
        X = confounded_X[:, d:]

        return X, self.adjacency


    def confound_adjacency(self, A):
        """
        Let number of confounders = number of nodes in A.
        For each confounder, let it be parent of a node of A with probability rho.

        Paramaters
        ----------
        A : np.array
            Unconfounded adjacency matrix
        """
        d, _ = A.shape

        # For each i, j pair, sample whether there is a confounder or not.
        # Then uniformly sample the confounder among the d allowed
        confounders_matrix = np.zeros((d, d))
        for i in range(d):
            for j in range(i+1, d):
                confounded = np.random.binomial(n=1, p=self.rho) == 1
                if confounded:
                    # sample a confounder among the d available.
                    confounder_node = random.choice(range(d))
                    confounders_matrix[confounder_node, i] = 1
                    confounders_matrix[confounder_node, j] = 1
        
        # Integrate confounder matrix, such that confounders are source nodes
        A_confounded = np.vstack((confounders_matrix, A))
        A_confounded = np.hstack((np.zeros(A_confounded.shape), A_confounded))
        
        assert(np.allclose(A_confounded, np.triu(A_confounded)))
        return A_confounded
    

###################### Linear Mechanisms SCM generator ######################
class LinearSCMGenerator(DataGenerator):
    """Linearities in the SCM
    """
    def __init__(
        self, 
        p_linear : float,
        graph_type : str,
        num_nodes : int, 
        graph_size : str,
        graph_density : str, 
        num_samples : int, 
        noise_distr : str, 
        noise_std_support : tuple[float], 
        seed = 42, 
        GP = True, 
        lengthscale=1
    ):
        """
        p_linear : float
            Probability of linear mechanism
        """
        super().__init__(graph_type, num_nodes, graph_size, graph_density, num_samples, noise_distr, noise_std_support, seed, GP, lengthscale)
        assert p_linear > 0 and p_linear < 1, "Probability of linear function outside of (0, 1) interval!"
        self.p_linear = p_linear

    def generate_data(self):
        """Generate vanilla dataset according to arguments specifications
        """
        self.set_seeds()

        self.simulate_dag() # Set self.adjacency
        teacher = DataSimulator(self.num_samples, self.num_nodes, self.noise_std_support, self.noise_distr, self.adjacency, GP=self.GP, lengthscale=self.lengthscale)
        X = teacher.sample(p_linear=self.p_linear)
        return X, self.adjacency
    

###################### Measure error generator ######################
class MeasureErrorGenerator(DataGenerator):
    """Random additive measurement error on the variables
    """
    def __init__(
        self, 
        gamma : float,
        graph_type : str,
        num_nodes : int, 
        graph_size : str,
        graph_density : str, 
        num_samples : int, 
        noise_distr : str, 
        noise_std_support : tuple[float], 
        seed = 42, 
        GP = True, 
        lengthscale=1
    ):
        """
        gamma : float 
            Ratio variance(eps_i)/variance(X_i). It measures signal to noise ration.
            Values accepted are in the interval (0, 1].
            Gamma = 1 means that 50 percent of the variance of measured X_i is due to eps_i error.
        """
        super().__init__(graph_type, num_nodes, graph_size, graph_density, num_samples, noise_distr, noise_std_support, seed, GP, lengthscale)
        assert gamma > 0 and gamma <= 1, "Signal to noise ratio outside of  (0, 1] interval!"
        self.gamma = gamma

    def add_noise(self, X):
        """
        Parameters
        ----------
        X : torch.tensor
            Dataset under perfect measurement
        Return
        X : torch.tensor
            Noisy version of the dataset
        """
        n, d = X.shape
        X_std = torch.std(X, dim=0)
        for node in range(d):
            s = self.gamma*X_std[node]
            error_sample = s*torch.randn((n, ))
            X[:, node] += error_sample
        return X

    def generate_data(self):
        """Generate data with random additive measure error
        """
        self.set_seeds()
        self.simulate_dag() # Set self.adjacency

        teacher = DataSimulator(self.num_samples, self.num_nodes, self.noise_std_support, self.noise_distr, self.adjacency, GP=self.GP, lengthscale=self.lengthscale)
        X = teacher.sample()
        X = self.add_noise(X)
        return X, self.adjacency
    

###################### Non independent samples generator ######################
class TiminoGenerator(DataGenerator):
    def __init__(
        self,
        graph_type : str,
        num_nodes : int, 
        graph_size : str,
        graph_density : str, 
        num_samples : int, 
        noise_distr : str, 
        noise_std_support : tuple[float], 
        seed = 42, 
        GP = True, 
        lengthscale=1
    ):
        """
        TiMINo model with lagged effect (1 timestep behind) and instantaneuos effect.
        The noise terms distribution is stationary. 
        Let X -> Y:
        X(t) = X(t-1)*c + Nx
        Y(t) = f(X(t)) + k*Y(t-1) + Ny

        X(0) = Nx
        Y(0) = f(X(0)) + Ny
        """
        super().__init__(graph_type, num_nodes, graph_size, graph_density, num_samples, noise_distr, noise_std_support, seed, GP, lengthscale)

    def make_timino(self, X : torch.tensor, linear_coeffs : np.array = None):
        """For each column, add the previous column multiplied by a random coefficient 
        uniformly sampled from (0.5, 1)
        """
        n, d = X.shape
        if linear_coeffs is None:
            linear_coeffs = torch.from_numpy(np.random.uniform(-1, 1, (d, )))
        for t in range(1, n):
            X[t] += X[t-1]*linear_coeffs
        return X

    def generate_data(self):
        """Generate TiMINo dataset according to arguments specifications
        """
        self.set_seeds()
        self.simulate_dag() # Set self.adjacency

        teacher = DataSimulator(self.num_samples, self.num_nodes, self.noise_std_support, self.noise_distr, self.adjacency, GP=self.GP, lengthscale=self.lengthscale)
        X = teacher.sample()
        X = self.make_timino(X)
        return X, self.adjacency
    

###################### Unfaithful Genearator ######################
class UnfaithfulGenerator(DataGenerator):
    def __init__(
        self,
        p_unfaithful : float,
        graph_type : str,
        num_nodes : int, 
        graph_size : str,
        graph_density : str, 
        num_samples : int, 
        noise_distr : str, 
        noise_std_support : tuple[float], 
        seed = 42, 
        GP = True, 
        lengthscale=1
    ):
        """
        p_unfaithful : float
            Probability that a potential unfaithful structure gives path cancellation
        """
        super().__init__(graph_type, num_nodes, graph_size, graph_density, num_samples, noise_distr, noise_std_support, seed, GP, lengthscale)
        self.p_unfaithful = p_unfaithful
        self.unfaithful_adj = None

    def is_a_collider(self, A : np.array, p1 : int, p2 : int, c : int):
        """
        Paramaters
        ----------
        A : np.array
            Adj. matrix with potential collider
        p1 : int
            First parent of the potential collider
        p2 : int
            Second parent of the potential collider
        c : int
            Head of the potential collider
        """
        # Check p1 -> c and p2 --> c
        collider_struct = A[p1, c] == 1 and A[p2, c] == 1
        return collider_struct
    

    def find_moral_colliders(self):
        """Find potentially unfaithful triplets in the graph, i.e. colliders with moralized parents
        First, find colliders, then check if they are moral.

        Return
        ------
        moral_colliders_toporder : List[List[int]]
            Represent moralized colliders by their topological order.
            E.g. 1->0<-2, 1->2 is uniquely represented by [1, 2, 0] toporder of the triplet
        """
        moral_colliders_toporder = list() # Represent moralized colliders by their topological order.
        for child in range(self.num_nodes):
            parents = np.flatnonzero(self.adjacency[:, child])
            n_parents = len(parents)
            if n_parents > 1: # colliders condition
                for i in range(n_parents):
                    for j in range(i+1, n_parents):
                        p_i, p_j = parents[i], parents[j]
                        assert self.is_a_collider(self.adjacency, p_i, p_j, child) # TODO remove
                        # Store collider according to its topological order
                        is_moralized = self.adjacency[p_i, p_j] + self.adjacency[p_j, p_i] == 1
                        if is_moralized:
                            moral_collider = [p_i, p_j]
                            if self.adjacency[p_j, p_i] == 1:
                                moral_collider = [p_j, p_i]
                            moral_collider.append(child)
                            moral_colliders_toporder.append(moral_collider)
        return moral_colliders_toporder


    def make_unfaithful_adj(self):
        """Make an unfaithful copy of self.adjacency.

        Return
        ------
        unfaithful_adj : np.array
            Groundtruth adjacency matrix with independences of the distribution
            unfaithful to the graph
        unfaithful_triplets_toporder : List[tuple(int)]
            Represent moralized colliders by their topological order.
            E.g. 1->0<-2, 1->2 is uniquely represented by [1, 2, 0] toporder of the triplet
        """
        def p1_wasnot_p2(p1, c, child_unfaithful_triplets):
            """
            Check that element p1 of current moral relation was not used in posiiton 2
            in previous unfaithful deletion. 
            If this is the case, then current could not be modeled in the SCM,
            thus we must leave the triplet as is 
            """ 
            previous_p2 = child_unfaithful_triplets.get(c, [[], []])[1]
            return p1 not in previous_p2

        moral_colliders_toporder = self.find_moral_colliders()
        unfaithful_adj = self.adjacency.copy()
        unfaithful_triplets_toporder = list()

        # For each child, if (p1, p2, c) lead to unfaithful deletion of p1 -> c
        # then I can not reuse p2 in position 1 for future unfaithful deletions
        fixed_edges = list()

        for triplet in moral_colliders_toporder:
            p1, p2, child = triplet
            # Check if triplet still has collider in unfaithful_adj
            if self.is_a_collider(unfaithful_adj, p1, p2, child) and not((p1, child) in fixed_edges):
                if np.random.binomial(n=1, p=self.p_unfaithful):
                    unfaithful_adj[p1, child] = 0 # remove p1 -> c
                    # Remove all others directed paths from the groundtruth and adj graph
                    unfaithful_triplets_toporder.append(triplet)
                    if (p2, child) not in fixed_edges:
                        fixed_edges.append((p2, child))

        assert nx.is_directed_acyclic_graph(directed_np2nx(unfaithful_adj))
        return unfaithful_adj, unfaithful_triplets_toporder
    

    def make_unfaithful_dataset(self, X, X_noise, unfaithful_triplets_toporder):
        """For each edge mismatch between unfaithful_adj and self.adjacency,
        change X according to the unfaithful SCM.
        
        Parameters
        ----------
        X : torch.tensor
            n x d matrix of the data
        X_noise: 
            n x d matrix of the additive noise terms
        unfaithful_triplets_toporder : List[List[int]]
            Represent moralized colliders by their topological order.
            E.g. 1->0<-2, 1->2 is uniquely represented by [1, 2, 0] toporder of the triplet.
            To model unfaithfulness, add X_noise[:, 2] to X[0:, ]
        """
        edges_removed = np.transpose(np.nonzero(self.unfaithful_adj - self.adjacency))
        added_noise = dict()
        for ordered_triplet in unfaithful_triplets_toporder:
            p1, p2, child = ordered_triplet
            child_added_noise = added_noise.get(child, list())
            if p2 not in child_added_noise:
                X[:, child] += X_noise[:, p2]
                child_added_noise.append(p2)
                added_noise[child] = child_added_noise
            assert np.array([p1, child]) in edges_removed

    def generate_data(self):
        """Generate unfaithful dataset according to arguments specifications 
        TODO: store the adjacency matrix with ONLY the unfaithful edges.
              Much like I do when storing the confounded matrix
        """
        self.set_seeds()
        self.simulate_dag() # Set self.adjacency

        # Unfaithful adjacency with path canceling
        self.unfaithful_adj, unfaithful_triplets_toporder = self.make_unfaithful_adj()
        teacher = DataSimulator(self.num_samples, self.num_nodes, self.noise_std_support, self.noise_distr, self.unfaithful_adj, GP=self.GP, lengthscale=self.lengthscale)
        
        X = teacher.sample()
        additive_noise = teacher.noise

        # Correct samples accounting for unfaithful effects 
        self.make_unfaithful_dataset(X, additive_noise, unfaithful_triplets_toporder)

        return X, self.adjacency


###################### PNL setting generator ######################
class PNLGenerator(VanillaGenerator):
    def __init__(
        self,
        exponent : float,
        graph_type : str,
        num_nodes : int, 
        graph_size : str,
        graph_density : str, 
        num_samples : int, 
        noise_distr : str, 
        noise_std_support : tuple[float], 
        seed = 42, 
        GP = True, 
        lengthscale=1
    ):
        super().__init__(graph_type, num_nodes, graph_size, graph_density, num_samples, noise_distr, noise_std_support, seed, GP, lengthscale)
        self.exp = exponent

    def generate_data(self):
        """Generate PNL data applying invertible nonlinearity to vanilla data.

        Return
        ------
        X : torch.tensor
            Dataset
        A : np.array
            Groundtruth adjacency
        """
        X, adj = super().generate_data()
        X = torch.pow(X, self.exp) # Apply post nonlinearity
        return X, adj
