import time 
import numpy as np
import torch 
import experiments.utils

# DAG-TFRC (Ours)
from dagTFRC.dagTFRC import dagTFRC_solver

# SparseRC
from experiments.methods.sparserc.sparserc import sparserc_solver

# LiNGAM 
from experiments.methods.lingam import lingam
from experiments.methods.lingam.lingam import ICALiNGAM, DirectLiNGAM

# # DYNOTEARS
# causalnex imports
from causalnex.structure import StructureModel
from causalnex.structure import dynotears
from causalnex.structure.data_generators import wrappers
from causalnex.structure.dynotears import from_pandas_dynamic
from causalnex.network import BayesianNetwork
import pandas as pd
from causalnex.evaluation import roc_auc

# NTS-NOTEARS imports
from experiments.methods.NTSNOTEARS.notears.locally_connected import LocallyConnected
from experiments.methods.NTSNOTEARS.notears.lbfgsb_scipy import LBFGSBScipy
from experiments.methods.NTSNOTEARS.notears.trace_expm import trace_expm
from experiments.methods.NTSNOTEARS.notears.utils import *
import experiments.methods.NTSNOTEARS.notears.utils as ut

############ tsFCI, TiMINO 
import subprocess # to run R scripts
Rscript = 'C:/"Program Files"/R/R-4.2.1/bin/Rscript'
path_to_tsfci = './experiments/methods/tsFCI/tsfci.R'
path_to_data_tsfci = './experiments/methods/tsFCI/data_tsfci'
path_to_timino = './experiments/methods/TiMINO/timino.R'
path_to_data_timino = './experiments/methods/TiMINO/data_timino'

##################### PCMCI ####################
from tigramite.pcmci import PCMCI
from tigramite.lpcmci import LPCMCI
from tigramite.independence_tests.parcorr import ParCorr
from tigramite.independence_tests.cmiknn import CMIknn
import tigramite.data_processing as pp
# pcmci-Omega
from experiments.methods.pcmci_omega.PCMCI_OMEGA_ContinuousData import algorithm_v2_mci_

############## only symmary graph #############
# TCDF 
from experiments.methods.TCDF.runTCDF import runTCDF
###############################################


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def execute_R(data, d, omega, number_of_lags, Rscript, path, datapath, method, sig_level=0.01):

    # prepare data so R file can take them as input
    if method == "timino":
        data.to_csv(datapath + "/data.csv", index=False)
    elif method == "tsfci":
        data = data.rename(columns = {i:'X{}'.format(i + 1) for i in range(d * (number_of_lags + 1))})
        data.to_csv(datapath + "/data.csv", sep=" ", index=False)
    with open(datapath + '/omega.txt', 'w') as f:
        f.write('{:.3f}'.format(omega))
    with open(datapath + '/nlags.txt', 'w') as f:
        f.write('{:.0f}'.format(number_of_lags))
    with open(datapath + '/sig_level.txt', 'w') as f:
        f.write('{:.3f}'.format(sig_level))

    # execute the R script that executes the method
    start = time.time()
    res = subprocess.run('{} {} {}/data.csv {}/omega.txt {}/nlags.txt {}/sig_level.txt'.format(Rscript, path, datapath, datapath, datapath, datapath), shell=True, capture_output=True, text=True)
    T = time.time() - start
    # print standard output
    print(res.stdout)

    # read result and check for errors
    W_est = np.zeros((d, (number_of_lags + 1) * d))
    if res.returncode == 0:
        W_est = pd.read_csv(datapath + "/result.csv").to_numpy()[:d, :] # this works for tsFCI
        # timino gives summary graph, NA's are set to zero.
    else:
        print('R Error:\n {0}'.format(res.stderr))


    return W_est, T


