import argparse
import glob
import json
import logging
import os
import random
import sys
import time
from collections import defaultdict

import numpy as np
import pandas as pd
import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from benchmark.utils.causal_graphs import MixedGraph
from utils.algo_wrappers import CAMUV, \
    NoGAMWrapper, \
    SCAMWrapper, \
    ScoreFCI, \
    FCI, \
    RandomMixedGraph, \
    LINGAM, \
    RCDLINGAM, \
    RandomPAG, OAFASWrapper, OracleWrapper, RESIT, FullyRandomMixedGraph

if __name__ == '__main__':
    seed = 42
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(description='Experiment with increasing node size.')
    parser.add_argument('--algorithms', nargs='+', default=['resit'])
    parser.add_argument('--alpha_others', default=0.01, type=float)
    parser.add_argument('--alpha_confounded_leaf', default=0.05, type=float)
    parser.add_argument('--alpha_orientations', default=0.05, type=float)
    parser.add_argument('--alpha_separations', default=.1, type=float)
    parser.add_argument('--alpha_cam', default=0.001, type=float)  # The usual value in CAM
    parser.add_argument('--dir', default='.')
    parser.add_argument('--cv', default=3, type=int)
    params = vars(parser.parse_args())

    with open(os.path.join(params['dir'], 'params.json')) as file:
        experiment_params = json.load(file)

    results = defaultdict(lambda: [])
    for data_subdir in glob.glob(os.path.join(params['dir'], 'data', '*')):
        # structure is e.g. './data/num_nodes_i/data_j.csv'. We want 'num_nodes_i' to insert in other paths later
        subdir = os.path.basename(os.path.normpath(data_subdir))
        logging.info("Enter subdir '{}'".format(subdir))
        for i in range(experiment_params['num_datasets']):
            logging.info("Test dataset Nr. {}".format(i))
            graph_dir = os.path.join(params['dir'], 'graphs', subdir)
            ground_truth = MixedGraph.load_graph(os.path.join(graph_dir, 'ground_truth_{}.gml'.format(i)))
            data = pd.read_csv(os.path.join(params['dir'], 'data', subdir, 'data_{}.csv'.format(i)), index_col=0)
            for algo_name in params['algorithms']:
                if algo_name == 'scam':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       alpha_campruning=params['alpha_cam'],
                                       cv=params["cv"]
                                       )
                elif algo_name == 'ridge':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       alpha_campruning=params['alpha_cam'],
                                       cv=params["cv"],
                                       regression='kernel_ridge'
                                       )
                elif algo_name == 'falkon':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       alpha_campruning=params['alpha_cam'],
                                       cv=params["cv"],
                                       regression='falkon'
                                       )
                elif algo_name == 'xgboost':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       alpha_campruning=params['alpha_cam'],
                                       cv=params["cv"],
                                       regression='xgboost'
                                       )
                elif algo_name == 'linear':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       alpha_campruning=params['alpha_cam'],
                                       cv=params["cv"],
                                       regression='linear'
                                       )
                elif algo_name == 'oafas':
                    algo = OAFASWrapper(alpha=params['alpha_others'])
                elif algo_name == 'camuv':
                    algo = CAMUV(alpha=params['alpha_others'])
                elif algo_name == 'rcd':
                    algo = RCDLINGAM(alpha=params['alpha_others'])
                elif algo_name == 'nogam':
                    algo = NoGAMWrapper(alpha=params['alpha_others'])
                elif algo_name == 'lingam':
                    algo = LINGAM()
                elif algo_name == 'fci':
                    algo = FCI(alpha=params['alpha_others'], indep_test='kci')
                elif algo_name == 'score_fci':
                    algo = ScoreFCI(alpha=params['alpha_separations'])
                elif algo_name == 'resit':
                    algo = RESIT(alpha=params['alpha_others'])
                elif algo_name == 'random':
                    algo = RandomMixedGraph(num_hidden=params['num_hidden'])
                elif algo_name == 'random_pag':
                    algo = RandomPAG(num_hidden=params['num_hidden'])
                elif algo_name == 'fully_random':
                    algo = FullyRandomMixedGraph()
                elif algo_name == 'oracle':
                    algo = OracleWrapper(ground_truth.graph)
                else:
                    raise NotImplementedError(algo_name)

                start_t = time.time()
                g_hat = algo.fit(data)
                end_t = time.time()
                indicate_ucp = experiment_params['mechanism'] != 'linear' if 'mechanism' in experiment_params else True
                metrics = g_hat.eval_all_metrics(ground_truth.graph, )
                metrics['time'] = end_t - start_t
                metrics['num_nodes'] = len(data.keys())
                results[algo_name].append(metrics)
                g_hat.save_graph(os.path.join(graph_dir, 'g_hat_{}_{}.gml'.format(algo_name, i)))

    for algo_name in params['algorithms']:
        df = pd.DataFrame(results[algo_name])
        df.to_csv(os.path.join(params['dir'], algo_name + '.csv'))
