import sys
import os
import argparse
import json
import joblib
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Any
from tqdm import tqdm

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

from functools import partial
from itertools import product

from .args import process_arguments
from .fairpate import run_experiment, process_data, DataSet, FairPATERunParams
from .optuna_study import run_optuna

MAX_ITER = 1000

# Borrowed from governance-games/agents.py
def is_pareto_efficient(costs, return_mask = False):
    """
        Find the pareto-efficient points
        :param costs: An (n_points, n_costs) array
        :param return_mask: True to return a mask
        :return: An array of indices of pareto-efficient points.
            If return_mask is True, this will be an (n_points, ) boolean array
            Otherwise it will be a (n_efficient_points, ) integer array of indices.
    """
    is_efficient = np.arange(costs.shape[0])
    n_points = costs.shape[0]
    next_point_index = 0  # Next index in the is_efficient array to search for
    while next_point_index<len(costs):
        nondominated_point_mask = np.any(costs<costs[next_point_index], axis=1)
        nondominated_point_mask[next_point_index] = True
        is_efficient = is_efficient[nondominated_point_mask]  # Remove dominated points
        costs = costs[nondominated_point_mask]
        next_point_index = np.sum(nondominated_point_mask[:next_point_index])+1
    if return_mask:
        is_efficient_mask = np.zeros(n_points, dtype = bool)
        is_efficient_mask[is_efficient] = True
        return is_efficient_mask
    else:
        return is_efficient

