import subprocess
import argparse
from pathlib import Path
import numpy as np
import os
from datetime import timedelta
import pandas as pd
import pickle as pkl
from scipy.spatial import distance
from scipy.stats import wasserstein_distance, wasserstein_distance_nd
from sklearn.feature_selection import mutual_info_regression, mutual_info_classif


def load_best_models(file_name):
    df = pd.read_csv(file_name)
    best_models = []
    for index, row in df.iterrows():
        dataset = row['dataset']
        sensitive_attribute = row['sensitive_attr']
        surv_model = row['model']
        fair_model = 'None'
        metric = row['metric']
        hparam_seed = row['hparam_seed']
        seed = '0'
        pretrain = 'True'
        shift = 'None'
        group_shift = 'None'
        best_models.append((dataset, sensitive_attribute, surv_model, fair_model, metric, hparam_seed, seed, pretrain, shift, group_shift))
    return best_models


def get_repr(dataset, sensitive_attribute, surv_model, fair_model, metric, hparams_seed, seed, pretrained, shift, group_shift):
    with open('repr/repr_%s_%s_%s_%s_%s_%s_%s_%s_%s_%s.pkl' % (surv_model, fair_model, dataset, sensitive_attribute, metric, pretrained, shift, group_shift, hparams_seed, seed), 'rb') as f:
        output = pkl.load(f)
    return output['train']


def discretize_samples(x, y, bins=10):
    """
    Discretize two continuous samples into bins.
    Returns joint and marginal distributions.
    """
    # Discretize the two samples
    hist_2d, x_edges, y_edges = np.histogram2d(x, y, bins=bins)
    
    # Normalize to get probabilities
    joint_prob = hist_2d / hist_2d.sum()
    marginal_x = np.sum(joint_prob, axis=1)
    marginal_y = np.sum(joint_prob, axis=0)
    
    return joint_prob, marginal_x, marginal_y

def calculate_nmi(x, y, bins=10):
    """
    Calculate the normalized mutual information between two continuous samples.
    """
    # Discretize the samples
    joint_prob, marginal_x, marginal_y = discretize_samples(x, y, bins=bins)
    
    # Flatten joint probability for mutual information calculations
    joint_flat = joint_prob[joint_prob > 0]
    
    # Compute mutual information
    mutual_info = np.sum(joint_flat * np.log(joint_flat / (np.outer(marginal_x, marginal_y))[joint_prob > 0].flatten()))
    
    # Compute entropy for normalization
    entropy_x = -np.sum(marginal_x[marginal_x > 0] * np.log(marginal_x[marginal_x > 0]))
    entropy_y = -np.sum(marginal_y[marginal_y > 0] * np.log(marginal_y[marginal_y > 0]))
    
    # Normalize mutual information
    nmi = mutual_info / np.sqrt(entropy_x * entropy_y)
    return nmi


def sliced_wasserstein(X, Y, n_proj=100, n_bins=1024):
    X = np.asarray(X)
    Y = np.asarray(Y)
    d = X.shape[1]
    distances = []
    # n_sample = min(X.shape[0], Y.shape[0])
    # X = X[:n_sample]
    # Y = Y[:n_sample]
    for _ in range(n_proj):
        dir = np.random.randn(d)
        dir /= np.linalg.norm(dir)
        # proj_X = X @ dir
        # proj_Y = Y @ dir
        # distances.append(np.mean(np.abs(np.sort(proj_X) - np.sort(proj_Y))))

        # Project the samples
        proj_X = np.sort(X @ dir)
        proj_Y = np.sort(Y @ dir)

        # Interpolate to common bins
        bins = np.linspace(0, 1, n_bins)
        proj_X_interp = np.interp(bins, np.linspace(0, 1, len(proj_X)), proj_X)
        proj_Y_interp = np.interp(bins, np.linspace(0, 1, len(proj_Y)), proj_Y)

        distances.append(np.mean(np.abs(proj_X_interp - proj_Y_interp)))
    return np.mean(distances)



def calculate_bias(data):
    output = dict()
    group_list = np.unique(data['s'])
    repr_list = []
    y_list = []
    d_list = []
    for group in group_list:
        g_idx = data['s'] == group
        repr = data['repr'][g_idx]
        y = data['y'][g_idx]
        d = data['d'][g_idx]
        repr_list.append(repr)
        y_list.append(y)
        d_list.append(d)
    # calculate source of bias: d
    score = np.abs(np.mean(d_list[0]) - np.mean(d_list[1])) / np.mean(np.concatenate(d_list, axis=0))
    output['censoring_disparity'] = score
    print('Source of bias (d): %.4f' % score)
    # calculate source of bias: x
    max_val = max(np.max(repr_list[0]), np.max(repr_list[1]))
    # score = wasserstein_distance_nd(repr_list[0][:500], repr_list[1][:500]) / max_val
    score = sliced_wasserstein(repr_list[0], repr_list[1])
    print('Source of bias (x): %.4f' % score)
    output['repr_disparity'] = score
    # calculate source of bias: y
    max_val = max(max(y_list[0][d_list[0]==1]), max(y_list[1][d_list[1]==1]))
    score = wasserstein_distance(y_list[0][d_list[0]==1], y_list[1][d_list[1]==1]) / max_val
    print('Source of bias (y): %.4f' % score)
    output['tte_disparity'] = score
    # calculate source of bias: MI(x,y)
    mi = np.abs(np.mean([calculate_nmi(repr_list[0][:,i], y_list[0], bins=(100,100)) for i in range(repr_list[0].shape[1])]) - np.mean([calculate_nmi(repr_list[1][:,i], y_list[1], bins=(100,100)) for i in range(repr_list[1].shape[1])]))
    print('Source of bias (MI(x,y)): %.4f' % mi)
    output['nmi_xy_disparity'] = mi
    # calculate source of bias: MI(x,d)
    mi = np.abs(np.mean([calculate_nmi(repr_list[0][:,i], d_list[0], bins=(100,2)) for i in range(repr_list[0].shape[1])]) - np.mean([calculate_nmi(repr_list[1][:,i], d_list[1], bins=(100,2)) for i in range(repr_list[1].shape[1])]))
    print('Source of bias (MI(x,d)): %.4f' % mi)
    output['nmi_xd_disparity'] = mi
    return output


best_model_list = load_best_models('result/tte_model_selection.csv')
score_list = []
for best_model in best_model_list:
    dataset, sensitive_attribute, surv_model, fair_model, metric, hparams_seed, seed, pretrained, shift, group_shift = best_model
    if metric == 'ctd':
        print('Calculating bias for %s' % str(best_model))
        data = get_repr(dataset, sensitive_attribute, surv_model, fair_model, metric, hparams_seed, seed, pretrained, shift, group_shift)
        score = calculate_bias(data)
        score['dataset'] = dataset
        score['sensitive_attr'] = sensitive_attribute
        score['metric'] = metric
        score['fair_model'] = fair_model
        score_list.append(score)
score_list = pd.DataFrame(score_list)
score_list.sort_values(['dataset', 'sensitive_attr', 'metric'], inplace=True)
score_list.to_csv('result/source_of_bias.csv', index=False)
