"""
This script computes context-specific graphs for time series data with endogenous contexts, using SAC-PCMCI, PAC-PCMCI, M-PCMCI, B-PCMCI and P-PCMCI causal discovery methods. It supports parallel execution via MPI and sequential execution.
"""

import numpy as np

import time
import sys
import socket
import os
import mpi
import copy
import dill as pickle
import argparse

sys.path.append('./../')

from matplotlib import pyplot as plt
from numpy.random import SeedSequence

from config_generator import generate_name_from_params, generate_configurations, Config, generate_string_from_params

from tigramite.pcmci import PCMCI
from tigramite.toymodels import structural_causal_processes as toys
from tigramite import data_processing as pp
import tigramite.plotting as tp
import generate_data as gd

from dataclasses import dataclass, fields

from tigramite.independence_tests.parcorr import ParCorr
from tigramite.independence_tests.robust_parcorr import RobustParCorr
from tigramite.independence_tests.cmiknn import CMIknn
from tigramite.independence_tests.cmiknn_mixed import CMIknnMixed

from tigramite.independence_tests.regressionCI import RegressionCI

import importlib
import copy
from endo_regime_pcmci.persistent_endo_cit import PersistentEndoCIT
from endo_regime_pcmci.sparse_endo_cit import SparseEndoCIT
from endo_regime_pcmci.mask_generator import MaskGenerator
from endo_regime_pcmci.mixed_test_pcmci import MixedTestPCMCI, PersistentEndoPCMCI, SparseEndoPCMCI
from utils import count_persistence


########## SET THESE BEFORE RUNNING!
sequential = False
plot_data = False

try:
    arg = sys.argv
    num_cpus = int(arg[1])
    samples = int(arg[2])
    config_list = list(arg)[3:]
    num_configs = len(config_list)
except Exception as e:
    print(e)
    arg = ''
    num_cpus = 2
    samples = 100
    verbosity = 2
    config_list = []

time_start = time.time()


def unpack_params(params_str):
    if sequential:
        para_setup_string = params_str
    else:
        para_setup_string, sam = params_str

    paras = para_setup_string.split('-')
    # paras = [w.replace("'", "") for w in paras]
    print('PARAS', paras)

    N = int(paras[0])
    density = float(paras[1])
    max_lag = int(paras[2])
    pc_alpha = float(paras[3])
    sample_size = int(paras[4])
    regime_children_known = str(paras[5])
    nb_changed_links = int(paras[6])
    nb_regimes = int(paras[7])
    nb_repeats = int(paras[8])
    cycles_only = paras[9].lower() in ['true']
    remove_only = paras[10].lower() in ['true']
    use_cmiknnmixed = paras[11].lower() in ['true']
    imbalance_factor =  float(paras[12])
    contemp_fraction = float(paras[13])
    regime_autocorr = float(paras[14])
    endo_regime = bool(paras[15])
    contemp_context = bool(paras[16])
    save_folder = str(paras[17])

    return Config(N, 
                  density, 
                  max_lag, 
                  pc_alpha, 
                  sample_size, 
                  regime_children_known, 
                  nb_changed_links, 
                  nb_regimes, 
                  nb_repeats, 
                  cycles_only, 
                  remove_only, 
                  use_cmiknnmixed, 
                  imbalance_factor, 
                  contemp_fraction, 
                  regime_autocorr, 
                  endo_regime, 
                  contemp_context,
                  save_folder)


