import argparse
from itertools import product
import os
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from analysis.calibration import generate_calibration_constants
from analysis.rdp_cumulative import analyze_multiclass_confident_fair_gnmax, analyze_multiclass_confident_gnmax
from fairpate_tabular.args import process_arguments
from fairpate_tabular.grid_search import run_grid_search
from fairpate_tabular.optuna_study import run_optuna
from fairpate_tabular.utils import one_hot, sklearn_disparity
from fairpate_tabular.utils import process_data, write_results
from fairpate_tabular.query import FairPATEQuery, PATEQuery, PATESPreProcessor
import pandas as pd

MAX_ITER = 1000


def create_teacher_indices(indices, num_teachers):
    '''Creates a powerset of teacher indices such that each teacher has the same number of samples'''
    if num_teachers not in [2, 4, 8, 16, 32, 64, 128, 256, 512]:
        raise ValueError("Number of teachers must be a power of 2")

    teacher_indices = [indices]
    while len(teacher_indices) < num_teachers:
        sizes  =np.array([len(x) for x in teacher_indices])
        idx = np.argmax(sizes)
        largest_teacher_indices = teacher_indices[idx]
        teacherset_1, teacherset_2 = train_test_split(largest_teacher_indices, train_size = 0.5, shuffle = True, stratify=define_joint_label(train_labels[largest_teacher_indices], train_sensitives[largest_teacher_indices]), random_state=args.seed)
        # Remove the largest teacher indices
        teacher_indices = teacher_indices[:idx] + teacher_indices[idx+1:]
        teacher_indices.append(teacherset_1)
        teacher_indices.append(teacherset_2)
    
    final_sizes =np.array([len(x) for x in teacher_indices])
    assert final_sizes.max() - final_sizes.min() <= 1
    return teacher_indices


def define_joint_label(labels, sensitives):
    label_groups = np.unique(labels)
    sensitive_groups = np.unique(sensitives)
    joint_labels = list(product(label_groups, sensitive_groups))
    yz_to_id = dict(zip(joint_labels, range(len(joint_labels))))
    joint_labels = np.array([yz_to_id[y, z]for y, z in zip(labels, sensitives)])
    return joint_labels

def sklearn_score(model, features, truth):
    try:
        return model.score(features, truth)
    except AttributeError:
        predictions = model.predict(features)
        return (predictions == truth).mean()

