import os
import pickle
import argparse
import numpy as np
import torch
from programmable_synthesizer import ProgrammableSynthesizer
from itertools import product
from utils import evaluate_sampled_dataset, statistics, Timer
from constraints import ConstraintEvaluator
from query import get_all_marginals, query_marginal
from tabular_datasets import ADULT, HealthHeritage
import copy
import re


def main(args):

    # set the random seed
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    device = 'cuda'

    available_datasets = {
        'ADULT': ADULT,
        'HealthHeritage': HealthHeritage
    }
    # instantiate the dataset
    dataset = available_datasets[args.dataset](drop_education_num=True, device=device) if args.dataset == 'ADULT' else available_datasets[args.dataset](device=device)
    full_one_hot = dataset.get_Dtrain_full_one_hot(return_torch=True)
    full_one_hot_test = dataset.get_Dtest_full_one_hot(return_torch=True)

    # prepare everything workload related
    translated_workload = {
        'all_two': 2,
        'all_three': 3,
        'all_three_with_labels': 'all_three_with_labels'
    }
    workload_marginal_names = {
        'all_two': get_all_marginals(list(dataset.features.keys()), 2, downward_closure=False),
        'all_three': get_all_marginals(list(dataset.features.keys()), 3, downward_closure=False)
    }
    workload_marginal_names['all_three_with_labels'] = [m for m in workload_marginal_names['all_three'] if dataset.label in m]
    measured_workload = {m: query_marginal(full_one_hot, m, dataset.full_one_hot_index_map, normalize=True, input_torch=True, max_slice=1000) for m in workload_marginal_names[args.workload]}

    # differential privacy incl. or no
    dp = args.epsilon > 0
    additional_string = '_dp' if dp else ''

    adult_programs = [
        f'implication1{additional_string}.txt',
        f'implication2{additional_string}.txt',
        f'implication3{additional_string}.txt',
        f'line_constraint1{additional_string}.txt',
        f'line_constraint2{additional_string}.txt'
    ]

    health_programs = [
        'implication1.txt',
        'line_constraint1.txt'
    ]

    programs = adult_programs if args.dataset == 'ADULT' else health_programs

    base_path = f'experiment_data/constraint_program_experiments/{args.dataset}/'
    base_path += 'dp_constraints/' if dp else 'non_dp_constraints/'
    eval_base_path = base_path + 'testing_results/'
    os.makedirs(eval_base_path, exist_ok=True)

    for program_name in programs:
        
        stripped_program_name = program_name.split('.')[0]
        eval_save_path = eval_base_path + f'{stripped_program_name}_{args.workload}_{args.n_samples}_{args.n_resamples}_{args.random_seed}_{args.epsilon}_rejection_sampling.npy'
        
        print(f'Evaluating {stripped_program_name}')

        if os.path.isfile(eval_save_path) and not args.force:
            print('This experiment has been conducted already, abort')
            continue
        
        else:

            load_path = base_path + f'training_constraints/{program_name}'
            with open(load_path, 'r') as f:
                program = f.read()

            eval_load_path = base_path + f'evaluation_constraints/{stripped_program_name}_eval.txt'
            with open(eval_load_path, 'r') as f:
                eval_program = f.read()

            if dp:
                program = program.replace('<epsilon>', str(args.epsilon))

            program = re.sub(r'(?i)param\s<[^>\s]*>', 'PARAM 0.1', program)

            timer = Timer(args.n_samples * args.n_resamples)

            collected_data = None

            for sample in range(args.n_samples):

                synthesizer = ProgrammableSynthesizer(
                    constraint_program=program, 
                    workload=translated_workload[args.workload], 
                    random_seed=args.random_seed, 
                    device=device
                )

                synthesizer.fit(finetune=False)

                for resample in range(args.n_resamples):
                    
                    timer.start()
                    print(f'Sample: {sample+1}/{args.n_samples}    Resample: {resample+1}/{args.n_resamples}    {timer}', end='\r')

                    synthetic_data = synthesizer.generate_data_with_rejection_sampling(len(synthesizer.base_data), program).detach()
                    
                    tv_stats, l2_stats, js_stats, acc_stats, bac_stats, f1_stats = evaluate_sampled_dataset(
                        synthetic_dataset=synthetic_data.detach().clone(),
                        workload=workload_marginal_names[args.workload],
                        true_measured_workload=measured_workload,
                        dataset=synthesizer.dataset,
                        max_slice=1000,
                        random_seed=args.random_seed
                    )

                    # evaluate the constraints
                    ce = ConstraintEvaluator(
                        program=copy.copy(eval_program),
                        dataset=dataset,
                        base_data=full_one_hot_test.detach().clone(),
                        xgb_random_state=args.random_seed,
                        program_arguments=None,
                        device=device
                    )
                    constraint_eval_data = ce.evaluate_constraints(synthetic_data.detach().clone())
                    constraint_stats = []
                    for ced in constraint_eval_data:
                        scores = ced['score']
                        if isinstance(scores, list):
                            constraint_stats += [statistics([score]) for score in scores]
                        else:
                            constraint_stats += [statistics([scores])]

                    if collected_data is None:
                        collected_data = np.zeros((args.n_samples, args.n_resamples, 6 + len(constraint_stats), 5))
                    
                    collected_data[sample, resample, 0] = tv_stats
                    collected_data[sample, resample, 1] = l2_stats
                    collected_data[sample, resample, 2] = js_stats
                    collected_data[sample, resample, 3] = acc_stats
                    collected_data[sample, resample, 4] = bac_stats
                    collected_data[sample, resample, 5] = f1_stats

                    for l, c_stats in enumerate(constraint_stats):
                        collected_data[sample, resample, 6+l] = c_stats

                    timer.end()
            
            timer.duration()
            np.save(eval_save_path, collected_data)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('validation_param_search_parser')
    parser.add_argument('--dataset', type=str, default='ADULT', help='Dataset name')
    parser.add_argument('--n_samples', type=int, default=5, help='Number of reruns')
    parser.add_argument('--n_resamples', type=int, default=5, help='Number of resamples')
    parser.add_argument('--random_seed', type=int, default=42, help='Set the random seed')
    parser.add_argument('--workload', type=str, default='all_three_with_labels', help='Set the base workload')
    parser.add_argument('--epsilon', type=float, default=0.0, help='Epsilon for DP constraints')
    parser.add_argument('--force', action='store_true', help='Force the execution')
    in_args = parser.parse_args()
    main(in_args)