def calculate_graphs(params_str, seedSeq, plot_data=False, sam=None):
    # calculate regime-specific graphs for each regime
    # also calculate union graph (gt, PCMCI, unionize regime-spec graphs)
    random_state = np.random.default_rng(seedSeq)

    config_vals = unpack_params(params_str)

    folder_name = generate_name_from_params(config_vals)
    result_path = config_vals.save_folder + '/' + folder_name

    metrics_result_path = result_path + '/metrics'
    figure_path = result_path + '/figures/'
    graphs_path = figure_path + '/g_' + str(sam) + '/'
    
    if not os.path.exists(result_path):
        os.makedirs(result_path)
    if not os.path.exists(metrics_result_path):
        os.makedirs(metrics_result_path)
    if not os.path.exists(figure_path):
        os.makedirs(figure_path)
    if not os.path.exists(graphs_path):
        os.makedirs(graphs_path)
    
    child_seeds = seedSeq.spawn(config_vals.nb_regimes + 3)
    # randomly select a variable to be regime indicator
    regime_indicator_var = random_state.integers(config_vals.N)
    regime_indicator = (regime_indicator_var, 0)

    dep_coeffs = np.asarray([0.6, 0.7, 0.5, -0.4, -0.5, -0.6, -0.7, 0.4])


    links_base, true_regime_links, links_with_regime_children, regime_children, causal_order, regime_indicator = gd.generate_regime_model(config_vals.N, config_vals.density, child_seeds, 
                                                                                                                                          config_vals.nb_regimes, config_vals.max_lag, 
                                                                                                                                          config_vals.nb_changed_links,
                                                                                                                                          remove_only=config_vals.remove_only, 
                                                                                                                                          cycles_only=config_vals.cycles_only,
                                                                                                                                          dep_coeffs=dep_coeffs, 
                                                                                                                                          auto_coeffs=[0.3, -0.4, 0.4, 0.2, -0.3, -0.2],
                                                                                                                                          contemp_fraction=config_vals.contemp_fraction, 
                                                                                                                                          regime_autocorr=config_vals.regime_autocorr,
                                                                                                                                          regime_endo=config_vals.endo_regime,
                                                                                                                                          contemp_context=config_vals.contemp_context)
    
    if links_base is None:
        return None

    true_union_links = gd.unionize_links(true_regime_links, regime_children, config_vals.nb_regimes, regime_indicator)
    true_union_links_str = {var: [(val[0], val[1], "lin_f") for val in parents] for var, parents in true_union_links.items()}

    # generate data
    data, data_type = gd.generate_regime_data(config_vals.sample_size, true_regime_links, config_vals.nb_regimes, 
                                              child_seeds, regime_indicator,
                                              config_vals.max_lag, causal_order=causal_order, 
                                              imbalance_factor=config_vals.imbalance_factor,
                                              regime_thresholds=None)
    

    if data is None:
        return None

    persistence_metric, max_length = count_persistence(data, regime_indicator)

    if plot_data:
        tp.plot_time_series_graph(toys.links_to_graph(links_base))
        plt.savefig(graphs_path + 'base_' + '_trial_' + str(sam) + '.png')
        for r in range(len(true_regime_links)):
            tp.plot_time_series_graph(toys.links_to_graph(true_regime_links[r]))
            plt.savefig(graphs_path + 'graph_r_no_regind_' + str(r) + '_trial_' + str(sam) + '.png')
            tp.plot_time_series_graph(toys.links_to_graph(links_with_regime_children[r]))
            plt.savefig(graphs_path + 'graph_r_' + str(r) + '_trial_' + str(sam) + '.png')
        
        tp.plot_time_series_graph(toys.links_to_graph(true_union_links_str))
        plt.savefig(graphs_path + 'union_graph_trial_' + str(sam) + '.png')


    if np.isnan(data).any():
        return None
        

    unique_regimes =np.unique(data[:, regime_indicator[0]])

    mask_generator = MaskGenerator('persistent')
    masks = []
    for i, val in enumerate(unique_regimes):
        mask = mask_generator.generate_lagged_mask_persistent(data.T, 
                                                              [regime_indicator[0]], 
                                                              tau_max=config_vals.max_lag, 
                                                              values=[unique_regimes[i]])
        masks.append(mask.T)



    dataframes_ymask = [pp.DataFrame(data=data,
                                 mask=np.concatenate([masks[i]] * config_vals.N, axis=-1),
                                 data_type=data_type,
                                 datatime={0: np.arange(len(data))},) for i in range(config_vals.nb_regimes)]
    
    dataframe = pp.DataFrame(data=data,
                         data_type=data_type,
                         datatime={0: np.arange(len(data))})

    if config_vals.reg_children_known == 'True':  # link_assumptions
        link_assumptions = {j: {(i, -tau): 'o?o' for i in range(config_vals.N) for tau in range(config_vals.max_lag+1)} for j in range(config_vals.N)}
        link_assumptions[regime_indicator[0]] = {}
        for var, lag in regime_children:
            link_assumptions[var][regime_indicator] = '-->'
            link_assumptions[regime_indicator[0]][(var, regime_indicator[-1])] = '<--'

    elif config_vals.reg_children_known == 'and_parents':
            true_regime_parents = true_regime_links[0][regime_indicator[0]]
            true_regime_parents = [val[0] for val in true_regime_parents]
            link_assumptions = {j: {(i, -tau): 'o?o' for i in range(config_vals.N) for tau in range(config_vals.max_lag + 1)} for j in range(config_vals.N)}
            # print(link_assumptions)
    
            for var, lag in true_regime_parents:
                # print('var, lag', var, lag)
                link_assumptions[regime_indicator[0]][(var, lag)] = '-->'
                link_assumptions[var][(regime_indicator[0], lag)] = '<--'
    
            for var, lag in regime_children:
                # print('var, lag', var, lag)
                link_assumptions[var][regime_indicator] = '-->'
                link_assumptions[regime_indicator[0]][(var, regime_indicator[-1])] = '<--'
    else:
        link_assumptions = None

    # calculate y-masked graphs
    time_start_ymask = time.time()
    try:
        results_ymask = {}
        for i, regime in enumerate(unique_regimes):
            cond_ind_test_ymask = RobustParCorr(mask_type='y')
            disc_cond_ind_test_ymask = RegressionCI(mask_type='y')
                
    
            pcmci_ymask = MixedTestPCMCI(dataframe=dataframes_ymask[i], 
                                        cond_ind_test=cond_ind_test_ymask,
                                        disc_cond_ind_test=disc_cond_ind_test_ymask, 
                                        verbosity=0)
            
            results_ymask[regime] = pcmci_ymask.run_pcmciplus(tau_max=config_vals.max_lag)
        
            if plot_data:
                tp.plot_time_series_graph(
                    val_matrix=results_ymask[regime]['val_matrix'],
                    graph=results_ymask[regime]['graph'])
                plt.savefig(graphs_path + 'ymask' + str(regime) + '_trial_' + str(sam) + '.png')
    
    except Exception as e:
        return str(e)
        results_ymask = {}

    time_end_ymask = time.time()

    # calculate regime-graphs (adapted z-mask)
    time_start_persistent = time.time()
    try:
        results_persistent_regimes = dict()
        for i in range(len(unique_regimes)):
            regime = unique_regimes[i]
            print(f'Running PCMCI for regime {regime}')
            
            cond_ind_test = PersistentEndoCIT(
                mixed_cit=RegressionCI(),
                                    cont_cit=RobustParCorr(),
                                    context_vars=[regime_indicator[0]],
                                    context_values=[regime])
    
            pcmci_regimes = PersistentEndoPCMCI(dataframe=dataframe, 
                                                cond_ind_test=cond_ind_test,
                                                verbosity=0)
    
            results = pcmci_regimes.run_pcmciplus_fullpcmci(tau_max=config_vals.max_lag)
    
            # if results is not None:  
            results_persistent_regimes[regime] = results  
            tp.plot_time_series_graph(
                    val_matrix=results_persistent_regimes[regime]['val_matrix'],
                    graph=results_persistent_regimes[regime]['graph'],
                    var_names=list(range(data.shape[1])))
            plt.savefig(graphs_path + 'persistent_regime' + str(regime) + '_trial_' + str(sam) + '.png')
    except Exception as e:
        return str(e)
        results_persistent_regimes = {}
    
    time_end_persistent = time.time()


    time_start_sparse = time.time()
    try:
        results_sparse_regimes = dict()
        for i in range(len(unique_regimes)):
            regime = unique_regimes[i]
            print(f'Running sparse PCMCI for regime {regime}')
            
            cond_ind_test = SparseEndoCIT(mixed_cit=RegressionCI(),
                                    cont_cit=RobustParCorr(),
                                    context_vars=[regime_indicator[0]],
                                    context_values=[regime]) 
    
            pcmci_regimes = SparseEndoPCMCI(dataframe=dataframe,
                                            cond_ind_test=cond_ind_test,
                                            verbosity=0)
    
            results = pcmci_regimes.run_pcmciplus(tau_max=config_vals.max_lag)
    
            # if results is not None:  
            results_sparse_regimes[regime] = results  
            tp.plot_time_series_graph(
                    val_matrix=results_sparse_regimes[regime]['val_matrix'],
                    graph=results_sparse_regimes[regime]['graph'],
                    var_names=list(range(data.shape[1])))
            plt.savefig(graphs_path + 'sparse_regime' + str(regime) + '_trial_' + str(sam) + '.png')
    except Exception as e:
        return str(e)
        results_sparse_regimes = {}
    
    time_end_sparse = time.time()

    # assemble union graphs of both
    try:
        union_graph_ymask = gd.unionize_graphs([res['graph'] for res in results_ymask.values()], config_vals.nb_regimes)
    except Exception as e:
        return str(e)
        union_graph_ymask = None

    try:
        persistent_union_graph_regimes = gd.unionize_graphs([res['graph'] for res in results_persistent_regimes.values()], config_vals.nb_regimes)
        sparse_union_graph_regimes = gd.unionize_graphs([res['graph'] for res in results_sparse_regimes.values()], config_vals.nb_regimes)

    except Exception as e:
        return str(e)
        persistent_union_graph_regimes = sparse_union_graph_regimes = None

    time_start_union = time.time()
    
    cond_ind_test_union = RobustParCorr()
    if config_vals.use_cmiknnmixed:
        disc_cond_ind_test_union = CMIknnMixed(sig_samples=100, 
                                               knn_type='global',
                                               knn=0.2)
    else:
        disc_cond_ind_test_union = RegressionCI()
    
    pcmci_union = MixedTestPCMCI(dataframe=dataframe, 
                                 cond_ind_test=cond_ind_test_union, 
                                 disc_cond_ind_test=disc_cond_ind_test_union,
                                 verbosity=0)
    
    try:
        
        results_union = pcmci_union.run_pcmciplus(tau_max=config_vals.max_lag, 
                                                  pc_alpha=config_vals.pc_alpha,
                                                  link_assumptions=link_assumptions)
        if plot_data:
            tp.plot_time_series_graph(
                val_matrix=results_union['val_matrix'],
                graph=results_union['graph'])
            plt.savefig(graphs_path + 'found_pcmci_union_trial_' + str(sam) + '.png')
    except Exception as e:
        return str(e)
        results_union = dict()
        results_union['graph'] = None
        
    time_end_union = time.time()

    computation_time_ymask = time_end_ymask - time_start_ymask
    computation_time_persistent_regimes = time_end_persistent - time_start_persistent
    computation_time_sparse_regimes = time_end_sparse - time_start_sparse
    computation_time_union = time_end_union - time_start_union

    return {
        'base_links': links_base,
        'causal_order': causal_order,
        'true_regime_links': true_regime_links,
        'true_regime_children': regime_children,
        'true_regime_links_with_regime_ind': links_with_regime_children,
        'true_union_links': true_union_links_str,
        'regime_indicator': regime_indicator,

        'computation_time_ymask': computation_time_ymask,
        'computation_time_persistent_regimes': computation_time_persistent_regimes,
        'computation_time_sparse_regimes': computation_time_sparse_regimes,

        'computation_time_pcmci': computation_time_union,

        'graphs_ymask': {regime: res['graph'] for regime, res in results_ymask.items()},
        'graphs_persistent_regimes': {regime: res['graph'] for regime, res in results_persistent_regimes.items()},
        'graphs_sparse_regimes': {regime: res['graph'] for regime, res in results_sparse_regimes.items()},

        'union_graph_ymask': union_graph_ymask,
        'union_graph_persistent_regimes': persistent_union_graph_regimes,
        
        'union_graph_sparse_regimes': sparse_union_graph_regimes,

        'union_graph_pcmci': results_union['graph'],
        'persistence_metric': persistence_metric,
        'max_length': max_length
    }
 
