import GPy
import uuid
import os
import math
import igraph as ig
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import torch
from torch.distributions import Normal, Laplace, Gumbel, Exponential, Beta, Gamma, Pareto

from sklearn.ensemble import ExtraTreesRegressor
from sklearn.feature_selection import SelectFromModel

from cdt.metrics import retrieve_adjacency_matrix, SID

base_folder = "/home/francescom/Research/DAS-Extension/src"

class Dist(object):
    def __init__(self, d, noise_std = 1, noise_type = 'Gauss', adjacency = None, GP = True, lengthscale = 1, f_magn = 1, GraNDAG_like = False):
        self.d = d
        if isinstance(noise_std, (int, float)):
            noise_std = noise_std * torch.ones(self.d)
        self.GP = GP
        self.lengthscale = lengthscale
        self.f_magn = f_magn
        self.GraNDAG_like = GraNDAG_like

        
        if self.GraNDAG_like:
            noise_std = torch.ones(d)
        
        if noise_type == 'Gauss':
            self.noise = Normal(0, noise_std) 
        elif noise_type == 'Laplace': # Does SCORE fails for Laplace?
            self.noise = Laplace(0, noise_std / np.sqrt(2))
        elif noise_type == 'Gumbel':
            self.noise = Gumbel(0, noise_std * np.sqrt(6)/np.pi)
        elif noise_type == 'Exponential':
            self.noise = Exponential(noise_std)
        elif noise_type == 'Beta':
            self.noise = Beta(noise_std, noise_std)
        elif noise_type == 'Gamma': # Chi2 = Gamma(alpha=0.5*df, beta=0.5)
            self.noise = Gamma(noise_std, noise_std)
        elif noise_type == 'Pareto':
            self.noise = Pareto()
        else:
            raise NotImplementedError("Unknown noise type.")
        
        self.adjacency = adjacency
        if adjacency is None:
            self.adjacency = np.ones((d,d))
            self.adjacency[np.tril_indices(d)] = 0

        # Needs strictly upper triangular matrix

        assert(np.allclose(self.adjacency, np.triu(self.adjacency)))


    def sampleGP(self, X, lengthscale=1):
        ker = GPy.kern.RBF(input_dim=X.shape[1],lengthscale=lengthscale,variance=self.f_magn)
        C = ker.K(X,X) # Covariance matrix
        X_sample = np.random.multivariate_normal(np.zeros(len(X)),C)
        return X_sample


    def sample(self, n):
        noise = self.noise.sample((n,)) # n x d noise matrix
        X = torch.zeros(n, self.d)
        R = torch.zeros(n, self.d) # Real residuals

        # !!! Only works if adjacency matrix is upper triangular !!!
        noise_var = np.zeros(self.d)
        if self.GP:
            for i in range(self.d):
                #print("Sample function number {}".format(i))
                parents = np.nonzero(self.adjacency[:,i])[0]
                if self.GraNDAG_like: # False
                    if len(parents) == 0: # For roots, noise variance sampled U(1,2)
                        noise_var[i] = np.random.uniform(1,2)
                    else: # Otherwise, noise variance sampled U(0.4,0.8)
                        noise_var[i] = np.random.uniform(0.4,0.8)
                    X[:, i] = np.sqrt(noise_var[i]) * noise[:,i]
                    R[:, i] = np.sqrt(noise_var[i]) * noise[:,i]
                else:
                    X[:, i] = noise[:,i]
                    R[:, i] = noise[:,i]
                if len(parents) > 0:
                    X_par = X[:,parents]
                    X[:, i] += torch.tensor(self.sampleGP(np.array(X_par), self.lengthscale))
        else:
            for i in range(self.d):
                X[:, i] = noise[:,i]
                R[:, i] = noise[:,i]
                parents = np.nonzero(self.adjacency[:,i])[0]
                for p in parents:
                    # X[:, i] += torch.sin(X[:,p])
                    X[:, i] += X[:,p]**1/2
        return X, R, noise_var
    
    def log_p(self, X, active_nodes=None):
        if self.GP:
            raise NotImplementedError("Score computation not available with GPs.")
        if active_nodes is None:
            active_nodes = list(range(X.shape[1]))
        n = X.shape[0]
        d = X.shape[1]
        l = torch.zeros(n)
        for i, node_i in enumerate(active_nodes):
            fi = torch.zeros(n)
            for j, node_j in enumerate(active_nodes):
                if self.adjacency[node_j, node_i] != 0:
                    fi += torch.sin(X[:,j])
            l -= 0.5 * (X[:,i] - fi)**2
        return l



############## DAG ##############

# Only for GP = False
def exact_residuals(X, A):
    n, d  = X.shape

    def nonlinear_f(j):
        P = X[:, parents(j)]
        return torch.sin(P).sum(dim=1)

    def parents(j):
        return np.nonzero(A[:, j])[0]

    R = torch.zeros((n, d))
    for j in range(d):
        R[:, j] = X[:, j] - nonlinear_f(j)
    
    return R.numpy()