def execute_method(X, method, args, n, d, t, dataset="time_series", search_params=None, ground_truth=None):
    # X has shape n x t x d where 
    # n: number of independent instantiations
    # t: length of the time-series
    # d: number of nodes  
    # search_params: used for hyperparameter search, otherwise we used the best-performing hyperparameters
    # ground truth: for algorithms that output edges with ambiguity we need the ground truth to allow the method get the correct result

    if dataset == 'time_series':
        best_params = {
            "dagTFRC" : {"lambda1" : 0.001, "lambda2" :  1, "omega": 0.09},
            "sparserc" : {"lambda1" : 0.001, "lambda2" :  1, "lambda3" : 0.001, "omega": 0.09},
            "dynotears" : {"lambda_w" : 0.01, "lambda_a" : 0.01, "omega": 0.09}, #{"lambda_w" : 0.05, "lambda_a" : 0.05},
            "nts-notears" : {"lambda1" : 0.002, "lambda2" :  0.01, "omega": 0.09}, # {"lambda1" : 0.001, "lambda2" :  0.05},
            "tsfci" : {"sig_level" : 0.1, "omega": 0.09}, # {"sig_level" : 0.01},
            "pcmci" : {"pc_alpha" : 0.1, "alpha_level" : 0.01, "omega": 0.09}, # {"pc_alpha" : 0.05, "alpha_level" : 0.05},
            "TCDF" : {"significance" :  1., "nrepochs" : 1000, "omega": 0.09}
        }
    else:
        best_params = {
            "dagTFRC" : {"lambda1" : 0.0001, "lambda2" :  1, "omega": 0.5}, #{"lambda1" : 0.001, "lambda2" :  1},
            "sparserc" : {"lambda1" : 0.0001, "lambda2" :  1, "lambda3" :  0.1, "omega": 0.3}, #{"lambda1" : 0.001, "lambda2" :  1},
            "varlingam" : {"omega": 0.5}, #{"lambda1" : 0.001, "lambda2" :  1},
            "d_varlingam" : {"omega": 0.6}, #{"lambda1" : 0.001, "lambda2" :  1},
            "dynotears" : {"lambda_w" : 0.05, "lambda_a" : 0.01, "omega": 0.3}, #{"lambda_w" : 0.05, "lambda_a" : 0.05},
            "nts-notears" : {"lambda1" : 0.001, "lambda2" :  1, "omega": 0.1}, # {"lambda1" : 0.001, "lambda2" :  0.05},
            "tsfci" : {"sig_level" : 0.001, "omega": 0.1}, # {"sig_level" : 0.01},
            "pcmci" : {"pc_alpha" : 0.1, "alpha_level" : 0.01, "omega": 0.1}, # {"pc_alpha" : 0.05, "alpha_level" : 0.05},
            "TCDF" : {"significance" :  0.8, "nrepochs" : 1000, "omega": 0.2}
        }

    if search_params is None and method in best_params.keys():
        params = best_params[method]
    else:
        params = search_params

    if method == 'dagTFRC':
        start = time.time()
        if (dataset == "time_series"):
            W = dagTFRC_solver(X, lambda1=params["lambda1"], lambda2=params["lambda2"], time_lag=args.algo_lags, epochs=args.sparserc_epochs, omega=args.omega, T=t)
            L = min(args.number_of_lags, args.algo_lags)
            W_est = W[:d, :(L + 1) * d]
            if args.number_of_lags > args.algo_lags:
                W_est = np.concatenate([W_est, np.zeros((d, d * (args.number_of_lags - args.algo_lags)))], axis=1)
        
        elif(dataset in ["thames", "us_temps", "stocks", "finance", "fMRI", "S&P", "swiss_temps", "dream3"]):
            a, _ = X.shape
            X = X[:int(a / t) * t, :]
            X = X.reshape((int(a / t), t, d))
            if dataset == "thames" or dataset == "us_temps":
                W = dagTFRC_solver(X, lambda1=0.001, lambda2=1, time_lag=args.number_of_lags, epochs=args.sparserc_epochs, omega=args.omega, T=t)
            elif dataset == "stocks":
                W = dagTFRC_solver(X, lambda1=0.001, lambda2=1, time_lag=args.number_of_lags, epochs=args.sparserc_epochs, omega=args.omega, T=t)
            elif dataset == "finance":
                W = dagTFRC_solver(X, lambda1=params["lambda1"], lambda2=params["lambda2"], time_lag=args.number_of_lags, epochs=args.sparserc_epochs, omega=params["omega"], T=t)
            elif dataset == "fMRI":
                W = dagTFRC_solver(X, lambda1=0.01, lambda2=2, time_lag=args.number_of_lags, epochs=args.sparserc_epochs, omega=args.omega, T=t)
            elif dataset == "S&P":
                W = dagTFRC_solver(X, lambda1=args.lambda1, lambda2=args.lambda2, time_lag=args.number_of_lags, epochs=args.sparserc_epochs, omega=args.omega, T=t)
            elif dataset == "swiss_temps":
                W = dagTFRC_solver(X, lambda1=args.lambda1, lambda2=args.lambda2, time_lag=args.number_of_lags, epochs=args.sparserc_epochs, omega=args.omega, T=t)
            elif dataset == "dream3":
                W = dagTFRC_solver(X, lambda1=args.lambda1, lambda2=args.lambda2, time_lag=args.algo_lags, epochs=args.sparserc_epochs, omega=args.omega, T=t)
            W_est = W[:d, :(args.algo_lags + 1) * d]

        print(" Time for dagTFRC was {:.3f}".format(time.time() - start))
        T = time.time() - start
        B_est = W_est != 0

    elif method == 'sparserc':
        start = time.time()
        if (dataset == "time_series"):
            X_past = experiments.utils.X_past(X, args.algo_lags)
            X_past = X_past.reshape((n * t, d * (args.algo_lags + 1)))
            W = sparserc_solver(X_past, lambda1=params["lambda1"], lambda2=params["lambda2"], epochs=args.sparserc_epochs, lambda3=params["lambda3"], omega=params["omega"], T = args.algo_lags + 1)
            W_est = W[:d, :(args.number_of_lags + 1) * d]
            L = min(args.number_of_lags, args.algo_lags)
            W_est = W[:d, :(L + 1) * d]
            if args.number_of_lags > args.algo_lags:
                W_est = np.concatenate([W_est, np.zeros((d, d * (args.number_of_lags - args.algo_lags)))], axis=1)
                
        elif(dataset in ["thames", "us_temps", "stocks", "finance", "fMRI", "S&P", "swiss_temps", "dream3"]):
            a, _ = X.shape
            X = X[:int(a / t) * t, :]
            X = X.reshape((int(a / t), t, d))
            X_past = experiments.utils.X_past(X, args.algo_lags)
            X_past = X_past.reshape((int(a / t) * t, d * (args.algo_lags + 1)))
            if dataset == "finance":
                W = sparserc_solver(X_past, lambda1=params["lambda1"], lambda2=params["lambda2"], lambda3=params["lambda3"], epochs=args.sparserc_epochs, omega=params["omega"], T=args.number_of_lags + 1)
            elif dataset == "S&P":
                W = sparserc_solver(X_past, lambda1=0.0001, lambda2=10, lambda3=0.0001, epochs=args.sparserc_epochs, omega=args.omega, T=args.number_of_lags + 1)
            elif dataset == "dream3":
                W = sparserc_solver(X_past, lambda1=args.lambda1, lambda2=args.lambda2, epochs=args.sparserc_epochs, lambda3=0.001, omega=args.omega, T=args.number_of_lags + 1)
            W_est = W[:d, :(args.number_of_lags + 1) * d]

        print(" Time for sparserc was {:.3f}".format(time.time() - start))
        T = time.time() - start
        B_est = W_est != 0

    ############### Functional causal model based
    elif method in ['varlingam', 'd_varlingam']:
        if method == 'varlingam':
            model = lingam.VARLiNGAM(lags=args.algo_lags, criterion=None, lingam_model=ICALiNGAM())
        elif method == 'd_varlingam':
            model = lingam.VARLiNGAM(lags=args.number_of_lags, criterion=None, lingam_model=DirectLiNGAM())

        if (dataset == "time_series"):
            data = pd.DataFrame(X.reshape((n * t, d))) 
            start = time.time()
            model.fit(data)
            T = time.time() - start
            W_est = np.concatenate([W.T for W in model.adjacency_matrices_], axis=1)
            W_est = np.where(np.abs(W_est) > args.omega, W_est, 0) #thresholding

        elif(dataset in ["thames", "us_temps", "stocks", "finance", "fMRI", "S&P", "swiss_temps", "dream3"]):
            data = pd.DataFrame(X) 
            start = time.time()
            model.fit(data)
            T = time.time() - start
            W_est = np.concatenate([W.T for W in model.adjacency_matrices_], axis=1)
            W_est = np.where(np.abs(W_est) > params["omega"], W_est, 0) #thresholding
            
        L = min(args.number_of_lags, args.algo_lags)
        W_est = W_est[:d, :(L + 1) * d]
        if args.number_of_lags > args.algo_lags:
            W_est = np.concatenate([W_est, np.zeros((d, d * (args.number_of_lags - args.algo_lags)))], axis=1)
        B_est = W_est !=0


    elif method == 'timino':
        # timino outputs summary graph
        # the code below requires modification to work properly
        data = pd.DataFrame(X.reshape((n * t, d)))
        W_est, B_est, T = execute_R(data, d, args.omega, args.number_of_lags, Rscript, path_to_timino, path_to_data_timino, method=method)

    ############### Continuous Optimization
    elif method == "dynotears":
        if (dataset == "time_series"):
            time_series = [pd.DataFrame(X[i,:,:]) for i in range(X.shape[0])]
            assert len(time_series) == n # should have length equal to the n realizations of the time series
        elif(dataset in ["thames", "us_temps", "stocks", "finance", "fMRI", "S&P", "swiss_temps", "dream3"]):
            time_series = [pd.DataFrame(X)]
        start = time.time()
        # best params lambda_w=.05, lambda_a=.05
        if (dataset in ["S&P", "dream3"]):
            g_learnt = from_pandas_dynamic(time_series, p=1, lambda_w=0.1, lambda_a=0.1, w_threshold=0)
        else:
            g_learnt = from_pandas_dynamic(time_series, p=args.algo_lags, lambda_w=params["lambda_w"], lambda_a=params["lambda_a"], w_threshold=params["omega"])
        T = time.time() - start
        B_est, W_est = experiments.utils.network_to_numpy(g_learnt, d, args.number_of_lags)
        if (dataset in ["dream3"]):
            B_est = (B_est[:d, :d] +  B_est[:d, d: 2 * d]) != 0
            W_est = B_est
        else: 
            L = min(args.number_of_lags, args.algo_lags)
            W_est = W_est[:d, :(L + 1) * d]
            if args.number_of_lags > args.algo_lags:
                W_est = np.concatenate([W_est, np.zeros((d, d * (args.number_of_lags - args.algo_lags)))], axis=1)
            B_est = W_est !=0

    elif method == "nts-notears":
        if (dataset == "time_series"):
            data = np.array(X.reshape((n * t, d)), dtype=np.float32)
        elif(dataset in ["thames", "us_temps", "stocks", "finance", "fMRI", "S&P", "swiss_temps", "dream3"]):
            data = np.array(X).astype(np.float32)
        start = time.time()
        model = NTS_NOTEARS(dims=[d, 10, 1], bias=True, number_of_lags=args.number_of_lags,
                            prior_knowledge=None, variable_names_no_time=['X{}'.format(j) for j in range(1, d + 1)])


        # best params  lambda1=0.001, lambda2=0.05,
        W = train_NTS_NOTEARS(model, data, device=device, lambda1=params["lambda1"], lambda2=params["lambda2"],
                                        w_threshold=params["omega"], h_tol=1e-60, verbose=1)
        T = time.time() - start
        # W_est = [ 0 0 0 B_τ 
        #           0 0 0 Β_{τ-1}
        #           ...
        #           0 0 0 Β_1
        #           0 0 0 Β_0
        W_est = np.zeros((d, (args.number_of_lags + 1) * d))
        for i in range(args.number_of_lags + 1):
            W_est[:, i * d : (i + 1) * d] = W[(args.number_of_lags - i) * d:(args.number_of_lags + 1 - i) * d, -d:]
        B_est = W_est !=0

    ############### Constraint-based methods
    elif method == 'tsfci':
        if (dataset == "time_series"):
            data = experiments.utils.get_lagged_data(X.reshape((n * t, d)), args.number_of_lags)
        elif(dataset in ["thames", "us_temps", "stocks", "finance", "fMRI", "S&P", "swiss_temps", "dream3"]):
            data = experiments.utils.get_lagged_data(X, args.number_of_lags)
        data = pd.DataFrame(data)
        W, T = execute_R(data, d, params["omega"], args.number_of_lags, Rscript, path_to_tsfci, path_to_data_tsfci, method=method, sig_level=params["sig_level"])
        # W_est contains 1 for arrow tail, 2 for arrow-head and 3 for circle. 
        W_est = np.zeros((d, d * (args.number_of_lags + 1)))
        for t in range(args.number_of_lags + 1):
            for i in range(d):
                for j in range(i):
                    # We let the PAG get the correct choice in ambiguous cases.
                    if W[i, j + d * t] == 1 and W[j, i + d * t] == 2: #case <--
                        W_est[j, i + d * t] = 1
                    if W[i, j + d * t] == 2 and W[j, i + d * t] == 1: #case -->
                        W_est[i, j + d * t] = 1
                    if W[i, j + d * t] == 2 and W[j, i + d * t] == 2: #case <->
                        W_est[i, j + d * t] = 0
                        W_est[j, i + d * t] = 0
                    if W[i, j + d * t] == 2 and W[j, i + d * t] == 3: #case o->
                        W_est[i, j + d * t] = ground_truth[i, j + d * t]
                    if W[i, j + d * t] == 3 and W[j, i + d * t] == 2: #case <-o
                        W_est[j, i + d * t] = ground_truth[j, i + d * t]
                    if W[i, j + d * t] == 3 and W[j, i + d * t] == 3: #case o-o
                        W_est[i, j + d * t] = ground_truth[i, j + d * t]
                        W_est[j, i + d * t] = ground_truth[j, i + d * t]

        B_est = W_est != 0

    elif method in ['pcmci', 'pcmci+', 'lpcmci']:
        if (dataset == "time_series"):
            data = pp.DataFrame(X.reshape((n * t, d)))
        elif(dataset in ["thames", "us_temps", "stocks", "finance", "fMRI", "S&P", "swiss_temps", "dream3"]):
            data = pp.DataFrame(np.array(X))
        start = time.time()

        if method == 'pcmci':
            model = PCMCI(dataframe=data, cond_ind_test=ParCorr(), verbosity=0)
            # best params pc_alpha=0.05, alpha_level=0.05
            results = model.run_pcmci(tau_min=0, tau_max=args.number_of_lags, pc_alpha=params["pc_alpha"], alpha_level=params["alpha_level"])
        elif method == 'pcmci+':
            model = PCMCI(dataframe=data, cond_ind_test=ParCorr(), verbosity=0)
            results = model.run_pcmciplus(tau_min=0, tau_max=args.number_of_lags, pc_alpha=0.01)
        elif method == 'lpcmci':
            model = LPCMCI(dataframe=data, cond_ind_test=ParCorr(), verbosity=0)
            results = model.run_lpcmci(tau_min=0, tau_max=args.number_of_lags, pc_alpha=0.05)

        T = time.time() - start

        # causal link from :math:`i` to :math:`j` at lag :math:`\\tau`
        graph = np.where(np.isin(results['graph'], ['-->']), 1, 0)
        # denotes an unoriented, contemporaneous adjacency between :math:`i` and :math:`j` indicating Markov equivalence => we count it as correct
        graph_yes = np.where(np.isin(results['graph'], ['o-o']), 1, 0)
        # directionality is undecided due to conflicting orientation rules => No edge
        graph_no = np.where(np.isin(results['graph'], ['x-x']), 0, 1)

        total_graph = (graph + 2 * graph_yes) * graph_no
        W_est = np.concatenate([total_graph[:,:,i] for i in range(args.number_of_lags + 1)], axis=1)

        # Whenever there is ambiguity we assume PCMCI made the correct choice.
        W_est = np.where(W_est == 2, ground_truth, W_est)

        # for LPCMCI
        # A = np.concatenate([np.where(np.isin(results['graph'][:, :, i], ['-->', 'o->','x->', '<->', 'o-o', 'x-x']), 1, 0) for i in range(number_of_lags + 1)], axis=1)
        # B = np.concatenate([np.where(np.isin(results['graph'][:, :, i].T, ['<--','<-o', '<-x', '<->', 'o-o', 'x-x']), 1, 0) for i in range(number_of_lags + 1)], axis=1)
        # W_est = np.where(A | B, 1, 0)

        B_est = (W_est != 0)

    elif method == 'pcmci_omega':
        # cannot yet run because selected_links is deprecated in the new version
        data = X.reshape((n * t, d))
        start = time.time()
        tem_array, omega_hat4, superset_bool, elapsed_time = algorithm_v2_mci_(data, tau_max_pcmci=args.number_of_lags, search_omega_bound=1)
        T = time.time() - start  
        print(tem_array.shape, omega_hat4, elapsed_time)
        B_est = (tem_array != 0)

    ############## Neural Networks 
    elif method == 'TCDF':
        start = time.time()
        if (dataset == "time_series"):
            data = pd.DataFrame(X.reshape((n * t, d)))
            print("Kernel size is {} ############".format(args.algo_lags + 1))
            allcauses, alldelays, allreallosses, allscores, columns = \
                runTCDF(data, kernel_size=args.algo_lags + 1, dilation_c=args.algo_lags + 1, nrepochs=params["nrepochs"], significance=params["significance"])

        elif(dataset in ["thames", "us_temps", "stocks", "finance", "fMRI", "S&P", "swiss_temps", "dream3"]):
            data = pd.DataFrame(X)
            allcauses, alldelays, allreallosses, allscores, columns = \
                runTCDF(data, kernel_size=args.algo_lags + 1, dilation_c=args.algo_lags + 1, nrepochs=params["nrepochs"], significance=params["significance"])

        T = time.time() - start
        W_est = np.zeros((d, (args.algo_lags + 1) * d))

        for effect in allcauses.keys():
            for cause in allcauses[effect]:
                delay = alldelays[(effect, cause)]
                if delay <= args.algo_lags: # just to prevent throwing error
                    W_est[cause, effect + d * delay] = 1 # Updating B_delay
        
        L = min(args.number_of_lags, args.algo_lags)
        W_est = W_est[:d, :(L + 1) * d]
        if args.number_of_lags > args.algo_lags:
            W_est = np.concatenate([W_est, np.zeros((d, d * (args.number_of_lags - args.algo_lags)))], axis=1)
        B_est = W_est !=0
        print(B_est.shape, W_est.shape)

    else: 
        print("method not implemented")
    
    return B_est, W_est, T
    