import argparse
import json
import logging
import os
import random
import sys
import time
from collections import defaultdict
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from cdt.metrics import SHD

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from data.generate_data import get_confounded_datasets
from utils.algo_wrappers import CAMUV, NoGAMWrapper, SCAMWrapper
from utils.cache_source_files import copy_referenced_files_to
from utils.metrics import direct_edge_precision, \
    direct_edge_recall, \
    direct_edge_f1, \
    bi_edge_precision, \
    bi_edge_recall, bi_edge_f1

if __name__ == '__main__':
    seed = 42
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(description='Experiment with increasing node size.')
    parser.add_argument('--algorithms', nargs='+', default=['xgboost'])
    parser.add_argument('--num_nodes', default=7, type=int)
    parser.add_argument('--num_hidden', default=1, type=int)
    parser.add_argument('--num_samples', default=1000, type=int)
    parser.add_argument('--num_datasets', default=5, type=int)
    parser.add_argument('--parameter', default='alpha_separations')
    parser.add_argument('--values', nargs='+', default=[0.5, .1, .01, .001])
    parser.add_argument('--metric', default='shd')
    params = vars(parser.parse_args())

    result_dir = os.path.join('..', 'data', 'cv_') + time.strftime('%y.%m.%d_%H.%M.%S')
    Path(result_dir).mkdir(parents=True, exist_ok=False)
    with open(os.path.join(result_dir, 'params.json'), 'w') as file:
        json.dump(params, file)
    copy_referenced_files_to(__file__, os.path.join(result_dir, "src_dump"))

    metrics = [direct_edge_precision, direct_edge_recall, direct_edge_f1, bi_edge_precision, bi_edge_recall, bi_edge_f1,
               SHD]
    metric_func = {m.__name__: m for m in metrics}

    results = {}
    for algo in params['algorithms']:
        results[algo] = defaultdict(lambda: [])

    for i, (data, ground_truth) in enumerate(get_confounded_datasets(params['num_datasets'],
                                                                     params['num_nodes'],
                                                                     params['num_hidden'],
                                                                     params['num_samples']
                                                                     )
                                             ):
        logging.info("Test dataset Nr. {}".format(i))
        for value in params['values']:
            algo_params = {params['parameter']: value}
            for algo_name in params['algorithms']:
                if algo_name == 'scam':
                    algo = SCAMWrapper(**algo_params)
                elif algo_name == 'ridge':
                    algo = SCAMWrapper(regression='kernel_ridge', **algo_params)
                elif algo_name == 'falkon':
                    algo = SCAMWrapper(regression='falkon', **algo_params)
                elif algo_name == 'xgboost':
                    algo = SCAMWrapper(regression='xgboost', **algo_params)
                elif algo_name == 'linear':
                    algo = SCAMWrapper(regression='linear', **algo_params)
                elif algo_name == 'camuv':
                    algo = CAMUV(**algo_params)
                elif algo_name == 'nogam':
                    algo = NoGAMWrapper(**algo_params)
                else:
                    raise NotImplementedError(algo_name)

                g_hat = algo.fit(data)

                metrics = g_hat.eval_all_metrics(ground_truth)
                results[algo_name][params['metric']].append(metrics[params['metric']])
                results[algo_name][params['parameter']].append(value)

    final_dict = {}
    for algo_name in params['algorithms']:
        df = pd.DataFrame(results[algo_name])
        res = df.groupby(params['parameter']).mean().idxmin()
        print(df.groupby(params['parameter']).mean())
        final_dict[algo_name] = res[params['metric']]  # Get the idx of the minimum w.r.t. metric

    with open(os.path.join(result_dir, 'cv.json'), 'w') as file:
        json.dump(final_dict, file)
