import os
import torch
import torch.optim as optim
import numpy as np
from tqdm import tqdm
import pickle
import random
import time
from fairtorch_local import DemographicParityLoss, EqualiedOddsLoss
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score
from sklearn.cluster import KMeans

from meta_fairness import equal_opp_binary, fair_loss_binary, avg_odds_binary, acc_diff_binary, disparate_impact_binary

def get_weights(labels, groups):
    weights = torch.ones(labels.size())
    # ACSIncome
    weights[torch.logical_and(labels==0, groups==0)] = 1.103221
    weights[torch.logical_and(labels==1, groups==0)] = 0.881572
    weights[torch.logical_and(labels==0, groups==1)] = 0.905575
    weights[torch.logical_and(labels==1, groups==1)] = 1.176071
    # CelebA
    # weights[torch.logical_and(labels==0, groups==0)] = 0.866271
    # weights[torch.logical_and(labels==1, groups==0)] = 1.201112
    # weights[torch.logical_and(labels==0, groups==1)] = 1.125507
    # weights[torch.logical_and(labels==1, groups==1)] = 0.892101
    # ACSEmployment
    # weights[torch.logical_and(labels==0, groups==0)] = 1.066670
    # weights[torch.logical_and(labels==1, groups==0)] = 0.930995
    # weights[torch.logical_and(labels==0, groups==1)] = 0.943020
    # weights[torch.logical_and(labels==1, groups==1)] = 1.077683

    return weights

def train_mlp(model, trainloader, savefldr,
                     validloader=None, save_best=True, save_ite=True,
                     cuda=True, losstype='ce', epochs=100, lr=0.001, lmbd=1,
                     reweigh=False, gradnoise=False):

    torch.manual_seed(0)

    if reweigh:
        criterion = torch.nn.CrossEntropyLoss(reduction='none')
    else:
        criterion = torch.nn.CrossEntropyLoss()
    dp_loss = EqualiedOddsLoss(sensitive_classes=[0, 1], alpha=lmbd)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    best_score = -1000
    best_epoch = -1

    ## Start Training
    for epoch in tqdm(range(epochs)):
        if save_ite:
            torch.save(model, savefldr + '%d.pth' % epoch)

        model.train()

        torch.manual_seed(epoch)
        trainloader_local = trainloader
        for i, data in enumerate(trainloader_local, 0):
            inputs, labels, groups = data
            inputs, labels, groups = inputs.squeeze(), labels.squeeze(), groups.squeeze()
            if cuda:
                inputs, labels, groups = inputs.cuda(), labels.cuda(), groups.cuda()

            optimizer.zero_grad()

            outputs = model(inputs.float())
            if losstype=='ce':
                loss = criterion(outputs, labels.long())
                if reweigh:
                    weights = get_weights(labels, groups)
                    if cuda:
                        weights = weights.cuda()
                    loss = torch.sum(loss * weights)/torch.sum(weights)

            elif losstype=='fairce':
                lossce = criterion(outputs, labels.long())

                outprob = torch.nn.functional.softmax(outputs, dim=-1)[:, 1]
                lossfair = dp_loss(inputs, outprob, groups, labels.long())
                if torch.isnan(lossfair):
                    loss = lossce
                else:
                    loss = lossce + lossfair

            loss.backward()
            if gradnoise:
                for p in model.parameters():
                    p.grad += torch.randn(p.grad.shape).cuda()
            optimizer.step()

        if validloader is not None:
            fscore, unfair_per = test_mlp(model, validloader, cuda=cuda)

            curr_score = fscore
            if curr_score > best_score:
                best_epoch = epoch
                best_score = curr_score
                best_unfair = unfair_per
                if save_best:
                    torch.save(model, savefldr + 'best.pth')

    if validloader is not None:
        print("Final Score : ", fscore)
        print("Final Unfair : ", unfair_per)

    torch.manual_seed(int(time.time()))

    return model

