import argparse
import json
import logging
import os
import random
import sys
import time
from collections import defaultdict
from pathlib import Path

import numpy as np
import pandas as pd
import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from benchmark.utils.causal_graphs import MixedGraph
from data.generate_data import get_confounded_datasets
from utils.algo_wrappers import CAMUV, \
    NoGAMWrapper, \
    SCAMWrapper, \
    ScoreFCI, \
    FCI, \
    RandomMixedGraph, \
    RCDLINGAM, \
    LINGAM, \
    RandomPAG, OracleWrapper, OAFASWrapper, RESIT, FullyRandomMixedGraph
from utils.cache_source_files import copy_referenced_files_to

if __name__ == '__main__':
    seed = 42
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(description='Experiment with increasing sample size.')
    parser.add_argument('--algorithms', nargs='+', default=['fci', 'score_fci', 'random_pag'])
    parser.add_argument('--num_samples', nargs='+', default=[300, 500, 700, 900])
    parser.add_argument('--num_hidden', default=1, type=int)
    parser.add_argument('--num_nodes', default=5, type=int)
    parser.add_argument('--num_datasets', default=5, type=int)
    parser.add_argument('--alpha_others', default=0.01, type=float)
    parser.add_argument('--alpha_confounded_leaf', default=0.001, type=float)
    parser.add_argument('--alpha_orientations', default=0.001, type=float)
    parser.add_argument('--alpha_separations', default=0.1, type=float)
    parser.add_argument('--p_edge', default=None, type=float)
    params = vars(parser.parse_args())

    result_dir = os.path.join('..', 'data', 'incr_samples_') + time.strftime('%y.%m.%d_%H.%M.%S')
    Path(result_dir).mkdir(parents=True, exist_ok=False)
    with open(os.path.join(result_dir, 'params.json'), 'w') as file:
        json.dump(params, file)
    copy_referenced_files_to(__file__, os.path.join(result_dir, "src_dump"))

    results = defaultdict(lambda: [])
    for num_samples in params['num_samples']:
        graph_dir = os.path.join(result_dir, 'graphs', 'num_samples_{}'.format(num_samples))
        Path(graph_dir).mkdir(parents=True, exist_ok=False)
        data_dir = os.path.join(result_dir, 'data', 'num_samples_{}'.format(num_samples))
        Path(data_dir).mkdir(parents=True, exist_ok=False)
        for i, (data, ground_truth) in enumerate(get_confounded_datasets(params['num_datasets'],
                                                                         params['num_nodes'],
                                                                         params['num_hidden'],
                                                                         int(num_samples),
                                                                         2,
                                                                         params['mechanism']
                                                                         )
                                                 ):
            logging.info("Test dataset Nr. {}".format(i))
            MixedGraph(ground_truth).save_graph(os.path.join(graph_dir, 'ground_truth_{}.gml'.format(i)))
            data.to_csv(os.path.join(data_dir, 'data_{}.csv'.format(i)))
            for algo_name in params['algorithms']:
                if algo_name == 'scam':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations']
                                       )
                elif algo_name == 'ridge':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       regression='kernel_ridge'
                                       )
                elif algo_name == 'falkon':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       regression='falkon'
                                       )
                elif algo_name == 'xgboost':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       regression='xgboost'
                                       )
                elif algo_name == 'linear':
                    algo = SCAMWrapper(alpha_orientations=params['alpha_orientations'],
                                       alpha_confounded_leaf=params['alpha_confounded_leaf'],
                                       alpha_separations=params['alpha_separations'],
                                       regression='linear'
                                       )
                elif algo_name == 'oafas':
                    algo = OAFASWrapper(alpha=params['alpha_others'])
                elif algo_name == 'camuv':
                    algo = CAMUV(alpha=params['alpha_others'])
                elif algo_name == 'rcd':
                    algo = RCDLINGAM(alpha=params['alpha_others'])
                elif algo_name == 'nogam':
                    algo = NoGAMWrapper(alpha=params['alpha_others'])
                elif algo_name == 'lingam':
                    algo = LINGAM()
                elif algo_name == 'fci':
                    algo = FCI(alpha=params['alpha_others'], indep_test='kci')
                elif algo_name == 'score_fci':
                    algo = ScoreFCI(alpha=params['alpha_separations'])
                elif algo_name == 'resit':
                    algo = RESIT(alpha=params['alpha_others'])
                elif algo_name == 'random':
                    algo = RandomMixedGraph(num_hidden=params['num_hidden'], erdos_p=params['p_edge'])
                elif algo_name == 'random_pag':
                    algo = RandomPAG(num_hidden=params['num_hidden'], erdos_p=params['p_edge'])
                elif algo_name == 'fully_random':
                    algo = FullyRandomMixedGraph()
                elif algo_name == 'oracle':
                    algo = OracleWrapper(ground_truth)
                else:
                    raise NotImplementedError(algo_name)

                start_t = time.time()
                g_hat = algo.fit(data)
                end_t = time.time()

                metrics = g_hat.eval_all_metrics(ground_truth, params['mechanism'] != 'linear')
                metrics['time'] = end_t - start_t
                metrics['num_samples'] = num_samples
                results[algo_name].append(metrics)
                g_hat.save_graph(os.path.join(graph_dir, 'g_hat_{}_{}.gml'.format(algo_name, i)))

    for algo_name in params['algorithms']:
        df = pd.DataFrame(results[algo_name])
        df.to_csv(os.path.join(result_dir, algo_name + '.csv'))