######## IMPORTANT: if using sequential, comment these out

def process_chunks(job_id, chunk, seed):
    results = {}
    num_here = len(chunk)
    model_seeds = seed.spawn(num_here)
    time_start_process = time.time()
    for isam, config_sam in enumerate(chunk):
        results[config_sam] = calculate_graphs(config_sam, model_seeds[isam], plot_data=plot_data, sam=job_id)
        current_runtime = (time.time() - time_start_process) / 3600.
        current_runtime_hr = int(current_runtime)
        current_runtime_min = 60. * (current_runtime % 1.)
        estimated_runtime = current_runtime * num_here / (isam + 1.)
        estimated_runtime_hr = int(estimated_runtime)
        estimated_runtime_min = 60. * (estimated_runtime % 1.)
        print("job_id %d index %d/%d: %dh %.1fmin / %dh %.1fmin:  %s" % (
            job_id, isam + 1, num_here, current_runtime_hr, current_runtime_min,
            estimated_runtime_hr, estimated_runtime_min, config_sam))
    return results


def master():
    print("Starting with num_cpus = ", num_cpus, config_list)

    all_configs = dict([(conf, {'results': {},
                                "graphs": {},
                                "val_min": {},
                                "max_cardinality": {},

                                "true_graph": {},
                                "computation_time": {}, }) for conf in config_list])

    job_list = [(conf, i) for i in range(samples) for conf in config_list]
    num_tasks = len(job_list)
    num_jobs = min(num_cpus - 1, num_tasks)

    def split(a, n):
        k, m = len(a) // n, len(a) % n
        return [a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)]

    config_chunks = split(job_list, num_jobs)

    print("num_tasks %s" % num_tasks)
    print("num_jobs %s" % num_jobs)

    ss = SeedSequence(12345)

    # Spawn off 10 child SeedSequences to pass to child processes.
    child_seeds = ss.spawn(len(config_chunks))

    ## Send
    for job_id, chunk in enumerate(config_chunks):
        print("submit %d / %d" % (job_id, len(config_chunks)))
        mpi.submit_call("process_chunks", (job_id, chunk, child_seeds[job_id]), id=job_id)

    ## Retrieve
    id_num = 0
    for job_id, chunk in enumerate(config_chunks):
        print("\nreceive %s" % job_id)
        tmp = mpi.get_result(id=job_id)
        for conf_sam in list(tmp.keys()):
            config = conf_sam[0]
            sample = conf_sam[1]
            print('CONFIG', config)
            # TODO: save as you get results
            para_setup_str = tuple(config.split("-"))
            config_string = generate_name_from_params(para_setup_str)
            result_path = para_setup_str[-1] + '/' + config_string
            file_name = result_path + '/%s_%s' % (config_string, id_num)
            file_name_cleaned = file_name.replace("'", "").replace('"', '') + '.dat'
            print('writing... ', file_name_cleaned)
            file = open(file_name_cleaned, 'wb')
            pickle.dump(tmp[conf_sam], file, protocol=-1)
            file.close()
            id_num += 1

    time_end = time.time()
    print('Run time in hours ', (time_end - time_start) / 3600.)


