from copy import deepcopy
import os
import pdb
import numpy as np
from dataloader import GeneralData
from imblearn.under_sampling import RandomUnderSampler
import pandas as pd

def one_hot(x, num_class=2): 
    return np.eye(num_class)[x]

def get_disparity(fairness_metric, preds, sensitives, labels=None, **kwargs):
    if fairness_metric == 'DemParity':
        return dem_disparity(preds, sensitives, **kwargs)
    elif fairness_metric == 'ErrorParity':
        return error_disparity(preds, sensitives, labels, **kwargs)
    elif fairness_metric == 'EqualityOfOdds':
        return equality_of_odds_disparity(preds, sensitives, labels, **kwargs)
    else:
        raise ValueError(f'Unknown fairness metric: {fairness_metric}')
    
def dem_disparity(preds, sensitives, num_sensitives=2, interpretation='max_vs_min', **kwargs):
    sensitives_one_hot = one_hot(sensitives, num_class=num_sensitives)
    pos_prediction = preds @ sensitives_one_hot
    sensitive_group_count = sensitives_one_hot.sum(axis=0)

    # Check if any of sensitive_group_count is 0
    if np.any(sensitive_group_count == 0):
        raise ValueError(f'One of sensitive_group_count is 0: {sensitive_group_count}')

    prob_pos_prediction_per_subgroup = pos_prediction / sensitive_group_count
    if interpretation == 'one_vs_average':
        avg_prob_pos_prediction = pos_prediction.sum() / sensitive_group_count.sum()
        dem_disparity = (prob_pos_prediction_per_subgroup - avg_prob_pos_prediction).max()
    elif interpretation == 'one_vs_others':
        others_count = sensitive_group_count.sum() - sensitive_group_count
        others_pos_prediction = pos_prediction.sum() - pos_prediction
        others_prob_pos_prediction_per_subgroup = others_pos_prediction / others_count
        dem_disparity = (prob_pos_prediction_per_subgroup - others_prob_pos_prediction_per_subgroup).max()
    elif interpretation == 'max_vs_min':
        dem_disparity = prob_pos_prediction_per_subgroup.max() - prob_pos_prediction_per_subgroup.min()
    else:
        raise ValueError(f'Unknown interpretation: {interpretation}')
    
    return dem_disparity

def error_disparity(preds, sensitives, labels, **kwargs):
    data = pd.DataFrame(np.c_[preds, labels, sensitives], columns=['prediction', 'truth', 'sensitive'])
    error_per_subgroup = []
    for z in np.unique(sensitives):
        error_per_subgroup.append(data.query(f'sensitive == {z} and prediction != truth')['prediction'].mean())
        
    error_disparity = np.max(error_per_subgroup) - np.min(error_per_subgroup)
    return error_disparity

def equality_of_odds_disparity(preds, sensitives, labels, **kwargs):
    data = pd.DataFrame(np.c_[preds, labels, sensitives], columns=['prediction', 'truth', 'sensitive'])
    Y_set = np.unique(labels)
    Z_set = np.unique(sensitives)

    disparity_list = []
    for z in Z_set:
        for yhat in Y_set:
            for y in Y_set:
                prob_for_z = len(data.query(f'prediction == {yhat} and truth == {y} and sensitive == {z}')) \
                    / len(data.query(f'truth == {y} and sensitive == {z}'))
                prob_but_z = len(data.query(f'prediction == {yhat} and truth == {y} and sensitive != {z}'))\
                    / len(data.query(f'truth == {y} and sensitive != {z}'))
                disparity_list.append(prob_for_z - prob_but_z)

    equality_of_odds_disparity = np.max(disparity_list)
    return equality_of_odds_disparity



def sklearn_disparity(fairness_metric, model, features, sensitives, labels=None, **kwargs):
    if fairness_metric == 'DemParity':
        return sklearn_dem_disparity(model, features, sensitives, **kwargs)
    elif fairness_metric == 'ErrorParity':
        return sklearn_error_disparity(model, features, sensitives, labels, **kwargs)
    elif fairness_metric == 'EqualityOfOdds':
        return sklearn_equality_of_odds_disparity(model, features, sensitives, labels, **kwargs)
    else:
        raise ValueError(f'Unknown fairness metric: {fairness_metric}')