if __name__ == '__main__':
    args = argparse.ArgumentParser()
    args.add_argument('--dataset', type=str, default='adult')

    args.add_argument('--num_classes', type=int, default=2)
    args.add_argument('--output_col_name', type=str, default='income')
    args.add_argument('--split', type=float, default=0.75)

    args.add_argument('--dem_disparity_interpretation', type=str, default='max_vs_min')

    args.add_argument('--teacher_query_set_split', type=float, default=0.7)
    
    args.add_argument('--num_teachers', type=int, default=4)
    args.add_argument('--list_num_teachers', nargs='+', type=int, help='A list of number of teachers to run the experiment on.')

    args.add_argument('--threshold', type=int, default=2)
    args.add_argument('--list_threshold', nargs='+', type=int, help='A list of thresholds to run the experiment on.')

    args.add_argument('--fairness_threshold', type=float, default=0.2)
    args.add_argument('--list_fairness_threshold', nargs='+', type=float, help='A list of fairness thresholds to run the experiment on.')

    args.add_argument('--sigma_threshold', type=float, default=60)
    args.add_argument('--list_sigma_threshold', nargs='+', type=float, help='A list of sigma thresholds to run the experiment on.')

    args.add_argument('--sigma_fair_threshold', type=int, default=0)

    args.add_argument('--sigma_gnmax', type=float, default=25)
    args.add_argument('--list_sigma_gnmax', nargs='+', type=float, help='A list of sigma gnmax to run the experiment on.')

    args.add_argument('--budget', type=float, default=1000)
    args.add_argument('--list_budget', nargs='+', type=float, help='A list of budgets to run the experiment on.')

    args.add_argument('--delta', type=float, default=1e-5)
    args.add_argument('--verbose', action='store_true')

    args.add_argument('--seed', type=int, default=0)
    args.add_argument('--list_seed', nargs='+', type=int, help='A list of seeds to run the experiment on.')

    args.add_argument('--data_path', type=str, default='./fairpate_tabular/data/')
    args.add_argument('--min_group_count', type=int, default=50)
    args.add_argument('--results_dir', type=str, default='.', help='Directory to store the results in.')
    
    args.add_argument('--use_optuna', action='store_true', help='Whether to use optuna to find the best hyperparameters.')
    args.add_argument('--num_optuna_trials', type=int, default=1000, help='Number of optuna trials to run.')

    args.add_argument('--use_stratification', action='store_true', help='Whether to use stratification to split the data.')

    args.add_argument('--fairness_metric', type=str, default='DemParity', help='Fairness metric to use for the experiment. Can be `DemParity`, `ErrorParity`, or `EqualityOfOdds`.')
    args.add_argument('--list_fairness_metric', nargs='+', type=str, help='A list of fairness metrics to run the experiment on.')

    args.add_argument('--num_calib', type=int, default=100, help='Number of calibration samples to use for ground-truth-based fairness metrics.')

    args.add_argument('--pate_based_model', type=str, default='fairpate', help='What PATE-based model to use. Can be `fairpate`, `pate`, `pateSpre`, `pateSin` or `pateSpost`.')

    args.add_argument('--use_inference_time_postprocessing', action='store_true', help='Whether to use inference-time postprocessing to mitigate fairness violations.')

    args.add_argument('--undersampling_ratio', type=float, default=None, help='Ratio of the majority class to the minority class. If None, no undersampling is done.')

    args.add_argument('--optuna_db_path', type=str, default='./optuna', help='Path to the optuna study db.')

    args.add_argument('--backend', type=str, default='sklearn', help='Backend to use for the experiment. Can be `sklearn`, `keras+pytorch` or `deep_auc`.')    
    args.add_argument('--epochs', type=int, default=100, help='Number of epochs to train the student model for.')
    args.add_argument('--keras_dict', type=json.loads, default='{"optimizer": "adam", "loss": "binary_crossentropy", \
                      "metrics": ["accuracy"]}', help='Dictionary of arguments to pass to the keras model.')

    args.add_argument('--parallel', type=int, default=None, help='Whether to use parallel processing to train the teachers.')

    args.add_argument('--deep_auc_dict', type=json.loads, default='{"batch_size": 32, "lr": 0.05, "margin": 1.0, \
                      "epoch_decay": 2e-3, "weight_decay": 1e-5, "eval_every": 200, "epochs": 2, "train_validation_split" : 0.8}', help='Dictionary of arguments to pass to the deep_auc model.')

    args.add_argument('--skip', type=str, default=None, help='Whether to skip par of the pipeline. Can be `training_teachers`, `training_all`, `voting` or None (defualt).')

    args.add_argument('--log_path', type=str, default='./logs/', help='Path to store the logs and artifacts from the experiments.')
    
    args = process_arguments(args.parse_args())

    if args.verbose:
        log = print
    else:
        log = lambda *x, **y: None
    
    # PATE analysis needs this
    np.random.seed(args.seed)
    # Otherwise we use an rng
    np_rng = np.random.default_rng(args.seed)

    # Initalize results db
    if os.path.exists(args.results_db_path):
        results_db = pd.read_parquet(args.results_db_path)
    else:
        results_db = pd.DataFrame(columns=['dataset', 'model', 'fairness_metric', 'num_teachers', 'threshold', 
                                           'fairness_threshold', 'sigma_threshold', 'sigma_gnmax', 'budget', 
                                           'delta', 'seed', 'student_validation_accuracy', 'student_test_accuracy', 
                                           'validation_disparity', 'test_disparity', 'achieved_eps', 'max_num_query', 
                                           'max_actual_query'])
    
    args_dict = vars(args)

    if args.backend == 'sklearn':
        def generate_model_and_fit(train_features, train_labels, model_id=None):
            model = LogisticRegression(max_iter=MAX_ITER, random_state=args.seed).fit(X=train_features, y=train_labels)
            return model
        train_set, test_set = None, None
    elif args.backend == 'keras+pytorch':
        from arch import LogisticRegrressionKeras

        def generate_model_and_fit(train_features, train_labels, model_id=None):
            model = LogisticRegrressionKeras(input_num_attr=args.num_inp_attr)
            model.compile(**args.keras_dict)
            model.fit(x=train_features, y=train_labels, epochs=args.epochs, verbose=0)
            # monkey-patching the predict function to return binary predictions
            model._predict = getattr(model, "predict")
            # breakpoint()
            setattr(model, "predict", lambda x: (model._predict(x, verbose=0) >= 0.5).astype(int).squeeze())
            return model
        train_set, test_set = None, None
    elif args.backend == 'deep_auc':
        from fairpate_tabular.custom_trainers.deep_auc import train_set, test_set, generate_model_and_fit
        generate_model_and_fit = partial(generate_model_and_fit, args=args)

    # Import data
    train_features, train_labels, train_sensitives, test_features, test_labels, test_sensitives = \
                process_data(np_rng, args, log, train_set=train_set, test_set=test_set)

    train_data = DataSet(features=train_features, labels=train_labels, sensitives=train_sensitives, set=train_set)
    test_data = DataSet(features=test_features, labels=test_labels, sensitives=test_sensitives, set=test_set)

    packed_data = FairPATERunParams(train=train_data, test=test_data, np_rng=np_rng, generate_model_and_fit=generate_model_and_fit, MAX_ITER=MAX_ITER)

    results_db = None
    # Iterate values for privacy and fairness
    #priv_list = [2, 3, 4, 5, 6, 7, 8, 9, 10]
    #fair_list = [0.05, 0.1, 0.15, 0.2, 0.25]
    priv_list = [2, 3, 4, 5, 6, 7, 8, 9, 10]
    fair_list = [0.05, 0.1, 0.15, 0.2, 0.25]
    #priv_list = [4]
    #fair_list = [0.2]
    total_entries = len(priv_list) * len(fair_list)
    loss_builder_acc = [] # np.zeros((total_entries, 1))
    loss_privacy = [] # np.zeros((total_entries, 1))
    loss_fairness = [] # np.zeros((total_entries, 1))
    priv_fair_values = [] # np.zeros((total_entries, 2))
    loss_builder_cov = [] # np.zeros((total_entries, 1))

    i = 0
    priv_fairs = list(product(priv_list, fair_list))
    priv_fairs_len = len(priv_fairs)
    for priv, fair in tqdm(priv_fairs):
        i += 1
        args.budget = priv
        args.fairness_threshold = fair
        print(f'({i / priv_fairs_len:.3%}) Testing epsilon {priv}, gamma {fair}')
        try:
            if args.use_optuna:
                best_trials = run_optuna(args, packed_data, run_experiment)

                # pareto py on best_trials, only taking the achieved eps, disparity, builder acc/loss
                # note: test_dem_parity = test_disparity
                #print('trials_len', len(best_trials))
                #print(best_trials[0])

                costs = np.array([
                    [trial.user_attrs['achieved_eps'],
                        trial.user_attrs['test_dem_parity'],
                        -trial.user_attrs['student_model_test_accuracy']]
                    for trial in best_trials
                ])

                #print(costs)
                
                pareto_indices = is_pareto_efficient(costs)

                # pick remaining and concat to thingy
                #print('pareto len', len(pareto_indices))
                for idx in pareto_indices:
                    trial_attrs = best_trials[idx].user_attrs
                    loss_builder_acc.append(trial_attrs['student_model_test_accuracy'])
                    loss_privacy.append(trial_attrs['achieved_eps'])
                    loss_fairness.append(trial_attrs['test_dem_parity'])
                    priv_fair_values.append(np.array([priv, fair]))
                    loss_builder_cov.append(trial_attrs['test_coverage'])
            else:
                student_model_validation_accuracy, validation_disparity, validation_coverage, validation_auc, \
                        achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_disparity, test_coverage, test_auc = run_experiment(args, log, packed_data, results_db)
                # write results
                loss_builder_acc.append(student_model_test_accuracy)
                loss_privacy.append(achieved_eps)
                loss_fairness.append(test_disparity)
                priv_fair_values.append(np.array([priv, fair]))
                loss_builder_cov.append(test_coverage)
        except ValueError:
            print('Too strict conditions.')
            student_model_test_accuracy, achieved_eps, test_disparity, test_coverage = None, None, None, None

    np.save('pareto/builder_loss_acc.npy', loss_builder_acc)
    np.save('pareto/privacy_loss.npy', loss_privacy)
    np.save('pareto/fairness_loss.npy', loss_fairness)
    np.save('pareto/priv_fair_values.npy', priv_fair_values)
    np.save('pareto/builder_loss_cov.npy', loss_builder_cov)
    print(loss_builder_acc)
    print(priv_fair_values)