import torch
import time
import ges
import statsmodels.api as sm
from sklearn.kernel_ridge import KernelRidge
from sklearn.model_selection import cross_val_predict, train_test_split

# GES
# import networkx as nx
# from cdt.causality.graph import GES

from modules.cam import CAM
from modules.hsic import RbfHSIC
from modules.utils import *

base_folder = "/home/francescom/Research/DAS-Extension/src"

def Stein_score(X, eta_G, s = None):
    n, d = X.shape
    
    X_diff = X.unsqueeze(1)-X
    if s is None:
        D = torch.norm(X_diff, dim=2, p=2)
        s = D.flatten().median()
    K = torch.exp(-torch.norm(X_diff, dim=2, p=2)**2 / (2 * s**2)) / s
    
    nablaK = -torch.einsum('kij,ik->kj', X_diff, K) / s**2
    G = torch.matmul(torch.inverse(K + eta_G * torch.eye(n)), nablaK)

    return G


def Stein_hess_diag(X, eta_G, eta_H, s = None):
    """
    Estimates the diagonal of the Hessian of log p_X at the provided samples points
    """
    n, d = X.shape
    
    X_diff = X.unsqueeze(1)-X
    if s is None:
        D = torch.norm(X_diff, dim=2, p=2)
        s = D.flatten().median()
    K = torch.exp(-torch.norm(X_diff, dim=2, p=2)**2 / (2 * s**2)) / s
    
    nablaK = -torch.einsum('kij,ik->kj', X_diff, K) / s**2
    G = torch.matmul(torch.inverse(K + eta_G * torch.eye(n)), nablaK)
    
    nabla2K = torch.einsum('kij,ik->kj', -1/s**2 + X_diff**2/s**4, K)
    return -G**2 + torch.matmul(torch.inverse(K + eta_H * torch.eye(n)), nabla2K)


def Stein_hess_col(X_diff, G, K, v, s, eta, n):
    """
    See https://arxiv.org/pdf/2203.04413.pdf Section 2.2 and Section 3.2 (SCORE paper)
        Args:
            X_diff (tensor): X.unsqueeze(1)-X difference in the NxD matrix of the data X
            G (tensor): G stein estimator 
            K (tensor): evaluated gaussian kernel
            s (float): kernel width estimator
            eta (float): regularization coefficients
            n (int): number of input samples

        Return:
            Hess_v: estimator of the v-th column of the Hessian of log(p(X))
    """
    Gv = torch.einsum('i,ij->ij', G[:,v], G)
    nabla2vK = torch.einsum('ik,ikj,ik->ij', X_diff[:,:,v], X_diff, K) / s**4
    nabla2vK[:,v] -= torch.einsum("ik->i", K) / s**2
    Hess_v = -Gv + torch.matmul(torch.inverse(K + eta * torch.eye(n)), nabla2vK)

    return Hess_v


def Stein_hess_matrix(X, s, eta):
    """
    Compute the Stein Hessian estimator matrix for each sample in the dataset

    Args:
        X: N x D matrix of the data
        s: kernel width estimate
        eta: regularization coefficient

    Return:
        Hess: N x D x D hessian estimator of log(p(X))
    """
    n, d = X.shape
    
    X_diff = X.unsqueeze(1)-X
    K = torch.exp(-torch.norm(X_diff, dim=2, p=2)**2 / (2 * s**2)) / s
    
    nablaK = -torch.einsum('ikj,ik->ij', X_diff, K) / s**2
    G = torch.matmul(torch.inverse(K + eta * torch.eye(n)), nablaK)
    
    # Compute the Hessian by column stacked together
    Hess = Stein_hess_col(X_diff, G, K, 0, s, eta, n) # Hessian of col 0
    Hess = Hess[:, None, :]
    for v in range(1, d):
        Hess = torch.hstack([Hess, Stein_hess_col(X_diff, G, K, v, s, eta, n)[:, None, :]])
    
    return Hess


def heuristic_kernel_width(X):
    """
    Estimator of width parameter for gaussian kernel

    Args:
        X (tensor): N x D matrix of the data

    Return: 
        s(float): estimate of the variance in the kernel
    """
    X_diff = X.unsqueeze(1)-X
    D = torch.norm(X_diff, dim=2, p=2)
    s = D.flatten().median()
    return s