def sklearn_dem_disparity(model, features, sensitives, interpretation='max_vs_min', **kwargs):
    pos_prediction = model.predict(features) @ one_hot(sensitives, num_class=2)
    sensitive_group_count = one_hot(sensitives, num_class=2).sum(axis=0)
    prob_pos_prediction_per_subgroup = pos_prediction / sensitive_group_count
    if interpretation == 'one_vs_average':
        avg_prob_pos_prediction = pos_prediction.sum() / sensitive_group_count.sum()
        dem_disparity = (prob_pos_prediction_per_subgroup - avg_prob_pos_prediction).max()
    elif interpretation == 'one_vs_others':
        others_count = sensitive_group_count.sum() - sensitive_group_count
        others_pos_prediction = pos_prediction.sum() - pos_prediction
        others_prob_pos_prediction_per_subgroup = others_pos_prediction / others_count
        dem_disparity = (prob_pos_prediction_per_subgroup - others_prob_pos_prediction_per_subgroup).max()
    elif interpretation == 'max_vs_min':
        dem_disparity = prob_pos_prediction_per_subgroup.max() - prob_pos_prediction_per_subgroup.min()
    else:
        raise ValueError(f'Unknown interpretation: {interpretation}')
    
    return dem_disparity

import pandas as pd

def sklearn_error_disparity(model, features, sensitives, labels, **kwargs):
    predictions = model.predict(features)
    data = pd.DataFrame(np.c_[predictions, labels, sensitives], columns=['prediction', 'truth', 'sensitive'])
    error_per_subgroup = []
    for z in np.unique(sensitives):
        error_per_subgroup.append(data.query(f'sensitive == {z} and prediction != truth')['prediction'].mean())
        
    error_disparity = np.max(error_per_subgroup) - np.min(error_per_subgroup)
    return error_disparity

def sklearn_equality_of_odds_disparity(model, features, sensitives, labels, **kwargs):
    predictions = model.predict(features)
    data = pd.DataFrame(np.c_[predictions, labels, sensitives], columns=['prediction', 'truth', 'sensitive'])
    Y_set = np.unique(labels)
    Z_set = np.unique(sensitives)

    disparity_list = []
    for z in Z_set:
        for yhat in Y_set:
            for y in Y_set:
                prob_for_z = len(data.query(f'prediction == {yhat} and truth == {y} and sensitive == {z}')) \
                    / len(data.query(f'truth == {y} and sensitive == {z}'))
                prob_but_z = len(data.query(f'prediction == {yhat} and truth == {y} and sensitive != {z}'))\
                    / len(data.query(f'truth == {y} and sensitive != {z}'))
                disparity_list.append(prob_for_z - prob_but_z)

    equality_of_odds_disparity = np.max(disparity_list)
    return equality_of_odds_disparity