def run_experiment(args, log, results_db=None):
    train_indices = np.arange(len(train_features))
    rng.shuffle(train_indices)
    
    if not args.use_stratification:
        teachers_indices, query_set_indices = train_test_split(train_indices, train_size = args.teacher_query_set_split, shuffle = True, random_state=args.seed)

        tt_size = len(teachers_indices) // (args.num_teachers + 1) # teacher train size. Adding one more 'teacher' for validation purposes
        individual_teacher_indices = {_id: teachers_indices[_id * tt_size: (_id+1) * tt_size] for _id in range(args.num_teachers)}
        # breakpoint()
        validation_set_indices = teachers_indices[args.num_teachers * tt_size:]

    else:
        # Stratified Split
        teachers_indices, query_set_indices = train_test_split(train_indices, train_size = args.teacher_query_set_split, shuffle = True, stratify=define_joint_label(train_labels, train_sensitives), random_state=args.seed)

        split_indices = create_teacher_indices(teachers_indices, args.num_teachers)
        validation_set_indices = split_indices[-1]
        individual_teacher_indices = dict(zip(range(args.num_teachers-1), split_indices[:-1]))
        tt_size = len(individual_teacher_indices[0])

    log("Size of training set: ", len(train_indices))
    log("Size of Each Teacher Training Set: ", tt_size)
    log("Size of Validation Set: ", len(validation_set_indices))
    log("Size of Query Set: ", len(query_set_indices))
    

    if not args.use_stratification:
        teacher_models = {_id: LogisticRegression(max_iter=MAX_ITER, random_state=args.seed).fit(X=train_features[individual_teacher_indices[_id]], y=train_labels[individual_teacher_indices[_id]]) for _id in range(args.num_teachers)}

        # shape: (num_query, num_class, num_teachers)
        train_votes = np.concatenate([one_hot(teacher_models[_id].predict(train_features))[:, :, None] 
                                    for _id in range(args.num_teachers)], axis=2)    
    else:
        teacher_models = {_id: LogisticRegression(max_iter=MAX_ITER, random_state=args.seed).fit(X=train_features[individual_teacher_indices[_id]], y=train_labels[individual_teacher_indices[_id]]) for _id in range(args.num_teachers-1)}

        # shape: (num_query, num_class, num_teachers -1)
        train_votes = np.concatenate([one_hot(teacher_models[_id].predict(train_features))[:, :, None]             
                                    for _id in range(args.num_teachers-1)], axis=2)

    # votes over all training samples (this is only an intermediate value; will _not_ be using them all)
    train_votes = train_votes.sum(axis=2) # shape: (num_query, num_class) this is the histogram of the votes
        
    # Privacy Analysis
    if args.gt_fairness:
        # Note that with calibration, query set will be smaller
        calibration_indices, query_set_indices = train_test_split(query_set_indices, train_size = args.num_calib, shuffle = True, stratify=define_joint_label(train_labels[query_set_indices], train_sensitives[query_set_indices]), random_state=args.seed)

        c_votes = train_votes[calibration_indices]
        c_targets = train_labels[calibration_indices]
        c_sensitives = train_sensitives[calibration_indices]

        if (c_votes is not None and c_targets is not None or c_sensitives is not None):
            for_z, but_z = generate_calibration_constants(args.fairness_metric, c_votes=c_votes, c_targets=c_targets, c_sensitives=c_sensitives)

        log("Size of Calibration Set: ", len(calibration_indices))
    else:
        for_z, but_z = None, None


    raw_votes = train_votes[query_set_indices]
    labels = train_labels[query_set_indices] # shape: (num_query, num_features)
    sensitives = train_sensitives[query_set_indices] # shape: (num_query, num_sensitive)
    
    if args.pate_based_model == 'fairpate':
        (max_num_query, dp_eps, partition, answered, order_opt, 
        sensitive_group_count, pos_prediction_one_hot, answered_curr, gaps, pr_answered) = \
                    analyze_multiclass_confident_fair_gnmax(votes=raw_votes, sensitives=sensitives, 
                                                                threshold=args.threshold, 
                                                                fair_threshold=args.fairness_threshold, 
                                                                sigma_threshold=args.sigma_threshold, 
                                                                sigma_fair_threshold=0,
                                                                sigma_gnmax=args.sigma_gnmax, 
                                                                budget = args.budget,
                                                                for_z=for_z,
                                                                but_z=but_z,
                                                                delta = args.delta, file='.', log=lambda *x: None, args=args)
    
    elif args.pate_based_model in ['pate', 'pateSpre', 'pateSin']:
        (max_num_query, dp_eps, partition, answered, order_opt) = \
                    analyze_multiclass_confident_gnmax(votes=raw_votes,
                                                       threshold=args.threshold,
                                                       sigma_threshold=args.sigma_threshold,
                                                       sigma_gnmax=args.sigma_gnmax, 
                                                       budget = args.budget,
                                                       delta = args.delta, file='.')
    else:
        raise NotImplementedError

    log("Maximum #Queries to be answered: ", max_num_query)
    achieved_eps = dp_eps[max_num_query-1]

    # Train Student model
    if args.pate_based_model == 'fairpate':
        query = FairPATEQuery(sensitive_group_list=np.unique(sensitives),
                            min_group_count=args.min_group_count, 
                            max_fairness_violation=args.fairness_threshold, 
                            num_classes=args.num_classes, 
                            threshold=args.threshold,
                            sigma_threshold=args.sigma_threshold, 
                            sigma_gnmax=args.sigma_gnmax, 
                            fairness_metric=args.fairness_metric, 
                            dataset=args.dataset)
        answered, student_train_labels = query.create_student_training_set(train_features[query_set_indices], train_sensitives[query_set_indices], raw_votes, for_z=for_z, but_z=but_z)
    else:
        query = PATEQuery(num_classes=args.num_classes, 
                          threshold=args.threshold, 
                          sigma_threshold=args.sigma_threshold, 
                          sigma_gnmax=args.sigma_gnmax)
        answered, student_train_labels = query.create_student_training_set(train_features[query_set_indices], 
                                                                           raw_votes)

    
    
    # max_num_queries comes from the privacy analysis
    indices_answered = query_set_indices[answered][:max_num_query]
    student_train_sensitives = train_sensitives[indices_answered]


    if args.pate_based_model == 'pateSpre':
        # PATE-S_pre mitigation occurs here:
        log("#Queries before fairness processing: ", len(indices_answered))
        if args.fairness_metric == 'DemParity':
            processor = PATESPreProcessor(sensitive_group_list=np.unique(sensitives), 
                                          num_classes=args.num_classes,
                                          min_group_count=args.min_group_count, 
                                          max_fairness_violation=args.fairness_threshold)
            mask = processor.filter_student_training_set(student_train_labels[:max_num_query], student_train_sensitives)
            indices_answered = indices_answered[mask]
            log('#Queries after fairness processing: ', len(indices_answered))
        else:
            raise NotImplementedError("Only DemParity is supported for now")
    
    log("Actual #Queries answered: ", len(indices_answered))
    student_train_features = train_features[indices_answered]

    if args.pate_based_model == 'pateSin':
        if args.fairness_metric == 'DemParity':
            from fairlearn.reductions import DemographicParity, ExponentiatedGradient
            unmitigated_student_model = LogisticRegression(max_iter=MAX_ITER, random_state=args.seed)
            constraint = DemographicParity()
            student_model = ExponentiatedGradient(unmitigated_student_model, constraint)
            student_model.fit(X=student_train_features, y=student_train_labels[:max_num_query], 
                              sensitive_features=student_train_sensitives)
        else:
            raise NotImplementedError("Only DemParity is supported for now")
    else:
        student_model = LogisticRegression(max_iter=MAX_ITER, random_state=args.seed).fit(X=student_train_features, y=student_train_labels[:max_num_query])

    student_model_validation_accuracy = sklearn_score(student_model, train_features[validation_set_indices], train_labels[validation_set_indices])
    log("Student Model Train Accuracy: ", student_model_validation_accuracy)
    
    student_model_test_accuracy = sklearn_score(student_model, test_features, test_labels)
    log("Student Model Test Accuracy: ", student_model_test_accuracy)

    ### Disparity
    validation_disparity = sklearn_disparity(args.fairness_metric, student_model, 
                                             train_features[validation_set_indices], train_sensitives[validation_set_indices], labels=train_labels[validation_set_indices], args=args)
    test_dem_disparity = sklearn_disparity(args.fairness_metric, student_model, test_features, test_sensitives, labels=test_labels,
                                           interpretation=args.dem_disparity_interpretation, args=args)

    log("Validation Disparity: ", validation_disparity)
    log("Test Disparity: ", test_dem_disparity)

    if results_db is not None:
        results_db = pd.concat([results_db, 
                                pd.DataFrame({'model': args.pate_based_model, 'fairness_metric': args.fairness_metric, 'dataset': args.dataset, 'num_teachers': args.num_teachers, 'threshold': args.threshold, 'fairness_threshold': args.fairness_threshold, 'sigma_threshold': args.sigma_threshold, 'sigma_gnmax': args.sigma_gnmax, 'budget': args.budget, 'delta': args.delta, 'seed': args.seed, 'student_validation_accuracy': student_model_validation_accuracy, 'student_test_accuracy': student_model_test_accuracy, 'validation_disparity': validation_disparity, 'test_disparity': test_dem_disparity, 'achieved_eps': achieved_eps, 'max_num_query': max_num_query, 'max_actual_query': len(indices_answered)}, index=[0])], ignore_index=True)
        return results_db, student_model_validation_accuracy, validation_disparity, achieved_eps, max_num_query, len(indices_answered), student_model_test_accuracy, test_dem_disparity
    else:
        return student_model_validation_accuracy, validation_disparity, achieved_eps, max_num_query, len(indices_answered), student_model_test_accuracy, test_dem_disparity

    
