import pickle
from tabular_datasets import ADULT, HealthHeritage
import pandas as pd 
import numpy as np
import torch
import os
from utils import Timer, evaluate_sampled_dataset, statistics
import argparse
from query import get_all_marginals, query_marginal
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score
from xgboost import XGBClassifier


def main(args):

    device = 'cuda'

    # set the random seed
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    # tabular_datasets
    available_datasets = {
        'ADULT': ADULT,
        'HealthHeritage': HealthHeritage
    }

    # base paths
    base_model_path = f'experiment_data/non_dp_benchmarks/trained_models/{args.dataset}/random_seed_{args.random_seed}/{args.model}/'
    base_save_path = f'experiment_data/non_dp_benchmarks/evaluation_data/{args.dataset}/random_seed_{args.random_seed}/{args.model}/'
    os.makedirs(base_save_path, exist_ok=True)
    save_path = base_save_path + f'collected_data_{args.model}_{args.dataset}_{args.n_samples}_{args.n_resamples}_{args.random_seed}_{args.distilled}.npy'

    # ------- Instantiate the dataset ------- #
    dataset = available_datasets[args.dataset](drop_education_num=True, device=device) if args.dataset == 'ADULT' \
        else available_datasets[args.dataset](device=device)
    full_one_hot = dataset.get_Dtrain_full_one_hot(return_torch=True).to(device)
    # full_one_hot_test = dataset.get_Dtest_full_one_hot(return_torch=True).to(device)
    # full_one_hot = torch.cat((full_one_hot_train, full_one_hot_test), axis=0)
    Xtest, ytest = dataset.get_Xtest().cpu().numpy(), dataset.get_ytest().cpu().numpy()

    # ------- Workload ------- #
    all_2_way_marginals = get_all_marginals(list(dataset.features.keys()), 2, downward_closure=False)
    all_3_way_marginals = get_all_marginals(list(dataset.features.keys()), 3, downward_closure=False)
    available_workloads = {
        'all_two': all_2_way_marginals,
        'all_three': all_3_way_marginals,
        'all_three_with_labels': [m for m in all_3_way_marginals if dataset.label in m]
    }
    precomputed_workloads = {
        'all_two': {m: query_marginal(full_one_hot, m, dataset.full_one_hot_index_map, normalize=True, input_torch=True, max_slice=1000) for m in available_workloads['all_two']},
        'all_three': {m: query_marginal(full_one_hot, m, dataset.full_one_hot_index_map, normalize=True, input_torch=True, max_slice=1000) for m in available_workloads['all_three']},
        'all_three_with_labels': {m: query_marginal(full_one_hot, m, dataset.full_one_hot_index_map, normalize=True, input_torch=True, max_slice=1000) for m in available_workloads['all_three_with_labels']}
    }

    backbone_model = 'distilgpt2' if args.distilled else 'gpt2'

    collected_data = np.zeros((len(available_workloads), args.n_samples, args.n_resamples, 9, 5))

    timer = Timer(len(available_workloads) * args.n_samples * args.n_resamples)

    if os.path.isfile(save_path) and not args.force:
        print('Experiment already evaluated -- Aborting evaluation')
    
    else:
        for sample in range(args.n_samples):
                
            if args.model == 'GReaT':
                specific_model_load_path = base_model_path + f'trained_{args.model}_{sample+1}_{args.n_samples}_{backbone_model}_{args.random_seed}.pickle'
            elif args.model == 'ProgSyn':
                specific_model_load_path = base_model_path + f'trained_{args.model}_{sample+1}_{args.n_samples}_{args.workload}_{args.random_seed}.pickle'
            else:    
                specific_model_load_path = base_model_path + f'trained_{args.model}_{sample+1}_{args.n_samples}_{args.random_seed}.pickle'

            if os.path.isfile(specific_model_load_path):
                with open(specific_model_load_path, 'rb') as f:
                    synthesizer = pickle.load(f)
            else:
                print('Model not found -- skipping evaluation')
                continue

            for resample in range(args.n_resamples):

                # sample synthetic data
                if args.model == 'ProgSyn':
                    full_one_hot_synth = synthesizer.generate_data(len(full_one_hot)).detach()
                    synthetic_data_full_one_hot_np = full_one_hot_synth.clone().cpu().numpy()
                    ymixed_synth = synthetic_data_full_one_hot_np[:, -1]
                    Xmixed_synth = dataset.decode_full_one_hot_batch(synthetic_data_full_one_hot_np[:, :-2])
                    Xmixed_synth = dataset.encode_batch(Xmixed_synth, standardize=False).cpu().numpy()
                else:
                    synthetic_data = synthesizer.sample(len(full_one_hot))

                    # convert this data to mixed representation
                    synthetic_data_np = synthetic_data.to_numpy()
                    Xmixed_synth, ymixed_synth = dataset.encode_batch(synthetic_data_np[:, :-1], standardize=False).cpu().numpy(), synthetic_data_np[:, -1].astype(int)

                    # convert this data to full one hot representation
                    synthetic_data_np[:, -1] = [dataset.features[dataset.label][idx] for idx in ymixed_synth]
                    full_one_hot_synth = dataset.encode_full_one_hot_batch(synthetic_data_np, with_label=True, return_torch=True)

                for i, (workload_name, workload_marginals) in enumerate(available_workloads.items()):

                    timer.start()
                    print(f'Workload: {workload_name}    Sample: {sample+1}/{args.n_samples}    Resample: {resample+1}/{args.n_resamples}    {timer}', end='\r')

                    # do the usual evaluation on the discretized dataset
                    tv_errors, l2_errors, js_errors, xgb_acc, xgb_bac, xgb_f1 = evaluate_sampled_dataset(
                        synthetic_dataset=full_one_hot_synth,
                        workload=workload_marginals,
                        true_measured_workload=precomputed_workloads[workload_name],
                        dataset=dataset,
                        max_slice=1000,
                        random_seed=args.random_seed
                    )

                    # train an xgboost on the non-discretized dataset as well
                    xgb = XGBClassifier()
                    xgb.fit(Xmixed_synth, ymixed_synth)
                    predictions = xgb.predict(Xtest)
                    acc_non_disc, bac_non_disc, f1_non_disc = accuracy_score(ytest, predictions), balanced_accuracy_score(ytest, predictions), f1_score(ytest, predictions)
                    acc_nd_stats, bac_nd_stats, f1_nd_stats = statistics([acc_non_disc]), statistics([bac_non_disc]), statistics([f1_non_disc])

                    collected_data[i, sample, resample, 0] = tv_errors
                    collected_data[i, sample, resample, 1] = l2_errors
                    collected_data[i, sample, resample, 2] = js_errors

                    collected_data[i, sample, resample, 3] = xgb_acc
                    collected_data[i, sample, resample, 4] = xgb_bac
                    collected_data[i, sample, resample, 5] = xgb_f1
                    
                    collected_data[i, sample, resample, 6] = acc_nd_stats
                    collected_data[i, sample, resample, 7] = bac_nd_stats
                    collected_data[i, sample, resample, 8] = f1_nd_stats

                    timer.end()
        
        np.save(save_path, collected_data)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('non_dp_benchmark_parser')
    parser.add_argument('--model', type=str, help='Specify the model to be evaluated')
    parser.add_argument('--dataset', type=str, default='ADULT', help='Select the dataset')
    parser.add_argument('--n_samples', type=int, default=5, help='Number of ratraining samples')
    parser.add_argument('--n_resamples', type=int, default=5, help='Number of resamples from the data')
    parser.add_argument('--random_seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--distilled', action='store_true', help='Use distilled gpt2 as backbone')
    parser.add_argument('--workload', type=str, default='all_three_with_labels', help='Training workload of ProgSyn')
    parser.add_argument('--force', action='store_true', help='Force the exectution of the experiment')
    in_args = parser.parse_args()
    main(in_args)
