import experiments.utils
import experiments.data.utils
import experiments.evaluation.utils
import experiments.methods.utils
from tqdm import tqdm
import time
import numpy as np
import os
from itertools import product

hyperparams = {
    "spinsvar" : {"lambda1" : [0, 0.0001, 0.001, 0.01, 0.1], "lambda2" : [0, 0.01, 0.1, 1, 10], "omega" : [0.2, 0.3, 0.4]},
    "sparserc" : {"lambda1" : [0.0001, 0.001, 0.01, 0.1], "lambda2" : [0.01, 0.1, 1, 10], "lambda3" : [0.0001, 0.001, 0.01, 0.1], "omega" : [0.1, 0.2, 0.3, 0.4]},
    "varlingam" : {"omega" : [0.2, 0.3, 0.4, 0.5, 0.6]},
    "d_varlingam" : {"omega" : [0.2, 0.3, 0.4, 0.5, 0.6]},
    "culingam" : {"omega" : [0.2, 0.3, 0.4, 0.5, 0.6]},
    "nts-notears" : {"lambda1" : [0.0001, 0.0005, 0.001, 0.002, 0.01, 0.1], "lambda2" : [0.01, 0.05, 0.1, 1], "omega" : [0.1, 0.2, 0.3, 0.4]},
    "tsfci" : {"sig_level" : [0.001, 0.01, 0.05, 0.1], "omega" : [0.1, 0.2, 0.3, 0.4]},
    "pcmci" : {"pc_alpha" : [0.01, 0.05, 0.1], "alpha_level" : [0.01, 0.05, 0.1], "omega" : [0.1, 0.2, 0.3, 0.4]},
}

def run_real():    
    filename_data = "dream3"

    # make directory to put results
    if not os.path.exists("results/{}/".format(path)):
        os.makedirs("results/{}/".format(path))

    for t in [20]:#, 10, 20, 25, 50, 100, 1000]:
            with open('results/{}.csv'.format(path), 'a') as f:
                dataset_id = 1 # choosing 1st dataset to do hyperparameter search.
                current = {}
                avgT = {}

                for key in args.methods:
                    current[key] = []
                    avgT[key] = []

                # causal discovery algorithms
                for r in range(args.runs):

                    # graph initialization
                    start = time.time()
                    
                    # graph initialization
                    X, _, _, B_true, W_true = experiments.data.utils.get_data(args, 0, 0, dataset="dream3", dataset_id=dataset_id, filename_data=filename_data)
                    print("\n\nData generation process done. Time: {:.3f}\n\n".format(time.time() - start))

                    # normalizes or standardizes data if supposed to
                    X = experiments.data.utils.data_transform(X, args) 

                    print(B_true.shape)

                    # causal discovery algorithms
                    for method in args.methods:
                        d = X.shape[-1]
                        best_auroc = 0

                        # Generate all combinations of hyperparameters
                        hyperparameter_combinations = product(*hyperparams[method].values())

                        # Iterate through hyperparameter combinations
                        for i, params in enumerate(hyperparameter_combinations):
                            # Set hyperparameters
                            parameters = dict(zip(hyperparams[method].keys(), params))
                            print(parameters)

                            B_est, W_est, T = experiments.methods.utils.execute_method(X, method, args, 0, d, t, dataset="dream3", search_params=parameters, ground_truth=B_true)

                            experiments.evaluation.utils.compute_metrics(method, current, path, r, t, T, X, None, B_true, 0, B_est, W_est, args)

                            if current[method][-1][6] > best_auroc:
                                best_auroc = current[method][-1][6]
                                best_params = parameters

                            # save average results in csv
                            experiments.evaluation.utils.print_results(current[method][-1], path, method, search_params=parameters)

                        experiments.evaluation.utils.print_best_params(path, method, search_params=best_params)

if __name__ == '__main__':
    parser, args = experiments.utils.get_args()
    print(vars(args))

    path = "hyperparameter_dream3"
    run_real()
