import argparse
import os
import numpy as np
import sklearn

import tigramite
import pickle
from tqdm import tqdm
from matplotlib import pyplot as plt
from tigramite import data_processing as pp
from tigramite.toymodels import structural_causal_processes as toys
from tigramite import plotting as tp
from tigramite.pcmci import PCMCI

from tigramite.independence_tests.parcorr import ParCorr
from tigramite.independence_tests.gpdc import GPDC
from tigramite.independence_tests.cmiknn import CMIknn
from tigramite.independence_tests.cmisymb import CMIsymb


class CausalDataGenerator:
    def __init__(self, seed=7, auto_coeff=0.95, coeff=0.4, T=500, links=None):
        self.seed = seed
        self.auto_coeff = auto_coeff
        self.coeff = coeff
        self.T = T
        self.links = links
        self.random_state = np.random.RandomState(seed)

    def _get_default_links(self):
        def lin(x): return x
        
        return {
            0: [((0, -1), self.auto_coeff, lin),
                ((1, -1), self.coeff, lin)],
            1: [((1, -1), self.auto_coeff, lin)],
            2: [((2, -1), self.auto_coeff, lin),
                ((3, 0), -self.coeff, lin)],
            3: [((3, -1), self.auto_coeff, lin),
                ((1, -2), self.coeff, lin)],
            4: [((4, -1), self.auto_coeff, lin),
                ((3, 0), self.coeff, lin)],
            5: [((5, -1), 0.5*self.auto_coeff, lin),
                ((6, 0), self.coeff, lin)],
            6: [((6, -1), 0.5*self.auto_coeff, lin),
                ((5, -1), -self.coeff, lin)],
            7: [((7, -1), self.auto_coeff, lin),
                ((8, 0), -self.coeff, lin)],
            8: []
        }

    def generate_data(self, noises_type):

        if self.links is None:
            self.links = self._get_default_links()
        if noises_type == 'gauss':
            noises = [self.random_state.normal for _ in self.links.keys()]
        elif noises_type == 'exp':
            noises = [self.random_state.exponential for _ in self.links.keys()]
        return toys.structural_causal_process(
            links=self.links,
            T=self.T,
            noises=noises,
            seed=self.seed
        )

def lin(x): return x

def generate_links(node_num, edge_num, p_orders, seed=0):

    rng = np.random.RandomState(seed)
    
    links = dict()
    for i in range(node_num):
        links[i] = []

    used_edges = set()

    contemporaneous_reachability = np.zeros((node_num, node_num), dtype=bool)
    edge_count = 0

    while edge_count < edge_num:
        x = rng.randint(0, node_num) 
        y = rng.randint(0, node_num) 
        order = rng.randint(0, p_orders + 1)  
        # if x == y and order > 1:
        #     order = 1

        edge = (x, y, order)
        if edge not in used_edges:
            if order == 0:

                if (not contemporaneous_reachability[x, y] and 
                    x != y): 
                    used_edges.add(edge)
                    if x != y:
                        eff = rng.rand() * 0.8 - 0.4
                        eff = eff + 0.1 if eff > 0 else eff - 0.1
                    else:
                        eff = rng.rand() * 0.85 + 0.1
                    links[x].append(((y, -order), eff, lin))
                    edge_count += 1

                    contemporaneous_reachability[y, x] = True

                    for k_intermediate in range(node_num):
                        for i in range(node_num):
                            for j in range(node_num):
                                contemporaneous_reachability[i, j] = contemporaneous_reachability[i, j] or (contemporaneous_reachability[i, k_intermediate] and contemporaneous_reachability[k_intermediate, j])

            else:
                used_edges.add(edge)
                eff = rng.rand() * 0.8 - 0.4
                eff = eff + 0.1 if eff > 0 else eff - 0.1
                links[x].append(((y, -order), eff, lin))
                edge_count += 1
    return links

def plot_pcmci_result(data, links, tau_max, save_path, alpha_level=0.001, pc_alpha=0.01):
    var_names = [r'$X^{%d}$' % j for j in range(args.node_num) ]
    parcorr = ParCorr(significance='analytic')
    pcmci = PCMCI(
        dataframe=pp.DataFrame(data, var_names=var_names),
        cond_ind_test=parcorr,
        verbosity=1)
    # results = pcmci.run_pcmciplus(tau_min=0, tau_max=tau_max, pc_alpha=pc_alpha)
    # pcmci.verbosity = 0
    results = pcmci.run_pcmci(tau_min=0, tau_max=tau_max, pc_alpha=None, alpha_level=alpha_level)

    # Mark false links as grey 
    true_graph = toys.links_to_graph(links=links, tau_max=tau_max)
    link_attribute = np.ones(results['val_matrix'].shape, dtype = 'object')
    link_attribute[true_graph==""] = 'spurious'
    link_attribute[true_graph!=""] = ''
    # Symmetrize contemp. link attribute
    for (i,j) in zip(*np.where(true_graph[:,:,0]!="")):
        link_attribute[i,j,0] = link_attribute[j,i,0] = ''

    tp.plot_time_series_graph(
        figsize=(8, 8),
        node_size=0.05,
        val_matrix=results['val_matrix'],
        graph=results['graph'],
        link_attribute=link_attribute,
        var_names=var_names,
        )
    plt.savefig(save_path)
    plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--node_num', type=int, default=5)
    parser.add_argument('--edge_num', type=int, default=20)
    parser.add_argument('--p_orders', type=int, default=3)
    parser.add_argument('--noises_type', type=str, default='gauss')
    parser.add_argument('--output', type=str, default='simulated_data_init')
    parser.add_argument('--algs',type=str,nargs='*',
                        default=['simulated_data_and_init0', 'simulated_data_multiply_init0', 'simulated_data_and_initdata', 'simulated_data_multiply_initdata'])
    
    args = parser.parse_args()

    for T in [250, 1000]:
        edge_num = args.edge_num * args.node_num
        links = generate_links(args.node_num, edge_num, args.p_orders, seed=42)
        generator = CausalDataGenerator(seed=42, T=T)
        generator.links = links
        data, indicator = generator.generate_data(noises_type=args.noises_type)
        REPEAT_NUM = 12
        for i in tqdm(range(REPEAT_NUM)):
            for alg in args.algs:
                save_path = f"{args.output}/{alg}/node{str(args.node_num).zfill(3)}_edge{str(edge_num).zfill(3)}_porders{str(args.p_orders).zfill(1)}_T{str(T).zfill(4)}_noise{args.noises_type}/repeat{str(i).zfill(1)}/"
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                pickle.dump(data, open(save_path + 'data.pkl', 'wb'))
                links_infos = dict()
                for key, value in links.items():
                    links_infos[key] = [(edge[0], edge[1]) for edge in value]
                pickle.dump(links_infos, open(save_path + 'links_infos.pkl', 'wb'))
                # plot_pcmci_result(data, links, args.p_orders + 1, save_path + 'pcmci_result.png')
    print("Done!")