def process_data(np_rng, args, log, return_train_test_index=False, train_set=None, test_set=None):

    if train_set is not None and test_set is not None:
        # we assume that the data processing happens in the pytorch dataloaders, and only pass indices, labels and sensitive attributes
        
        # get all labels from torch train set
        train_features = np.arange(len(train_set))
        train_labels = np.array([x[1] for x in train_set]).astype(int).squeeze()
        train_sensitives = np.array([x[2] for x in train_set]).astype(int).squeeze()
        
        # get all labels from torch test set
        test_features = np.arange(len(test_set))
        test_labels = np.array([x[1] for x in test_set]).astype(int).squeeze()
        test_sensitives = np.array([x[2] for x in test_set]).astype(int).squeeze()


    else:
        if args.undersampling_ratio is not None:
            output_path = os.path.join(args.data_path, f'{args.dataset}_fp_undersampled_{args.undersampling_ratio}.npz')
            rus = RandomUnderSampler(sampling_strategy=args.undersampling_ratio, random_state=args.seed)
        else:
            output_path = os.path.join(args.data_path, f'{args.dataset}_fp.npz')    

        if return_train_test_index or not os.path.exists(output_path) or args.undersampling_ratio is not None:
            full_data = GeneralData(path = args.path, random_state=np_rng, sensitive_attributes = args.sensitive_attributes, cols_to_norm = args.cols_to_norm, output_col_name = args.output_col_name, split = args.split)

            if return_train_test_index:
                return full_data.get_train_test_idx()

            dataset_train = full_data.getTrain(return_tensor=False)
            dataset_test = full_data.getTest(return_tensor=False)

            train_features = np.concatenate([x[0][None, :] for x in dataset_train], axis=0)
            # train_features = 
            train_labels = np.array([x[2] for x in dataset_train])
            train_sensitives = np.array([x[3] for x in dataset_train])

            test_features = np.concatenate([x[0][None, :] for x in dataset_test], axis=0)
            test_labels = np.array([x[2] for x in dataset_test])
            test_sensitives = np.array([x[3] for x in dataset_test])

            if args.undersampling_ratio is not None:
                # Train
                train_indices = np.arange(len(train_labels))
                # log("Train Before Resampling:\n", pd.Series(train_labels).value_counts())
                train_indices, train_labels = rus.fit_resample(train_indices.reshape(-1, 1), train_labels)
                # log("Train After Resampling:\n", pd.Series(train_labels).value_counts())
                train_features = train_features[train_indices.ravel()]
                train_sensitives = train_sensitives[train_indices.ravel()]

                # Test
                test_indices = np.arange(len(test_labels))
                # log("Test Before Resampling:\n", pd.Series(test_labels).value_counts())
                test_indices, test_labels = rus.fit_resample(test_indices.reshape(-1, 1), test_labels)
                # log("Test After Resampling:\n", pd.Series(test_labels).value_counts())
                test_features = test_features[test_indices.ravel()]
                test_sensitives = test_sensitives[test_indices.ravel()]
                if return_train_test_index:
                    full_data_train_indices, full_data_test_indices = full_data.get_train_test_idx()
                    return full_data_train_indices[train_indices], full_data_test_indices[test_indices]
            else:
                # we do not want to cache the subsampled data
                np.savez(output_path, train_features=train_features, train_labels=train_labels, train_sensitives=train_sensitives, test_features=test_features, test_labels=test_labels, test_sensitives=test_sensitives)

        else:
            log("Dataset already processed. Loading from file")
            data = np.load(output_path)
            train_features = data['train_features']
            train_labels = data['train_labels']
            train_sensitives = data['train_sensitives']
            test_features = data['test_features']
            test_labels = data['test_labels']
            test_sensitives = data['test_sensitives']

    log("Train Label Top-Count/All Ratio: ", np.unique(train_labels, return_counts=True)[1].max() / len(train_labels))
    return train_features, train_labels, train_sensitives, test_features, test_labels, test_sensitives

def write_results(args, student_model_validation_accuracy, validation_disparity, achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_disparity):
    if not os.path.exists('results.tsv'):
        with open('results.tsv', 'a') as f:
            f.write('dataset\tmodel\tnum_teachers\tthreshold\tfairness_threshold\tsigma_threshold\tsigma_gnmax\tbudget\tdelta\tseed\tstudent_model_validation_accuracy\tvalidation_disparity\tachieved_eps\tmax_num_query\tnum_queries_answered\tstudent_model_test_accuracy\ttest_disparity\n')

    with open('results.tsv', 'a') as f:
        f.write(f'{args.dataset}\t{args.pate_based_model}\t{args.num_teachers}\t{args.threshold}\t{args.fairness_threshold}\t{args.sigma_threshold}\t{args.sigma_gnmax}\t{args.budget}\t{args.delta}\t{args.seed}\t{student_model_validation_accuracy}\t{validation_disparity}\t{achieved_eps}\t{max_num_query}\t{num_queries_answered}\t{student_model_test_accuracy}\t{test_disparity}\n')