def estimate_residuals(X, alpha, gamma, n_cv):
    """
    Estimate the residuals by fitting a GLM.
    For each variable X_j, regress X_j on all the remainig varibales of X, and estimate the residuals
    Return: 
        n x d matrix of the residuals estimates
    """
    R = []
    for i in range(X.shape[1]):
        response = X[:, i]
        explainatory =  np.hstack([X[:,0:i], X[:,i+1:]])
        regr = KernelRidge(kernel='rbf', gamma=gamma, alpha=alpha)
        pred = cross_val_predict(regr, explainatory, response, cv=n_cv)
        R_i = response - pred
        R.append(R_i)

    return np.vstack(R).transpose()


def pred_err(X, Y, alpha, gamma, n_cv):
    """
    """
    err = []
    _, d = Y.shape
    for col in range(d):
        response = Y[:, col]
        explainatory = X[:, col].reshape(-1, 1)
        regr = KernelRidge(kernel='rbf', gamma=gamma, alpha=alpha)
        
        ############## Remove CV ##############
        # X_train, X_test, y_train, y_test = train_test_split(explainatory, response, test_size=0.3, random_state=42)
        # regr = regr.fit(X_train, y_train)
        # pred = regr.predict(X_test)
        # res = y_test - pred

        pred = cross_val_predict(regr, explainatory, response, cv=n_cv)
        res = response-pred
        mse = (res**2).mean().item() # rmse = mse / (pred**2).mean().item() # relative MSE. Doesn't work
        err.append(mse)
    return err


def regression_pvalues(X, y):
    X = sm.add_constant(X)
    glm = sm.GLM(y, X, family=sm.families.Gaussian())
    glm_results = glm.fit()

    return glm_results.pvalues[1:]


def top_order(X, eta_G, alpha, gamma, n_cv):
    _, d = X.shape
    top_order = []

    remaining_nodes = list(range(d))
    np.random.shuffle(remaining_nodes) # account for trivial top order
    for _ in range(d-1):
        S = Stein_score(X[:, remaining_nodes], eta_G=eta_G)
        R = estimate_residuals(X[:, remaining_nodes], alpha, gamma, n_cv)
        err = pred_err(R, S, alpha, gamma, n_cv)
        leaf = np.argmin(err)
        l_index = remaining_nodes[leaf]
        top_order.append(l_index)
        remaining_nodes = remaining_nodes[:leaf] + remaining_nodes[leaf+1:]

    top_order.append(remaining_nodes[0])
    return top_order[::-1]
    

def predict(X, eta_G, delta=0.01, sigma=3, alpha = 0.1, gamma=0.1, n_cv=5):
    """
    Predict without acyclicyty constraints. It gives best performance with CAM
    """
    # TODO: review to integrate topological ordering
    _, d = X.shape
    S = Stein_score(X, eta_G)
    R = estimate_residuals(X.numpy(), alpha, gamma, n_cv)
    A = np.zeros((d, d))

    S = S.type(torch.DoubleTensor)
    R = torch.tensor(R).type(torch.DoubleTensor)
    hsic = RbfHSIC(sigma) 

    # Account for trivial top order
    nodes = list(range(d))
    np.random.shuffle(nodes)
    for i in nodes:
        thresh = hsic.unbiased_estimator(R[:, i], S[:, i])*delta
        for j in range(i+1, d):
            hsic_values = [hsic.unbiased_estimator(R[:, j], S[:, i]), hsic.unbiased_estimator(R[:, i], S[:, j])]
            c = np.argmax(hsic_values)
            if c == 1:
                thresh = hsic.unbiased_estimator(R[:, j], S[:, j])*delta
                if hsic_values[c] > thresh:
                    A[j, i] = 1
            else:
                if hsic_values[c] > thresh:
                    A[i, j] = 1

    return A


def graph_inference(X, eta_G=0.001, alpha=0.1, gamma=0.1, n_cv=5, pruning = False, algorithm="DASExt", R=None):
    """
    Estimate adjacency matrix A and topological ordering of the variable in X from sample from data. Return estimations and execution times.
    Args:
        algorithm (str): Select the causal discovery algorithm to run. Accepted are {DAS, DASExt, SCORE, CAM, GES}
    """
    _, d = X.shape
    A = np.zeros((d, d))
    order = np.zeros((d), dtype=np.int8)
    start_time = time.time()

    if algorithm=="DASExt":
        order = top_order(X, eta_G, alpha, gamma, n_cv)
        order_time = time.time() - start_time
        if pruning:
            cam = CAM(X, 0.001)
            A = cam.pruning(full_DAG(order))

    elif algorithm=="CAM":
        cam = CAM(X, 0.001)
        A, order, order_time = cam.run(pruning=pruning)

    elif algorithm=="GES":
        A, _ = ges.fit_bic(X.numpy())
        order_time = 0

    tot_time = time.time() - start_time

    return A, order, order_time, tot_time