import os
import pickle
import argparse
import numpy as np
import torch
from programmable_synthesizer import ProgrammableSynthesizer
from itertools import product
from utils import evaluate_sampled_dataset, statistics, Timer
from constraints import ConstraintEvaluator
from query import get_all_marginals, query_marginal
from tabular_datasets import ADULT, HealthHeritage
import copy
import re


def main(args):

    # set the random seed
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    device = 'cuda'
    
    constraints_to_chain_option_1 = {
        'fairness_downstream_sex.txt': {'param1': 9e-4},
        'fairness_downstream_sex_avg_age_30.txt': {'param1': 9e-4, 'param2': 2.5e-5},
        'fairness_downstream_sex_avg_age_30_avg_male_female_age.txt': {'param1': 9e-4, 'param2': 2.5e-5, 'param3': 1.25e-5},
        'fairness_downstream_sex_avg_age_30_avg_male_female_age_implication3.txt': {'param1': 9e-4, 'param2': 2.5e-5, 'param3': 1.25e-5, 'param4': 7.5e-6},
        'fairness_downstream_sex_avg_age_30_avg_male_female_age_implication3_implication2.txt': {'param1': 9e-4, 'param2': 2.5e-5, 'param3': 1.25e-5, 'param4': 7.5e-6, 'param5': 7.5e-6}
    }

    constraints_to_chain_option_2 = {
        'implication1.txt': {'param1': 7.5e-6},
        'implication1_implication2.txt': {'param1': 7.5e-6, 'param2': 7.5e-6},
        'implication1_implication2_implication3.txt': {'param1': 7.5e-6, 'param2': 7.5e-6, 'param3': 7.5e-6},
        'implication1_implication2_implication3_line_constraint1.txt': {'param1': 7.5e-6, 'param2': 7.5e-6, 'param3': 7.5e-6, 'param4': 2.5e-6},
        'implication1_implication2_implication3_line_constraint1_line_constraint2.txt': {'param1': 7.5e-6, 'param2': 7.5e-6, 'param3': 7.5e-6, 'param4': 2.5e-6, 'param5': 7.5e-6}
    }
    constraints_and_params = constraints_to_chain_option_1 if args.option == 1 else constraints_to_chain_option_2
    eval_program_name = 'fairness_downstream_sex_avg_age_30_avg_male_female_age_implication3_implication2_eval.txt' if args.option == 1 else 'implication1_implication2_implication3_line_constraint1_line_constraint2_eval.txt'
    if args.baseline_mode:
        constraints_and_params = {'base_program.txt': None}

    # dataset preps
    dataset = ADULT(device=device)
    full_one_hot_train = dataset.get_Dtrain_full_one_hot(return_torch=True)
    full_one_hot_test = dataset.get_Dtest_full_one_hot(return_torch=True)

    # workload for eval
    workload_all_three_with_labels = [m for m in get_all_marginals(list(dataset.features.keys()), 3, downward_closure=False) if dataset.label in m]
    measured_workload = {m: query_marginal(full_one_hot_train, m, dataset.full_one_hot_index_map, normalize=True, input_torch=True, max_slice=1000) for m in workload_all_three_with_labels}

    base_load_path_train = f'experiment_data/adult_constraint_ablation/option{args.option}/training_constraints/'
    base_load_path_test = f'experiment_data/adult_constraint_ablation/option{args.option}/evaluation_constraints/'
    base_save_path = f'experiment_data/adult_constraint_ablation/evaluation_results/'
    os.makedirs(base_save_path, exist_ok=True)

    denoiser_config = {'finetuning_epochs': 500}

    collected_data = np.zeros((5, args.n_samples, args.n_resamples, 11, 5))
    timer = Timer(len(constraints_and_params) * args.n_samples)
    for sample in range(args.n_samples):
        for i, (constraint_file, params) in enumerate(constraints_and_params.items()):
            
            timer.start()
            print(f'Sample: {sample+1}/{args.n_samples}    Constraint: {constraint_file}    {timer}                              ', end='\r')
            # load the train and the eval constraints
            with open(f'{base_load_path_train}{constraint_file}', 'r') as f:
                training_program = f.read()
            
            with open(f'{base_load_path_test}{eval_program_name}', 'r') as f:
                eval_program = f.read()
            
            # instantiate ProgSyn
            synthesizer = ProgrammableSynthesizer(
                constraint_program=training_program,
                workload='all_three_with_labels',
                random_seed=args.random_seed,
                device=device,
                denoiser_config=denoiser_config
            )

            synthesizer.fit(program_arguments=params)

            for resample in range(args.n_resamples):
                
                if 'implication' in constraint_file or 'line_constraint' in constraint_file:
                    training_program_param_stripped = re.sub(r'(?i)param\s<[^>\s]*>', 'PARAM 0.1', training_program)
                    synthetic_data = synthesizer.generate_data_with_rejection_sampling(len(synthesizer.base_data), training_program_param_stripped)
                else:
                    synthetic_data = synthesizer.generate_data(len(synthesizer.base_data)).detach()
                
                tv_stats, l2_stats, js_stats, acc_stats, bac_stats, f1_stats = evaluate_sampled_dataset(
                    synthetic_dataset=synthetic_data.detach().clone(),
                    workload=workload_all_three_with_labels,
                    true_measured_workload=measured_workload,
                    dataset=synthesizer.dataset,
                    max_slice=1000,
                    random_seed=args.random_seed
                )

                # evaluate the constraints
                ce = ConstraintEvaluator(
                    program=copy.copy(eval_program),
                    dataset=dataset,
                    base_data=full_one_hot_test.detach().clone(),
                    xgb_random_state=args.random_seed,
                    program_arguments=None,
                    device=device
                )
                constraint_eval_data = ce.evaluate_constraints(synthetic_data.detach().clone())
                constraint_stats = []
                for ced in constraint_eval_data:
                    scores = ced['score']
                    if isinstance(scores, list):
                        constraint_stats += [statistics([score]) for score in scores]
                    else:
                        constraint_stats += [statistics([scores])]
                
                collected_data[i, sample, resample, 0] = tv_stats
                collected_data[i, sample, resample, 1] = l2_stats
                collected_data[i, sample, resample, 2] = js_stats
                collected_data[i, sample, resample, 3] = acc_stats
                collected_data[i, sample, resample, 4] = bac_stats
                collected_data[i, sample, resample, 5] = f1_stats

                for l, c_stats in enumerate(constraint_stats):
                    collected_data[i, sample, resample, 6+l] = c_stats

            timer.end()
    full_save_path = f'{base_save_path}collected_data_option{args.option}_{args.n_samples}_{args.n_resamples}_{args.random_seed}.npy'
    if args.baseline_mode:
        full_save_path = f'{base_save_path}collected_data_option{args.option}_{args.n_samples}_{args.n_resamples}_{args.random_seed}_baseline.npy'
    np.save(full_save_path, collected_data)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('adult_ablation')
    parser.add_argument('--option', type=int, default=int, help='Choose chaining options')
    parser.add_argument('--n_samples', type=int, default=5, help='Number of reruns')
    parser.add_argument('--n_resamples', type=int, default=5, help='Number of resamples')
    parser.add_argument('--random_seed', type=int, default=42, help='Set the random seed')
    parser.add_argument('--baseline_mode', action='store_true', help='Toggle for baseline mode')
    parser.add_argument('--force', action='store_true', help='Force the execution')
    in_args = parser.parse_args()
    main(in_args)