mpi.run(verbose=False)

##### IMPORTANT: if using sequential, comment this in!

# if __name__=='__main__':

#     parser = argparse.ArgumentParser(description="Run tasks from a YAML file.")
#     parser.add_argument('yaml_path', type=str, help='Path to the YAML configuration file.')
#     args = parser.parse_args()
    
#     config_path = args.yaml_path
#     # config_path = 'timeseries_configs/test_timeseries.yaml'
    
#     results_folder, all_configurations = generate_configurations(config_path)
#     nb_repeats = all_configurations[0].nb_repeats
    
#     print('results_folder', results_folder)
#     if not os.path.exists(results_folder):
#         os.makedirs(results_folder)
    
#     already_there = []
#     configurations = []
    
#     for configuration in all_configurations:
#         # print('configuration', configuration)
#         # config_params = configuration[0]
#         # print(config_params)
#         # suffix = configuration[1]
#         # save_folder = config_params[-1] + '/' + suffix
#         save_folder = configuration.save_folder + '/' + generate_name_from_params(configuration)
#         # print('save folder', save_folder)
#         if not os.path.exists(save_folder):
#             os.makedirs(save_folder)
        

#         # print('configuration', configuration)
#         # config_params = configuration[0]
#         # print(config_params)
#         # suffix = configuration[1]
#         # save_folder = config_params[-1] + '/' + suffix
#         # print('save folder', save_folder)
#         # if not os.path.exists(save_folder):
#         #     os.makedirs(save_folder)
        