# Only for GP = False
def exact_hess_diag(X, A, noise_std=1):
    n, d  = X.shape

    def nonlinear_f(j):
        P = X[:, parents(j)]
        return torch.sin(P).sum(dim=1)

    def parents(j):
        return np.nonzero(A[:, j])[0]

    def children(j):
        return np.nonzero(A[j, :])[0]

    hess_diag = torch.ones((n, d))
    for j in range(d):
        hess_diag[:, j] = -1 / noise_std**2 
        for c in children(j):
            hess_diag[:, j] += -torch.sin(X[:, j])*(X[:, c] - nonlinear_f(c))/noise_std**2 - (torch.cos(X[:, j]) / noise_std)**2
        
    return hess_diag

# Only for GP = False
def exact_score(X, A, noise_std=1):
    n, d  = X.shape

    def nonlinear_f(j):
        P = X[:, parents(j)]
        return torch.sin(P).sum(dim=1)

    def parents(j):
        return np.nonzero(A[:, j])[0]

    def children(j):
        return np.nonzero(A[j, :])[0]

    score = torch.zeros((n, d))
    for j in range(d):
        score[:, j] = - X[:, j] / noise_std**2 
        if len(parents(j)) > 0:
            score[:, j] += nonlinear_f(j) / noise_std**2 
        for c in children(j):
            score[:, j] += torch.cos(X[:, j])*(X[:, c] - nonlinear_f(c)) / noise_std**2

    return score      


def get_data(graph_type='ER', d=None, s0=None, N=1000, noise_std = 1, noise_type = 'Gauss', GP = True, lengthscale=1, verbose=True):
    """
    Returns X data matrix and A adjacency ground truth. 
    """
    real = ['Sachs']
    ######### SYNTHETIC #########
    if graph_type not in real:
        X, R, A = generate(d, s0, N, noise_std, noise_type, graph_type, GP, lengthscale, verbose)


    ######### REAL #########
    else:
        X, A = np.load(base_folder + '/../data/sachs/continuous/data1.npy'), np.load(base_folder + '/../data/sachs/continuous/DAG1.npy')
        X = torch.tensor(X)
        R = None

    return X, R, A


def generate(d=None, s0=None, N=1000, noise_std = 1, noise_type = 'Gauss', graph_type = 'ER', GP = True, lengthscale=1, verbose=True):
    """
        Args:
            d (int): num of nodes
            s0 (int): expected num of edges
            graph_type (str): ER, SF
    """
    if verbose:
        print("Generating data...", end=" ", flush=True)
    adjacency = simulate_dag(d, s0, graph_type, triu=True)
    teacher = Dist(d, noise_std, noise_type, adjacency, GP = GP, lengthscale=lengthscale)
    X, R, _ = teacher.sample(N)
    if verbose:
        print("Done")
    return X, R, adjacency


def simulate_dag(d, s0, graph_type, triu=False):
    """Simulate random DAG with some expected number of edges.
    Args:
        d (int): num of nodes
        s0 (int): expected num of edges
        graph_type (str): ER, SF

    Returns:
        B (np.ndarray): [d, d] binary adj matrix of DAG
    """
    def _random_permutation(M):
        # np.random.permutation permutes first axis only
        P = np.random.permutation(np.eye(M.shape[0]))
        return P.T @ M @ P

    def _random_acyclic_orientation(B_und):
        if triu:
            return np.triu(B_und, k=1)
        return np.tril(_random_permutation(B_und), k=-1)

    def _graph_to_adjmat(G):
        return np.array(G.get_adjacency().data)

    if graph_type == 'ER':
        # Erdos-Renyi
        G_und = ig.Graph.Erdos_Renyi(n=d, m=s0)
        B_und = _graph_to_adjmat(G_und)
        B = _random_acyclic_orientation(B_und)
    elif graph_type == 'SF':
        # Scale-free, Barabasi-Albert
        G = ig.Graph.Barabasi(n=d, m=int(round(s0 / d)), directed=False)
        B_und = _graph_to_adjmat(G)
        B = _random_acyclic_orientation(B_und)
    elif graph_type == 'BP':
        # Bipartite, Sec 4.1 of (Gu, Fu, Zhou, 2018)
        top = int(0.2 * d)
        G = ig.Graph.Random_Bipartite(top, d - top, m=s0, directed=True, neimode=ig.OUT)
        B = _graph_to_adjmat(G)
    else:
        raise ValueError('unknown graph type')
    if not triu:
        B = _random_permutation(B)
    assert ig.Graph.Adjacency(B.tolist()).is_dag()
    return B


def full_DAG(top_order):
    d = len(top_order)
    A = np.zeros((d,d))
    for i, var in enumerate(top_order):
        A[var, top_order[i+1:]] = 1
    return A


