from copy import deepcopy
import os
import numpy as np
from dataloader import GeneralData

def one_hot(x, num_class=2): 
    return np.eye(num_class)[x]

def sklearn_disparity(fairness_metric, model, features, sensitives, labels=None, **kwargs):
    if fairness_metric == 'DemParity':
        return sklearn_dem_disparity(model, features, sensitives, **kwargs)
    elif fairness_metric == 'ErrorParity':
        return sklearn_error_disparity(model, features, sensitives, labels, **kwargs)
    elif fairness_metric == 'EqualityOfOdds':
        return sklearn_equality_of_odds_disparity(model, features, sensitives, labels, **kwargs)
    else:
        raise ValueError(f'Unknown fairness metric: {fairness_metric}')

def sklearn_dem_disparity(model, features, sensitives, interpretation='max_vs_min', **kwargs):
    pos_prediction = model.predict(features) @ one_hot(sensitives, num_class=2)
    sensitive_group_count = one_hot(sensitives, num_class=2).sum(axis=0)
    prob_pos_prediction_per_subgroup = pos_prediction / sensitive_group_count
    if interpretation == 'one_vs_average':
        avg_prob_pos_prediction = pos_prediction.sum() / sensitive_group_count.sum()
        dem_disparity = (prob_pos_prediction_per_subgroup - avg_prob_pos_prediction).max()
    elif interpretation == 'one_vs_others':
        others_count = sensitive_group_count.sum() - sensitive_group_count
        others_pos_prediction = pos_prediction.sum() - pos_prediction
        others_prob_pos_prediction_per_subgroup = others_pos_prediction / others_count
        dem_disparity = (prob_pos_prediction_per_subgroup - others_prob_pos_prediction_per_subgroup).max()
    elif interpretation == 'max_vs_min':
        dem_disparity = prob_pos_prediction_per_subgroup.max() - prob_pos_prediction_per_subgroup.min()
    else:
        raise ValueError(f'Unknown interpretation: {interpretation}')
    
    return dem_disparity

import pandas as pd

def sklearn_error_disparity(model, features, sensitives, labels, **kwargs):
    predictions = model.predict(features)
    data = pd.DataFrame(np.c_[predictions, labels, sensitives], columns=['prediction', 'truth', 'sensitive'])
    error_per_subgroup = []
    for z in np.unique(sensitives):
        error_per_subgroup.append(data.query(f'sensitive == {z} and prediction != truth')['prediction'].mean())
        
    error_disparity = np.max(error_per_subgroup) - np.min(error_per_subgroup)
    return error_disparity

def sklearn_equality_of_odds_disparity(model, features, sensitives, labels, **kwargs):
    predictions = model.predict(features)
    data = pd.DataFrame(np.c_[predictions, labels, sensitives], columns=['prediction', 'truth', 'sensitive'])
    Y_set = np.unique(labels)
    Z_set = np.unique(sensitives)

    disparity_list = []
    for z in Z_set:
        for yhat in Y_set:
            for y in Y_set:
                prob_for_z = len(data.query(f'prediction == {yhat} and truth == {y} and sensitive == {z}')) \
                    / len(data.query(f'truth == {y} and sensitive == {z}'))
                prob_but_z = len(data.query(f'prediction == {yhat} and truth == {y} and sensitive != {z}'))\
                    / len(data.query(f'truth == {y} and sensitive != {z}'))
                disparity_list.append(prob_for_z - prob_but_z)

    equality_of_odds_disparity = np.max(disparity_list)
    return equality_of_odds_disparity

def process_data(rng, args, log):
    output_path = os.path.join(args.data_path, f'{args.dataset}_fp.npz')

    if not os.path.exists(output_path):
        full_data = GeneralData(path = args.path, rng=rng, sensitive_attributes = args.sensitive_attributes, cols_to_norm = args.cols_to_norm, output_col_name = args.output_col_name, split = args.split)
        dataset_train = full_data.getTrain()
        dataset_test = full_data.getTest()

        train_features = np.concatenate([x[0].numpy()[None, :] for x in dataset_train], axis=0)
        train_labels = np.array([x[2] for x in dataset_train])
        train_sensitives = np.array([x[3] for x in dataset_train])

        test_features = np.concatenate([x[0].numpy()[None, :] for x in dataset_test], axis=0)
        test_labels = np.array([x[2] for x in dataset_test])
        test_sensitives = np.array([x[3] for x in dataset_test])

        np.savez(output_path, train_features=train_features, train_labels=train_labels, train_sensitives=train_sensitives, test_features=test_features, test_labels=test_labels, test_sensitives=test_sensitives)

    else:
        log("Dataset already processed. Loading from file")
        data = np.load(output_path)
        train_features = data['train_features']
        train_labels = data['train_labels']
        train_sensitives = data['train_sensitives']
        test_features = data['test_features']
        test_labels = data['test_labels']
        test_sensitives = data['test_sensitives']

    log("Train Label Top-Count/All Ratio: ", np.unique(train_labels, return_counts=True)[1].max() / len(train_labels))
    return train_features, train_labels, train_sensitives, test_features, test_labels, test_sensitives

def write_results(args, student_model_validation_accuracy, validation_disparity, achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_disparity):
    if not os.path.exists('results.tsv'):
        with open('results.tsv', 'a') as f:
            f.write('dataset\tmodel\tnum_teachers\tthreshold\tfairness_threshold\tsigma_threshold\tsigma_gnmax\tbudget\tdelta\tseed\tstudent_model_validation_accuracy\tvalidation_disparity\tachieved_eps\tmax_num_query\tnum_queries_answered\tstudent_model_test_accuracy\ttest_disparity\n')

    with open('results.tsv', 'a') as f:
        f.write(f'{args.dataset}\t{args.pate_based_model}\t{args.num_teachers}\t{args.threshold}\t{args.fairness_threshold}\t{args.sigma_threshold}\t{args.sigma_gnmax}\t{args.budget}\t{args.delta}\t{args.seed}\t{student_model_validation_accuracy}\t{validation_disparity}\t{achieved_eps}\t{max_num_query}\t{num_queries_answered}\t{student_model_test_accuracy}\t{test_disparity}\n')