#         current_results_files = [f for f in os.listdir(save_folder) if os.path.isfile(os.path.join(save_folder, f))]

#         # print('saving to', current_results_files)
        
#         # if suffix not in configurations:
#         configurations.append(configuration)
    
#     num_configs = len(configurations)  # min(num_jobs, num_configs)  # num_configs/num_jobs
    
#     print("number of todo configs ", num_configs)
#     print("number of existing configs ", len(already_there))
#     print("cpus %s" % num_cpus)
    
#     print("Shuffle configs to create equal computation time chunks ")
#     if num_configs == 0:
#         raise ValueError("No configs to do...")
    
    
#     seedSeq = SeedSequence(12334567)
    
#     for sam in range(nb_repeats): # change here to number 
#         for config in configurations:
#             # print('config', config  )
#             # para_setup_str = config[-1].split("-")
#             # print('para_setup_str', para_setup_str  )
#             config_string = generate_name_from_params(config)
#             # print('config string' , config_string)
#             # print('config result', config[0][-1])
#             result_path = config.save_folder + '/' + config_string
#             # print('result_path', result_path)
#             file_name = result_path + '/%s_%s' % (config_string, sam)
#             # print('file_name', file_name)
#             file_name_cleaned = file_name.replace("'", "").replace('"', '') + '.dat'
#             # print('file_name_cleaned', file_name_cleaned)
#             # print('writing... ', file_name_cleaned)
#             res = calculate_graphs(config, seedSeq, sam=sam, plot_data=plot_data)
#             file = open(file_name_cleaned, 'wb')
#             # # print('res', res)
#             pickle.dump(res, file)
#             # # , protocol=-1)
#             file.close()