############## PRUNING ##############


def pns_(model_adj, x, num_neighbors, thresh):
    """Preliminary neighborhood selection"""
    num_samples = x.shape[0]
    num_nodes = x.shape[1]
    print("PNS: num samples = {}, num nodes = {}".format(num_samples, num_nodes))
    for node in range(num_nodes):
        print("PNS: node " + str(node))
        x_other = np.copy(x)
        x_other[:, node] = 0
        reg = ExtraTreesRegressor(n_estimators=500)
        reg = reg.fit(x_other, x[:, node])
        selected_reg = SelectFromModel(reg, threshold="{}*mean".format(thresh), prefit=True,
                                       max_features=num_neighbors)
        mask_selected = selected_reg.get_support(indices=False).astype(np.float)

        model_adj[:, node] *= mask_selected

    return model_adj


############## METRICS ##############


def edge_errors(pred, target):
    """
    Counts all types of edge errors (false negatives, false positives, reversed edges)
    """
    true_labels = retrieve_adjacency_matrix(target)
    predictions = retrieve_adjacency_matrix(pred, target.nodes() if isinstance(target, nx.DiGraph) else None)

    diff = true_labels - predictions

    rev = (((diff + diff.transpose()) == 0) & (diff != 0)).sum() / 2
    # Each reversed edge necessarily leads to one fp and one fn so we need to subtract those
    fn = (diff == 1).sum() - rev
    fp = (diff == -1).sum() - rev

    return fn, fp, rev

def precision(n_edges, fn, fp):
    tp = n_edges - fn
    return tp / (tp+fp)

def recall(n_edges, fn):
    tp = n_edges - fn
    return tp / n_edges


def SHD(pred, target):
    return sum(edge_errors(pred, target))


############## LOGGING ##############


def np_to_csv(array, save_path):
    """
    Convert np array to .csv
    array: numpy array
        the numpy array to convert to csv
    save_path: str
        where to temporarily save the csv
    Return the path to the csv file
    """
    id = str(uuid.uuid4())
    output = os.path.join(os.path.dirname(save_path), 'tmp_' + id + '.csv')

    df = pd.DataFrame(array)
    df.to_csv(output, header=False, index=False)

    return output


def pretty_evaluate(adj, A_SCORE, top_order_err, tot_time, sid, s0, alpha, gamma, n_cv, noise, algorithm):
    fn, fp, rev = edge_errors(A_SCORE, adj)
    d = A_SCORE.shape[0]
    precision_metric = precision(s0, fn, fp)
    recall_metric = recall(s0, fn)
    pretty = """\n----------------------------------------------------"""

    if algorithm=="DASExt":
        pretty += """
alpha:                              {alpha}
gamma:                              {gamma}
Number of CV splits:                {n_cv}

----------------------------------------------------
        """

    pretty += f"""

Number of nodes:                    {d}
Number of edges:                    {s0}
Noise:                              {noise}

----------------------------------------------------

Total execution time:               {round(tot_time, 2)}s
False negative:                     {int(fn)}
False positive:                     {int(fp)}
Recall:                             {round(recall_metric, 2)}
Precision:                          {round(precision_metric, 2)}
Reversed:                           {int(rev)}
SHD:                                {SHD(A_SCORE, adj)}
"""

    if sid:
        pretty += f"""
SID:                                {int(SID(target=adj, pred=A_SCORE))}
""".lstrip()

    if not algorithm=="GES":
        pretty += f"""
Top order errors:                   {top_order_err}
        """

    return pretty


def plot_correlation(X, y, k, l, path):
    _, d = X.shape
    k, l = 3, 4
    fig, ax = plt.subplots(k, l)
    for j in range(d):
        m, n = (math.floor(j / l), j % l)
        x = X[:, j]
        ax[m, n].scatter(x, y, s=0.5, alpha=0.6)

    fig.tight_layout()
    fig.savefig(path)
    plt.close('all')


def plot_res_distance(R_dist, path):
    R_dist = R_dist.numpy()
    n, d = R_dist.shape
    x = list(range(n))
    for i in range(d):
        plt.scatter(x, R_dist[:, i], s=0.5, alpha=0.6)
        plt.savefig(path + f"/R{i}")
        plt.close('all')


############## TOP ORDERING ############## 
def num_errors(order, adj):
    err = 0
    for i in range(len(order)):
        err += adj[order[i+1:], order[i]].sum()
    return err


def fullAdj2Order(A):
    order = list(A.sum(axis=1).argsort())
    order.reverse()
    return order


############## SERGIO-TESTING ##############
def ground_truth(d, path):
    A = np.zeros((d, d))
    ground_truth = pd.read_csv(path, header=None).to_numpy()
    for row in ground_truth:
        src, dest = row
        A[src, dest] = 1

    return A