import experiments.utils
import experiments.data.utils
import experiments.evaluation.utils
import experiments.methods.utils
from tqdm import tqdm
import time
import numpy as np
import os
from itertools import product

hyperparams = {
    "spinsvar" : {"lambda1" : [0, 0.0001, 0.001, 0.01, 0.1], "lambda2" : [0, 0.01, 0.1, 1, 10], "omega" : [0.5]},
    "sparserc" : {"lambda1" : [0.0001, 0.001, 0.01, 0.1], "lambda2" : [0.01, 0.1, 1, 10], "lambda3" : [0.0001, 0.001, 0.01, 0.1], "omega" : [0.1, 0.2, 0.3, 0.4]},
    "varlingam" : {"omega" : [0.2, 0.3, 0.4, 0.5, 0.6]},
    "d_varlingam" : {"omega" : [0.2, 0.3, 0.4, 0.5, 0.6]},
    "culingam" : {"omega" : [0.2, 0.3, 0.4, 0.5, 0.6]},
    "dynotears" : {"lambda_w" : [0.01, 0.05, 0.1], "lambda_a" : [0.01, 0.05, 0.1], "omega" : [0.1, 0.2, 0.3, 0.4]},
    "nts-notears" : {"lambda1" : [0.0001, 0.0005, 0.001, 0.002, 0.01, 0.1], "lambda2" : [0.01, 0.05, 0.1, 1], "omega" : [0.1, 0.2, 0.3, 0.4]},
    "tsfci" : {"sig_level" : [0.001, 0.01, 0.05, 0.1], "omega" : [0.1, 0.2, 0.3, 0.4]},
    "pcmci" : {"pc_alpha" : [0.01, 0.05, 0.1], "alpha_level" : [0.01, 0.05, 0.1], "omega" : [0.1, 0.2, 0.3, 0.4]},
    "TCDF" : {"significance" : [0.8, 0.9, 1.], "nrepochs" : [1000, 2000, 5000], "omega" : [0.1, 0.2, 0.3, 0.4]}
}

def run_real():    
    # make directory to put results
    if not os.path.exists("results/{}/".format(path)):
        os.makedirs("results/{}/".format(path))

    for t in [50]:
        with open('results/{}.csv'.format(path), 'a') as f:
            current = {}
            avgT = {}

            for key in args.methods:
                current[key] = []
                avgT[key] = []

            for r, filename_data in enumerate(os.listdir("experiments/data/FinanceCPT/returns")[:1]):
                filename_gt = filename_data.split("_returns")[0] + ".csv"

                # graph initialization
                start = time.time()
                
                # graph initialization
                start = time.time()
                X, _, _, B_true, _ = experiments.data.utils.get_data(args, 0, 0, dataset="finance", filename_data=filename_data, filename_gt=filename_gt)
                print("\n\nData generation process done. Time: {:.3f}\n\n".format(time.time() - start))

                print(B_true.shape)
                # causal discovery algorithms
                for method in args.methods:
                    d = X.shape[-1]
                    best_nshd = 1000000

                    # Generate all combinations of hyperparameters
                    hyperparameter_combinations = product(*hyperparams[method].values())

                    p = args.number_of_lags
                    # Iterate through hyperparameter combinations
                    for i, params in enumerate(hyperparameter_combinations):
                        # Set hyperparameters
                        parameters = dict(zip(hyperparams[method].keys(), params))
                        print(parameters)

                        B_est, W_est, T = experiments.methods.utils.execute_method(X, method, args, 0, d, t, dataset="finance", search_params=parameters, ground_truth=B_true)

                        experiments.evaluation.utils.compute_metrics(method, current, path, r, t, T, X, None, B_true, 0, B_est, W_est, args)

                        if current[method][-1][0] < best_nshd:
                            best_nshd = current[method][-1][0]
                            best_params = parameters

                        # save average results in csv
                        experiments.evaluation.utils.print_results(current[method][-1], path, method, search_params=parameters)

                    experiments.evaluation.utils.print_best_params(path, method, search_params=best_params)

if __name__ == '__main__':
    parser, args = experiments.utils.get_args()
    print(vars(args))

    path = "hyperparameter_finance"
    run_real()
