import os
import pickle
import argparse
import numpy as np
import torch
from programmable_synthesizer import ProgrammableSynthesizer
from itertools import product
from utils import evaluate_sampled_dataset, statistics, Timer
from constraints import ConstraintEvaluator
from query import get_all_marginals, query_marginal
from tabular_datasets import ADULT, HealthHeritage
import copy


def main(args):

    # set the random seed
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    device = 'cuda'

    available_datasets = {
        'ADULT': ADULT,
        'HealthHeritage': HealthHeritage
    }
    # instantiate the dataset
    dataset = available_datasets[args.dataset](drop_education_num=True, device=device) if args.dataset == 'ADULT' else available_datasets[args.dataset](device=device)
    full_one_hot = dataset.get_Dtrain_full_one_hot(return_torch=True)
    full_one_hot_test = dataset.get_Dtest_full_one_hot(return_torch=True)

    # prepare everything workload related
    translated_workload = {
        'all_two': 2,
        'all_three': 3,
        'all_three_with_labels': 'all_three_with_labels'
    }
    workload_marginal_names = {
        'all_two': get_all_marginals(list(dataset.features.keys()), 2, downward_closure=False),
        'all_three': get_all_marginals(list(dataset.features.keys()), 3, downward_closure=False)
    }
    workload_marginal_names['all_three_with_labels'] = [m for m in workload_marginal_names['all_three'] if dataset.label in m]
    measured_workload = {m: query_marginal(full_one_hot, m, dataset.full_one_hot_index_map, normalize=True, input_torch=True, max_slice=1000) for m in workload_marginal_names[args.workload]}

    # differential privacy incl. or no
    dp = args.epsilon > 0
    additional_string = '_dp' if dp else ''

    if not dp:
        if args.dataset == 'ADULT':
            programs_and_params = {
                f'eliminate_predictability_sex.txt': {'arguments': {'param1': np.linspace(0.0005, 0.002, 10)}, 'denoiser_config': {'finetuning_epochs': 200}},
                f'fairness_downstream_sex.txt': {'arguments': {'param1': np.linspace(0.0007, 0.001, 10)}, 'denoiser_config': {'finetuning_epochs': 200}},
                f'minimize_correlation_sex.txt': {'arguments': {'param1': np.linspace(0.01, 1.0, 5)}, 'denoiser_config': {'finetuning_epochs': 200}},
                f'avg_age_30.txt': {'arguments': {'param1': np.linspace(0.0, 0.00005, 5)}, 'denoiser_config': {'finetuning_epochs': 200}},
                f'avg_male_female_age.txt': {'arguments': {'param1': np.linspace(0.0, 0.00005, 5)}, 'denoiser_config': {'finetuning_epochs': 200}},
                f'implication1.txt': {'arguments': {'param1': np.linspace(0.0, 0.00001, 5)}, 'denoiser_config': {'finetuning_epochs': 500}},
                f'implication2.txt': {'arguments': {'param1': np.linspace(0.0, 0.00001, 5)}, 'denoiser_config': {'finetuning_epochs': 500}},
                f'implication3.txt': {'arguments': {'param1': np.linspace(0.0, 0.00001, 5)}, 'denoiser_config': {'finetuning_epochs': 500}},
                f'line_constraint1.txt': {'arguments': {'param1': np.linspace(0.0, 0.00001, 5)}, 'denoiser_config': {'finetuning_epochs': 500}},
                f'line_constraint2.txt': {'arguments': {'param1': np.linspace(0.0, 0.00001, 5)}, 'denoiser_config': {'finetuning_epochs': 500}},
            }
        elif args.dataset == 'HealthHeritage':
            programs_and_params = {
                f'implication1.txt': {'arguments': {'param1': np.linspace(0.000001, 0.00001, 5)}, 'denoiser_config': {'finetuning_epochs': 500, 'finetuning_batch_size': 20000}},
                f'line_constraint1.txt': {'arguments': {'param1': np.linspace(0.00001, 0.0005, 5)}, 'denoiser_config': {'finetuning_epochs': 500, 'finetuning_batch_size': 15000}}
            }

    else:
        if args.dataset == 'ADULT':
            programs_and_params = {
                f'fairness_downstream_sex_dp.txt': {'arguments': {'param1': np.linspace(0.001, 0.01, 10)}, 'denoiser_config': {'finetuning_epochs': 200}},
                f'implication1_dp.txt': {'arguments': {'param1': np.linspace(0.0, 0.00005, 5)}, 'denoiser_config': {'finetuning_epochs': 200}},
                f'implication2_dp.txt': {'arguments': {'param1': np.linspace(0.0, 0.00005, 5)}, 'denoiser_config': {'finetuning_epochs': 200}},
                f'implication3_dp.txt': {'arguments': {'param1': np.linspace(0.0, 0.0005, 5)}, 'denoiser_config': {'finetuning_epochs': 200}},
                f'line_constraint1_dp.txt': {'arguments': {'param1': np.linspace(0.0, 0.00005, 5)}, 'denoiser_config': {'finetuning_epochs': 200}},
                f'line_constraint2_dp.txt': {'arguments': {'param1': np.linspace(0.0, 0.00005, 5)}, 'denoiser_config': {'finetuning_epochs': 200}}
            }
        elif args.dataset == 'HealthHeritage':
            programs_and_params = {}

    base_path = f'experiment_data/constraint_program_experiments/{args.dataset}/'
    base_path += 'dp_constraints/' if dp else 'non_dp_constraints/'
    eval_base_path = base_path + 'testing_results/'
    os.makedirs(eval_base_path, exist_ok=True)

    for program_name, program_setups in programs_and_params.items():
        
        stripped_program_name = program_name.split('.')[0]
        if args.baseline_mode:
            eval_save_path = eval_base_path + f'{stripped_program_name}_{args.workload}_{args.n_samples}_{args.n_resamples}_{args.random_seed}_{args.epsilon}_baselines.npy'
        else:    
            eval_save_path = eval_base_path + f'{stripped_program_name}_{args.workload}_{args.n_samples}_{args.n_resamples}_{args.random_seed}_{args.epsilon}.npy'

        print(f'Evaluating {stripped_program_name}, Baseline mode: {args.baseline_mode}')

        if os.path.isfile(eval_save_path) and not args.force:
            print('This experiment has been conducted already, abort')
            continue
        
        else:

            load_path = base_path + f'training_constraints/{program_name}'
            with open(load_path, 'r') as f:
                program = f.read()
                print(program)

            eval_load_path = base_path + f'evaluation_constraints/{stripped_program_name}_eval.txt'
            with open(eval_load_path, 'r') as f:
                eval_program = f.read()

            if dp:
                program = program.replace('<epsilon>', str(args.epsilon))

            param_combinations = list(product(*list(program_setups['arguments'].values())))
            timer = Timer(2) if args.baseline_mode else Timer(len(param_combinations) * args.n_samples)

            collected_data = None
            for i, params_combination in enumerate(param_combinations):

                current_arguments = {arg_name: param for arg_name, param in zip(list(program_setups['arguments'].keys()), params_combination)}

                for sample in range(args.n_samples):

                    if (sample > 0 or i > 1) and args.baseline_mode:
                        continue

                    timer.start()
                    if args.baseline_mode:
                        print(f'Baseline datasets: {i+1}/2    {timer}', end='\r')
                    else:
                        print(f'Parameter Combination: {i+1}/{len(param_combinations)}    Sample: {sample+1}/{args.n_samples}    {timer}', end='\n')

                    denoiser_config = {'finetuning_epochs': 0} if args.baseline_mode else program_setups['denoiser_config']

                    synthesizer = ProgrammableSynthesizer(
                        constraint_program=program, 
                        workload=translated_workload[args.workload], 
                        random_seed=args.random_seed, 
                        device=device,
                        denoiser_config=denoiser_config
                    )

                    synthesizer.fit(program_arguments=current_arguments, verbose=False)

                    for resample in range(args.n_resamples):
                        
                        if args.baseline_mode and i == 0:
                            synthetic_data = full_one_hot.clone().detach()
                        else:
                            if (program_name.startswith('implication') or program_name.startswith('line_constraint')) and not args.baseline_mode:
                                synthetic_data = synthesizer.generate_data_with_rejection_sampling(len(synthesizer.base_data), eval_program)
                            else:
                                synthetic_data = synthesizer.generate_data(len(synthesizer.base_data)).detach()
                        
                        tv_stats, l2_stats, js_stats, acc_stats, bac_stats, f1_stats = evaluate_sampled_dataset(
                            synthetic_dataset=synthetic_data.detach().clone(),
                            workload=workload_marginal_names[args.workload],
                            true_measured_workload=measured_workload,
                            dataset=synthesizer.dataset,
                            max_slice=1000,
                            random_seed=args.random_seed
                        )

                        # evaluate the constraints
                        ce = ConstraintEvaluator(
                            program=copy.copy(eval_program),
                            dataset=dataset,
                            base_data=full_one_hot_test.detach().clone(),
                            xgb_random_state=args.random_seed,
                            program_arguments=None,
                            device=device
                        )
                        constraint_eval_data = ce.evaluate_constraints(synthetic_data.detach().clone())
                        constraint_stats = []
                        for ced in constraint_eval_data:
                            scores = ced['score']
                            if isinstance(scores, list):
                                constraint_stats += [statistics([score]) for score in scores]
                            else:
                                constraint_stats += [statistics([scores])]

                        if collected_data is None:
                            collected_data = np.zeros((len(param_combinations), args.n_samples, args.n_resamples, 6 + len(constraint_stats), 5))
                        
                        collected_data[i, sample, resample, 0] = tv_stats
                        collected_data[i, sample, resample, 1] = l2_stats
                        collected_data[i, sample, resample, 2] = js_stats
                        collected_data[i, sample, resample, 3] = acc_stats
                        collected_data[i, sample, resample, 4] = bac_stats
                        collected_data[i, sample, resample, 5] = f1_stats

                        for l, c_stats in enumerate(constraint_stats):
                            collected_data[i, sample, resample, 6+l] = c_stats

                    timer.end()
            
            timer.duration()
            np.save(eval_save_path, collected_data)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('validation_param_search_parser')
    parser.add_argument('--dataset', type=str, default='ADULT', help='Dataset name')
    parser.add_argument('--n_samples', type=int, default=5, help='Number of reruns')
    parser.add_argument('--n_resamples', type=int, default=5, help='Number of resamples')
    parser.add_argument('--random_seed', type=int, default=42, help='Set the random seed')
    parser.add_argument('--workload', type=str, default='all_three_with_labels', help='Set the base workload')
    parser.add_argument('--epsilon', type=float, default=0.0, help='Epsilon for DP constraints')
    parser.add_argument('--baseline_mode', action='store_true', help='Evaluate only single baseline points')
    parser.add_argument('--force', action='store_true', help='Force the execution')
    in_args = parser.parse_args()
    main(in_args)
