import argparse
import json
import logging
import math
import os
import random
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Tuple, List

import networkx as nx
import numpy as np
import pandas as pd
import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from benchmark.utils.custom_causally import NonAdditiveNoiseModel
from benchmark.utils.causal_graphs import MixedGraph
from benchmark.utils.algo_wrappers import CAMUV, \
    NoGAMWrapper, \
    SCAMWrapper, \
    ScoreFCI, \
    FCI, \
    RandomMixedGraph, \
    LINGAM, \
    RCDLINGAM, \
    RandomPAG, OAFASWrapper, OracleWrapper, RESIT, FullyRandomMixedGraph
from benchmark.utils.cache_source_files import copy_referenced_files_to

# Causally imports
import causally.scm.scm as scm
import causally.graph.random_graph as rg
import causally.scm.noise as noise
import causally.scm.causal_mechanism as cm
from causally.scm.context import SCMContext


def get_causally_datasets(
        num_datasets: int,
        num_samples: int,
        num_observed_nodes: int,
        num_hidden: int,
        scm_type: str,
        noise_dist: str,
        expected_degree: int = None,
        p_edge = None,
        standardize = True,
        seed:int = None
) -> List[Tuple[pd.DataFrame, nx.DiGraph]]:
    if p_edge is not None and expected_degree is not None:
        raise ValueError("Can not set both p_edge and m_edge. Only one")
    if p_edge is None and expected_degree is None:
        raise ValueError("You must explicitly set one value between p_edge and m_edge")

    datasets = []
    for i in range(num_datasets):
        # Erdos-Renyi graph generator
        num_nodes = num_observed_nodes + num_hidden
        graph_generator = rg.ErdosRenyi(
            num_nodes=num_nodes, expected_degree=expected_degree, p_edge=p_edge, min_num_edges=math.ceil(num_nodes / 2)
        )

        # Generator of the noise terms

        if noise_dist == "gauss":
            noise_generator = noise.Normal()
        elif noise_dist == "mlp":
            noise_generator = noise.MLPNoise(a_weight=-1.5, b_weight=1.5)
        elif noise_dist == "uniform":
            noise_generator = noise.Uniform(-2, 2)  # Std ~ 1

        # Structural causal model
        context = SCMContext() # context for assumptions
        if scm_type == "nonlinear":
            causal_mechanism = cm.NeuralNetMechanism()
            model = scm.AdditiveNoiseModel(
                num_samples=num_samples,
                graph_generator=graph_generator,
                noise_generator=noise_generator,
                causal_mechanism=causal_mechanism,
                scm_context=context,
                seed=seed
            )
        elif scm_type == "non-additive":
            causal_mechanism = cm.NeuralNetMechanism()
            model = NonAdditiveNoiseModel(
                num_samples=num_samples,
                graph_generator=graph_generator,
                noise_generator=noise_generator,
                causal_mechanism=causal_mechanism,
                scm_context=context,
                seed=seed
            )
        elif scm_type == "linear":
            if noise == "gauss":
                raise ValueError("Can not have gaussian noise for linear mechanisms.")
            model = scm.LinearModel(
                num_samples=num_samples,
                graph_generator=graph_generator,
                noise_generator=noise_generator,
                scm_context=context,
                seed=seed,
                max_weight=3,
                min_weight=-3,
                min_abs_weight=.5,
            )

        X, y = model.sample()
        if standardize:
            marginal_std = np.std(X, axis=0)
            for i in range(len(marginal_std)):
                X[:, i] = X[:, i] / marginal_std[i]

        # Change variables name and create nx gt
        nodes_map = {node: f"V{node + 1}" for node in range(0, num_nodes)}
        gt = nx.from_numpy_array(y, create_using=nx.DiGraph)
        gt = nx.relabel_nodes(gt, nodes_map)
        data = pd.DataFrame(X).rename(columns=nodes_map)

        # Drop hidden variables
        if num_hidden > 0:
            # Pick number at random
            hidden_idxs = np.random.choice(np.array(range(0, num_nodes)), size=(num_hidden,), replace=False)
            data = data.drop([nodes_map[k] for k in hidden_idxs], axis=1)

        datasets.append((data, gt))

    return datasets


