import experiments.utils
import experiments.data.utils
import experiments.evaluation.utils
import experiments.methods.utils
from tqdm import tqdm
import time
import numpy as np
import os
from itertools import product

hyperparams = {
    "spinsvar" : {"lambda1" : [0, 0.00001, 0.0001, 0.001, 0.01], "lambda2" : [0, 0.01, 0.1, 1, 10]},
    "sparserc" : {"lambda1" : [0, 0.0001, 0.001, 0.01, 0.1], "lambda2" : [0.01, 0.1, 1, 10],  "lambda3" : [0.001, 0.01, 0.1, 1], "omega": [0.01, 0.05, 0.09, 0.2]},
    "varlingam" : {"omega": [0.01, 0.05, 0.09, 0.2]},
    "d_varlingam" : {"omega": [0.01, 0.05, 0.09, 0.2]},
    "culingam" : {"omega": [0.01, 0.05, 0.09, 0.2]},
    # "sparserc" : {"lambda1" : [0, 0.001, 0.01, 0.1], "lambda2" : [0.1, 1, 1, 10],  "lambda3" : [0.001, 0.1], "omega": [0.01, 0.05, 0.09, 0.2]},
    "dynotears" : {"lambda_w" : [0.01, 0.05, 0.1], "lambda_a" : [0.01, 0.05, 0.1], "omega": [0.01, 0.05, 0.09, 0.2]},
    "nts-notears" : {"lambda1" : [0.0001, 0.0005, 0.001, 0.002, 0.01, 0.1], "lambda2" : [0.01, 0.05, 0.1, 1], "omega": [0.01, 0.05, 0.09, 0.2]},
    "tsfci" : {"sig_level" : [0.001, 0.01, 0.05, 0.1], "omega": [0.01, 0.05, 0.09, 0.2]},
    "pcmci" : {"pc_alpha" : [0.01, 0.05, 0.1], "alpha_level" : [0.01, 0.05, 0.1], "omega": [0.01, 0.05, 0.09, 0.2]},
    "TCDF" : {"significance" : [0.8, 0.9, 1.], "nrepochs" : [1000, 2000, 5000], "omega": [0.01, 0.05, 0.09, 0.2]}
}


if __name__ == '__main__':
    parser, args = experiments.utils.get_args()
    print(vars(args))

    # naming the output files according to the experimental settings
    filename, label = experiments.utils.get_filename(parser, args)
    filename = "hyperparameter_search"

    # make directory to put results
    if not os.path.exists("results/{}/".format(filename)):
        os.makedirs("results/{}/".format(filename))

    

    for n in args.samples:
        for d in args.nodes:
            for t in args.timesteps:
                with open('results/{}.csv'.format(filename), 'a') as f:
                    f.write('{}\n'.format(label))

                    print('samples = {}, timesteps = {}, nodes = {}, edges = {}'.format(n, t, d, args.edges * d + 2 * d * args.number_of_lags))
                    f.write('samples = {}, timesteps = {}, nodes = {}, edges = {}\n'.format(n, t, d, args.edges * d + 2 * d * args.number_of_lags))

                f.close()

                current = {}
                avgT = {}

                for key in args.methods:
                    current[key] = []
                    avgT[key] = []

                for r in tqdm(range(args.runs)):

                    # graph initialization
                    start = time.time()
                    
                    X, C_true, cond_num, B_true, W_true = experiments.data.utils.get_data(args, n, d, T=t, dataset=args.dataset)
                    print("Total number of edges {}".format(np.sum(B_true)))
                    # X has shape n x T x d where n is the number of independent realizations, T the length of the time series and d the number of nodes in the dag
                    # B_true and W_true have shape d x (p + 1)d where p is the number of time-lags. They are expressed in the form B_true = [A, B_1, ..., B_p]

                    print("\n\nData generation process done. Time: {:.3f}\n\n".format(time.time() - start))

                    # normalizes or standardizes data if supposed to
                    X = experiments.data.utils.data_transform(X, args) 

                    # causal discovery algorithms
                    if not np.isnan(X).any() and experiments.utils.is_bounded(X):
                        for method in args.methods:
                            best_nshd = 1000000
                            # Generate all combinations of hyperparameters
                            hyperparameter_combinations = product(*hyperparams[method].values())

                            p = args.number_of_lags
                            # Iterate through hyperparameter combinations
                            for i, params in enumerate(hyperparameter_combinations):
                                # Set hyperparameters
                                parameters = dict(zip(hyperparams[method].keys(), params))
                                print(parameters)

                                B_est, W_est, T = experiments.methods.utils.execute_method(X, method, args, n, d, t, dataset=args.dataset, search_params=parameters, ground_truth=B_true)

                                experiments.evaluation.utils.compute_metrics(method, current, filename, r, t, T, X, C_true, B_true, W_true, B_est, W_est, args)

                                if current[method][-1][0] < best_nshd:
                                    best_nshd = current[method][-1][0]
                                    best_params = parameters

                                # save average results in csv
                                experiments.evaluation.utils.print_results(current[method][-1], filename, method, search_params=parameters)

                            experiments.evaluation.utils.print_best_params(filename, method, search_params=best_params)