def test_mlp(model, validloader, cuda=True, return_outputs=False, return_logits=False, return_class_scores=False, return_class_corr_prob=False, fairness_criteria='eqopp'):

    model.eval()
    label_arr = []
    pred_arr = []
    predonehot_arr = []
    group_arr = []

    with torch.no_grad():
        for data in validloader:
            inputs, labels, groups = data
            if cuda:
                inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs.float())
            out_logits = torch.nn.functional.softmax(outputs.data, 1)
            _, predicted = torch.max(outputs.data, 1)
            pred_arr.extend(predicted.cpu().detach().numpy())
            label_arr.extend(labels.cpu().detach().numpy())
            group_arr.extend(groups.cpu().detach().numpy())
            predonehot_arr.extend(out_logits.cpu().detach().numpy())

    label_arr = np.array(label_arr)
    pred_arr = np.array(pred_arr)
    group_arr = np.array(group_arr)
    predonehot_arr = np.array(predonehot_arr)

    # print("Hello")
    # exit()

    conf_mat = confusion_matrix(label_arr, pred_arr)
    fscore = f1_score(label_arr, pred_arr, average='macro')
    acc = accuracy_score(label_arr, pred_arr)

    if return_class_scores:
        if fairness_criteria=='eqopp':
            unfair_per, class_scores = equal_opp_binary(group_arr, label_arr, pred_arr, return_class_scores=True)
        elif fairness_criteria=='avgodds':
            unfair_per, class_scores = avg_odds_binary(group_arr, label_arr, pred_arr, return_class_scores=True)
        elif fairness_criteria=='accdiff':
            unfair_per, class_scores = acc_diff_binary(group_arr, label_arr, pred_arr, return_class_scores=True)
        elif fairness_criteria=='disimp':
            unfair_per, class_scores = disparate_impact_binary(group_arr, label_arr, pred_arr, return_class_scores=True)
    else:
        if fairness_criteria=='eqopp':
            unfair_per = equal_opp_binary(group_arr, label_arr, pred_arr)
        elif fairness_criteria=='avgodds':
            unfair_per = avg_odds_binary(group_arr, label_arr, pred_arr)
        elif fairness_criteria=='accdiff':
            unfair_per = acc_diff_binary(group_arr, label_arr, pred_arr)
        elif fairness_criteria=='disimp':
            unfair_per = disparate_impact_binary(group_arr, label_arr, pred_arr)

    out_tuple = []
    if fairness_criteria is None:
        out_tuple.extend([fscore*100])
    else:
        out_tuple.extend([fscore*100, unfair_per*100])
    if return_outputs:
        if return_logits:
            out_tuple.extend([label_arr, predonehot_arr, group_arr])
        else:
            out_tuple.extend([label_arr, pred_arr, group_arr])
    if return_class_scores:
        out_tuple.extend([class_scores['prec_z']*100, class_scores['rec_z']*100, class_scores['prec_o']*100, class_scores['rec_o']*100])
    if return_class_corr_prob:
        label_arr = label_arr.astype(int)
        corr_prob = np.array([ele[i] for i, ele in zip(label_arr, predonehot_arr)])
        label_arr_z = label_arr[group_arr==0]
        label_arr_o = label_arr[group_arr==1]

        out_tuple.extend(
            [np.sum(corr_prob[group_arr==0]),
             np.sum(corr_prob[group_arr==1]),
             np.sum(corr_prob[group_arr==0][label_arr_z==1]>=0.5),
             np.sum(corr_prob[group_arr==1][label_arr_o==1]>=0.5),
             np.sum(corr_prob[group_arr==0][label_arr_z==0]>=0.5),
             np.sum(corr_prob[group_arr==1][label_arr_o==0]>=0.5)])

    return tuple(out_tuple)


def get_savefldr(dataset, protected_class, losstype, weight_init, model_identifier, repeat_num, finetune_loss=None):
    if finetune_loss is not None:
        savefldr = os.path.join("models", dataset, protected_class, losstype, weight_init, "%s_%d" % (model_identifier, repeat_num), finetune_loss, '')
    else:
        savefldr = os.path.join("models", dataset, protected_class, losstype, weight_init, "%s_%d" % (model_identifier, repeat_num), '')
    directory = os.path.dirname(savefldr)
    os.makedirs(directory, exist_ok=True)

    return savefldr

def get_loadfile(dataset, protected_class, losstype, weight_init, model_identifier, repeat_num, epoch, finetune_loss=None):
    if finetune_loss is not None:
        loadfile = os.path.join("models", dataset, protected_class, losstype, weight_init, "%s_%d" % (model_identifier, repeat_num), finetune_loss, str(epoch) + ".pth")
    else:
        loadfile = os.path.join("models", dataset, protected_class, losstype, weight_init, "%s_%d" % (model_identifier, repeat_num), str(epoch) + ".pth")

    return loadfile
