import argparse
import pickle
import torch
from tabular_datasets import ADULT, HealthHeritage
import numpy as np
import pandas as pd
import os
from utils import Timer, evaluate_sampled_dataset
from query import get_all_marginals, query_marginal



def main(args):

    device = 'cuda'

    # set the random seed
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    # tabular_datasets
    available_datasets = {
        'ADULT': ADULT,
        'HealthHeritage': HealthHeritage
    }

    # base paths
    base_model_path = f'experiment_data/dp_benchmarks/trained_models/{args.dataset}/random_seed_{args.random_seed}/{args.algorithm}/'
    base_sampled_datasets_path = f'experiment_data/dp_benchmarks/sampled_datasets/{args.dataset}/random_seed_{args.random_seed}/{args.algorithm}/'
    base_save_path = f'experiment_data/dp_benchmarks/evaluation_data/{args.dataset}/random_seed_{args.random_seed}/{args.algorithm}/'
    os.makedirs(base_save_path, exist_ok=True)
    save_path = base_save_path + f'collected_data_{args.algorithm}_{args.dataset}_{args.n_samples}_{args.n_resamples}_{args.workload}_{args.random_seed}.npy'

    # ------- Instantiate the dataset ------- #
    dataset = available_datasets[args.dataset](drop_education_num=True, device=device) if args.dataset == 'ADULT' \
        else available_datasets[args.dataset](device=device)
    full_one_hot = dataset.get_Dtrain_full_one_hot(return_torch=True).to(device)
    # full_one_hot_test = dataset.get_Dtest_full_one_hot(return_torch=True).to(device)
    # full_one_hot = torch.cat((full_one_hot_train, full_one_hot_test), axis=0)

    # ------- Workload ------- #
    all_2_way_marginals = get_all_marginals(list(dataset.features.keys()), 2, downward_closure=False)
    all_3_way_marginals = get_all_marginals(list(dataset.features.keys()), 3, downward_closure=False)
    available_workloads = {
        'all_two': all_2_way_marginals,
        'all_three': all_3_way_marginals,
        'all_three_with_labels': [m for m in all_3_way_marginals if dataset.label in m]
    }
    precomputed_workload = {m: query_marginal(full_one_hot, m, dataset.full_one_hot_index_map, normalize=True, input_torch=True, max_slice=10000) for m in available_workloads[args.workload]}
    workload_translator = {'all_two': 2, 'all_three': 3, 'all_three_with_labels': 'all_three_with_labels'}
    wtl = workload_translator[args.workload]

    # params
    epsilons = [0.1, 0.25, 0.5, 0.75, 1.0]

    timer = Timer(len(epsilons) * args.n_samples * args.n_resamples)
    if os.path.isfile(save_path) and not args.force:
        print('Experiment already evaluated -- Aborting evaluation')
    else:
        collected_data = np.zeros((len(epsilons), args.n_samples, args.n_resamples, 6, 5))
        for i, epsilon in enumerate(epsilons):

            for sample in range(args.n_samples):

                if args.algorithm == 'ProgSyn':
                    load_path = base_model_path + f'trained_{args.algorithm}_{sample+1}_{args.n_samples}_{args.workload}_{epsilon}_{args.random_seed}.pickle'
                    if not os.path.isfile(load_path):
                        print(f'File not found: {load_path}')
                        continue    
                    with open(load_path, 'rb') as f:
                        synthesizer = pickle.load(f)

                for resample in range(args.n_resamples):

                    timer.start()
                    print(f'Epsilon: {epsilon}    Sample: {sample+1}/{args.n_samples}    Resample: {resample+1}/{args.n_resamples}    {timer}', end='\r')

                    if args.algorithm == 'ProgSyn':
                        synthetic_data = synthesizer.generate_data(synthesizer.data_len).detach()
                    
                    elif args.algorithm == 'gem':
                        dataset_name = 'health' if args.dataset == 'HealthHeritage' else args.dataset.lower()
                        load_path = base_sampled_datasets_path + f'synth_{dataset_name}_{args.algorithm}_32_{wtl}_286_{epsilon}_{sample}_{args.random_seed}.npy'  # there is no sampling in GEM as it operates with fixed noise
                        if not os.path.isfile(load_path):
                            print(f'File not found: {load_path}')
                            continue
                        synthetic_data = torch.tensor(np.load(load_path)).to(device)
                    
                    else:
                        
                        # otherwise we work from files that we have to load
                        dataset_name = 'health' if args.dataset == 'HealthHeritage' else args.dataset.lower()
                        load_path = base_sampled_datasets_path + f'synth_{dataset_name}_{args.algorithm}_32_{wtl}_{epsilon}_{sample}_{args.random_seed}_{resample+1}_{args.n_resamples}.csv'
                        if not os.path.isfile(load_path):
                            print(f'File not found: {load_path}')
                            continue
                        synthetic_ordinal_data = pd.read_csv(load_path, delimiter=',').to_numpy()
                        synthetic_data = dataset.encode_full_one_hot_batch(synthetic_ordinal_data, with_label=True, already_ordinal=True, return_torch=True).to(device)
                        
                    # evaluate the synthetic data
                    tv_stats, l2_stats, js_stats, acc_stats, bac_stats, f1_stats = evaluate_sampled_dataset(
                        synthetic_dataset=synthetic_data,
                        workload=available_workloads[args.workload],
                        true_measured_workload=precomputed_workload,
                        dataset=dataset,
                        max_slice=1000,
                        random_seed=args.random_seed
                    )

                    # record
                    collected_data[i, sample, resample, 0] = tv_stats
                    collected_data[i, sample, resample, 1] = l2_stats
                    collected_data[i, sample, resample, 2] = js_stats
                    collected_data[i, sample, resample, 3] = acc_stats
                    collected_data[i, sample, resample, 4] = bac_stats
                    collected_data[i, sample, resample, 5] = f1_stats

                    timer.end()
        
        np.save(save_path, collected_data)
        timer.duration()


if __name__ == '__main__':
    parser = argparse.ArgumentParser('dp_benchmark_parser')
    parser.add_argument('--algorithm', type=str, help='Specify the algorithm to be evaluated')
    parser.add_argument('--workload', type=str, default='all_three', help='Specify the workload of training and evaluation')
    parser.add_argument('--dataset', type=str, default='ADULT', help='Select the dataset')
    parser.add_argument('--n_samples', type=int, default=5, help='Number of ratraining samples')
    parser.add_argument('--n_resamples', type=int, default=5, help='Number of resamples from the data')
    parser.add_argument('--random_seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--force', action='store_true', help='Force the exectution of the experiment')
    in_args = parser.parse_args()
    main(in_args)