if __name__ == '__main__':
    seed = 42
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(description='Experiment with increasing node size.')
    parser.add_argument('--algorithms', nargs='+', default=['random'])
    parser.add_argument('--num_nodes', nargs='+', default=[3, 5])
    parser.add_argument('--num_hidden', default=2, type=int)
    parser.add_argument('--num_samples', default=1000, type=int)
    parser.add_argument('--num_datasets', default=10, type=int)
    parser.add_argument('--p_edge', default=None, type=float)
    parser.add_argument('--m_edge', default=None, type=int)
    parser.add_argument('--noise', default="uniform", type=str)
    parser.add_argument('--alpha_others', default=0.05, type=float)
    parser.add_argument('--alpha_confounded_leaf', default=0.05, type=float) # high alpha: enter in nogam regime
    parser.add_argument('--alpha_orientations', default=0.05, type=float) 
    parser.add_argument('--alpha_separations', default=0.1, type=float) # high: no pre-pruning.
    parser.add_argument('--alpha_cam', default=0.001, type=float) # The usual value in CAM
    parser.add_argument('--scm', default='nonlinear')
    parser.add_argument('--cv', default=3, type=int)
    params = vars(parser.parse_args())

    result_dir = os.path.join('.', 'logs', 'paper-plots', 'incr_size_') + time.strftime('%y.%m.%d_%H.%M.%S')
    Path(result_dir).mkdir(parents=True, exist_ok=False)
    with open(os.path.join(result_dir, 'params.json'), 'w') as file:
        json.dump(params, file)
    copy_referenced_files_to(__file__, os.path.join(result_dir, "src_dump"))

    results = defaultdict(lambda: [])
    for num_nodes in params['num_nodes']:
        graph_dir = os.path.join(result_dir, 'graphs', 'num_nodes_{}'.format(num_nodes))
        data_dir = os.path.join(result_dir, 'data', 'num_nodes_{}'.format(num_nodes))
        Path(graph_dir).mkdir(parents=True, exist_ok=False)
        Path(data_dir).mkdir(parents=True, exist_ok=False)
        for i, (data, ground_truth) in enumerate(
            get_causally_datasets(
                num_datasets = params['num_datasets'],
                num_samples = params['num_samples'],
                num_observed_nodes = int(num_nodes),
                num_hidden = params["num_hidden"],
                scm_type=params["scm"],
                noise_dist=params["noise"],
                p_edge=params["p_edge"],
                expected_degree=params["m_edge"],
                standardize=True
            )
        ):
            print(f"\n##############\nDataset {i}")
            logging.info("Test dataset Nr. {}".format(i))
            MixedGraph(ground_truth).save_graph(os.path.join(graph_dir, 'ground_truth_{}.gml'.format(i)))
            data.to_csv(os.path.join(data_dir, 'data_{}.csv'.format(i)))
            for algo_name in params['algorithms']:
                if algo_name == 'scam':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       alpha_campruning=params['alpha_cam'],
                                       cv=params["cv"]
                                       )
                elif algo_name == 'scam_p_kci':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       alpha_campruning=params['alpha_cam'],
                                       cv=params["cv"]
                                       )
                    algo.algo.prune_kci = True
                elif algo_name == 'ridge':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       alpha_campruning=params['alpha_cam'],
                                       regression='kernel_ridge',
                                       cv=params["cv"]
                                       )
                elif algo_name == 'falkon':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       alpha_campruning=params['alpha_cam'],
                                       regression='falkon',
                                       cv=params["cv"]
                                       )
                elif algo_name == 'xgboost':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       alpha_campruning=params['alpha_cam'],
                                       regression='xgboost',
                                       cv=params["cv"]
                                       )
                elif algo_name == 'linear':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       alpha_campruning=params['alpha_cam'],
                                       regression='linear',
                                       cv=params["cv"]
                                       )
                elif algo_name == 'oafas':
                    algo = OAFASWrapper(alpha=params['alpha_others'])
                elif algo_name == 'camuv':
                    algo = CAMUV(alpha=params['alpha_others'])
                elif algo_name == 'rcd':
                    algo = RCDLINGAM(alpha=params['alpha_others'])
                elif algo_name == 'nogam':
                    algo = NoGAMWrapper(cv=params["cv"], alpha=params['alpha_others'])
                elif algo_name == 'lingam':
                    algo = LINGAM()
                elif algo_name == 'fci':
                    algo = FCI(alpha=params['alpha_others'], indep_test='kci')
                elif algo_name == 'score_fci':
                    algo = ScoreFCI(alpha=params['alpha_separations'])
                elif algo_name == 'resit':
                    algo = RESIT(alpha=params['alpha_others'])
                elif algo_name == 'random':
                    algo = RandomMixedGraph(num_hidden=params['num_hidden'], erdos_p=params['p_edge'])
                elif algo_name == 'random_pag':
                    algo = RandomPAG(num_hidden=params['num_hidden'], erdos_p=params['p_edge'])
                elif algo_name == 'fully_random':
                    algo = FullyRandomMixedGraph()
                elif algo_name == 'oracle':
                    algo = OracleWrapper(ground_truth)
                else:
                    raise NotImplementedError(algo_name)

                start_t = time.time()
                g_hat = algo.fit(data)
                end_t = time.time()

                metrics = g_hat.eval_all_metrics(ground_truth, params['scm'] != 'linear')
                metrics['time'] = end_t - start_t
                metrics['num_nodes'] = num_nodes
                results[algo_name].append(metrics)
                g_hat.save_graph(os.path.join(graph_dir, 'g_hat_{}_{}.gml'.format(algo_name, i)))

                print(f"Algorithm {algo_name}: time = {metrics['time']}s; SHD = {metrics['shd']}; skeleton f1 = {metrics['skeleton_f1']}")

    for algo_name in params['algorithms']:
        df = pd.DataFrame(results[algo_name])
        df.to_csv(os.path.join(result_dir, algo_name + '.csv'))