if __name__ == '__main__':
   
    args = argparse.ArgumentParser()
    args.add_argument('--dataset', type=str, default='adult')
    args.add_argument('--list_dataset', nargs='+', type=str, help='A list of datasets to run the experiment on.')

    args.add_argument('--num_classes', type=int, default=2)
    args.add_argument('--output_col_name', type=str, default='income')
    args.add_argument('--split', type=float, default=0.75)

    args.add_argument('--dem_disparity_interpretation', type=str, default='max_vs_min')

    args.add_argument('--teacher_query_set_split', type=float, default=0.7)
    
    args.add_argument('--num_teachers', type=int, default=4)
    args.add_argument('--list_num_teachers', nargs='+', type=int, help='A list of number of teachers to run the experiment on.')

    args.add_argument('--threshold', type=int, default=2)
    args.add_argument('--list_threshold', nargs='+', type=int, help='A list of thresholds to run the experiment on.')

    args.add_argument('--fairness_threshold', type=float, default=0.2)
    args.add_argument('--list_fairness_threshold', nargs='+', type=float, help='A list of fairness thresholds to run the experiment on.')

    args.add_argument('--sigma_threshold', type=float, default=60)
    args.add_argument('--list_sigma_threshold', nargs='+', type=float, help='A list of sigma thresholds to run the experiment on.')

    args.add_argument('--sigma_fair_threshold', type=int, default=0)

    args.add_argument('--sigma_gnmax', type=float, default=25)
    args.add_argument('--list_sigma_gnmax', nargs='+', type=float, help='A list of sigma gnmax to run the experiment on.')

    args.add_argument('--budget', type=float, default=1000)
    args.add_argument('--list_budget', nargs='+', type=float, help='A list of budgets to run the experiment on.')

    args.add_argument('--delta', type=float, default=1e-5)
    args.add_argument('--verbose', action='store_true')

    args.add_argument('--seed', type=int, default=0)
    args.add_argument('--list_seed', nargs='+', type=int, help='A list of seeds to run the experiment on.')

    args.add_argument('--data_path', type=str, default='./fairpate_tabular/data/')
    args.add_argument('--min_group_count', type=int, default=50)
    args.add_argument('--results_dir', type=str, default='.', help='Directory to store the results in.')
    
    args.add_argument('--use_optuna', action='store_true', help='Whether to use optuna to find the best hyperparameters.')
    args.add_argument('--num_optuna_trials', type=int, default=1000, help='Number of optuna trials to run.')

    args.add_argument('--use_stratification', action='store_true', help='Whether to use stratification to split the data.')

    args.add_argument('--fairness_metric', type=str, default='DemParity', help='Fairness metric to use for the experiment. Can be `DemParity`, `ErrorParity`, or `EqualityOfOdds`.')
    args.add_argument('--list_fairness_metric', nargs='+', type=str, help='A list of fairness metrics to run the experiment on.')

    args.add_argument('--num_calib', type=int, default=100, help='Number of calibration samples to use for ground-truth-based fairness metrics.')

    args.add_argument('--pate_based_model', type=str, default='fairpate', help='What PATE-based model to use. Can be `fairpate`, `pate`, `pateSpre`, `pateSin` or `pateSpost`.')
    
    
    args =  process_arguments(args.parse_args())
    if args.verbose:
        log = print
    else:
        log = lambda *x, **y: None

    log('\n'.join(f'{k}: {v}' for k, v in vars(args).items()))

    # PATE analysis needs this
    np.random.seed(args.seed)
    # Otherwise we use an rng
    rng = np.random.default_rng(args.seed)

    # Import data
    train_features, train_labels, train_sensitives, test_features, test_labels, test_sensitives = process_data(rng, args, log)

    # Initalize results db
    if os.path.exists(args.results_db_path):
        results_db = pd.read_parquet(args.results_db_path)
    else:
        results_db = pd.DataFrame(columns=['dataset', 'model', 'fairness_metric', 'num_teachers', 'threshold', 'fairness_threshold', 'sigma_threshold', 'sigma_gnmax', 'budget', 'delta', 'seed', 'student_validation_accuracy', 'student_test_accuracy', 'validation_disparity', 'test_disparity', 'achieved_eps', 'max_num_query', 'max_actual_query'])

    args_dict = vars(args)
    list_args = dict(filter(lambda x: 'list' in x[0], args_dict.items()))
    non_None_list_args = dict(filter(lambda x: x[1] != None, list_args.items()))
    # non_list_args = dict(filter(lambda x: 'list' not in x[0], args_dict.items()))
    
    if args.use_optuna:
        print("Running Optuna (Ignoring list items)")
        run_optuna(args, run_experiment)

    elif len(non_None_list_args) > 0:
        print("Running multiple experiments")
        run_grid_search(args, run_experiment, results_db, non_None_list_args)
    else:       
        results_db, student_model_validation_accuracy, validation_dem_disparity, achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_dem_parity = run_experiment(args, log, results_db)
        write_results(args, student_model_validation_accuracy, validation_dem_disparity, achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_dem_parity)
        results_db.drop_duplicates().to_parquet(args.results_db_path)