__author__ = 'Qi'
# Created by on 4/26/22.
import random

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from sklearn.metrics import average_precision_score, accuracy_score, balanced_accuracy_score
from torch import optim
from torch.autograd import Variable
from tqdm import tqdm

from mlp import FC, Net, TwoLayerFC, Adversarial_head
from myDataLoader import meps_dataloader, adult_dataloader, celeba_dataloader
from mylosses import RAAN, construct_neighbourhood_loss, class_balanced_attributes, DRO, feature_neutralization, EOR_regularizer
from myutils import get_accuracy
from parameters import get_parameters
from resnets import resnet18


def train_epoch_progress(args, model, classification_head, adv_head_model, train_iter, loss_function, optimizer, optimizer_head, epoch, RAAN_adversarial_weights, raan_criterion = None):

    model.train()
    classification_head.train()
    adv_head_model.train()

    if epoch >= EPOCHS_first_stage:
        for param in model.parameters():
             param.requires_grad = args.is_feature_trainable

        for param in classification_head.parameters():
            param.requires_grad = True

    avg_loss = 0.0
    truth_res = []
    pred_res = []
    softmax = nn.Softmax(dim=1)
    i = 0
    for batch in tqdm(train_iter, desc='Train epoch ' + str(epoch + 1)):

        # i+=1
        # if i >=5:
        #     break

        # Pre-processing the data
        sent, label, sensitive_label, index = batch[0], batch[1], batch[2], batch[3]
        sent = sent.type(torch.cuda.FloatTensor)
        sent = Variable(sent)
        label = Variable(label)
        truth_res += list(label.data.numpy())
        sensitive_label = Variable(sensitive_label)

        # Get probability of the original input from the biased model
        sent = sent.cuda()
        pred, representation = model(sent)
        # temperature = 1
        pred_softmax = softmax(pred / temperature)  # temperature scaling
        pred_softmax = pred_softmax.cpu()

        # The first training stage
        if epoch < EPOCHS_first_stage:
            # Calculating accuracy
            pred = F.log_softmax(pred, dim=1)
            pred = pred.cpu()
            pred_label = pred.data.max(1)[1].numpy()
            pred_res += [x for x in pred_label]

            # Update the parameters for the whole model
            model.zero_grad()
            loss = loss_function(pred, label)
            avg_loss += loss.item()
            loss.backward()
            optimizer.step()

        # The second training stage
        if epoch >= EPOCHS_first_stage:

            if args.methods == 'RNF':
                # Get the interpolated features and probablity
                neutra_repre_5, neutra_repre_6, neutra_repre_7, neutra_repre_8, neutra_repre_9, neutra_probability5 = feature_neutralization(
                    representation, pred_softmax, label, sensitive_label, args.HIDDEN_DIM)

                alpha = args.neb_tau

                # Using knowledge distillation loss as in equation 1
                pred_neutra = classification_head(neutra_repre_5)
                pred_neutra = softmax(pred_neutra)
                pred_neutra = pred_neutra.cpu()
                loss = kd_loss_function(neutra_probability5, pred_neutra)

                # Add regularization as is done in equation 3
                augmented_list = []
                augmented_list.append(neutra_repre_6)
                augmented_list.append(neutra_repre_7)
                augmented_list.append(neutra_repre_8)
                augmented_list.append(neutra_repre_9)
                difference_sum = 0
                for i in range(4):
                    pred_augmented = classification_head(augmented_list[i])
                    pred_augmented = softmax(pred_augmented)
                    pred_augmented = pred_augmented.cpu()
                    difference_sum += torch.abs(pred_augmented - pred_neutra)

                # Linearly combine two losses as is done in equation 4
                # print(loss.item(), args.neb_tau *torch.sum(difference_sum))
                loss += alpha * torch.sum(difference_sum)
                # print('>>>>> :', loss.item())

                avg_loss += loss.item()
                # avg_loss += loss.data[0]
                # Update the classification head parameters
                classification_head.zero_grad()
                loss.backward()
                optimizer_head.step()

            elif 'SCRAAN' in args.methods:
                criterion = nn.NLLLoss(reduction='none')
                pred = classification_head(representation)
                dynamic_temperature = torch.norm(pred, p=2, dim=1, keepdim=True) ** args.class_tau
                pred_softmax = F.log_softmax(pred / dynamic_temperature, dim=1)
                label = label.cuda()
                sensitive_label = sensitive_label.cuda()
                loss = criterion(pred_softmax, label)

                # print('class_tau', class_tau, 'loss.mean() :', loss.mean().item())
                if args.methods == 'SCRAAN':
                    # emd_loss = construct_neighbourhood_loss(representation, loss, label, sensitive_label,
                    #                                           neb_tau=args.neb_tau)
                    emd_loss = raan_criterion(index, representation, loss, label, sensitive_label,
                                                            neb_tau=args.neb_tau)
                elif args.methods == 'CACE':
                    emd_loss = class_balanced_attributes(loss, sensitive_label)
                elif args.methods == 'DRO':
                    emd_loss = DRO(loss, label, args.tau)
                elif args.methods == 'SCRAAN_ALL':
                    raan_loss = RAAN_adversarial_weights[index] * loss * len(RAAN_adversarial_weights) / len(RAAN_adversarial_weights[index])
                    emd_loss = raan_loss.sum() #* 10
                else:
                    emd_loss = loss.mean()

                model.zero_grad()
                classification_head.zero_grad()
                emd_loss.backward()
                optimizer_head.step()
            elif args.methods == 'adversarial_learning':
                args.adv_reg = args.neb_tau
                label = label.cuda()
                sensitive_label = sensitive_label.cuda()
                pred = classification_head(representation)
                pred_softmax = F.log_softmax(pred)
                adv_pred = adv_head_model(pred, label, sensitive_label)
                adv_pred_softmax  = F.log_softmax(adv_pred)
                criterion = nn.NLLLoss()

                class_loss = criterion(pred_softmax, label)
                adv_loss = criterion(adv_pred_softmax, sensitive_label)
                loss = class_loss - args.adv_reg * adv_loss
                # print(' >>>>> : ', class_loss.item(), adv_loss.item(), loss.item())
                # updates model and adversarial head
                model.zero_grad()
                classification_head.zero_grad()
                adv_head_model.zero_grad()
                loss.backward()
                optimizer_head.step()
                optimizer_adv_head.step()


            elif args.methods == 'EOR':
                criterion = nn.NLLLoss()
                args.eor_reg = args.neb_tau
                pred = classification_head(representation)
                dynamic_temperature = torch.norm(pred, p=2, dim=1, keepdim=True) ** args.class_tau
                pred_softmax = F.log_softmax(pred / dynamic_temperature, dim=1)
                label = label.cuda()
                sensitive_label = sensitive_label.cuda()
                loss = criterion(pred_softmax, label)
                eor_reg = EOR_regularizer(pred, label, sensitive_label)
                # print(' loss: ', loss, eor_reg)
                # print('loss :', loss)
                loss +=args.eor_reg * eor_reg

                model.zero_grad()
                classification_head.zero_grad()
                loss.backward()
                optimizer_head.step()
            elif args.methods == 'CE':
                criterion = nn.NLLLoss()
                pred = classification_head(representation)
                dynamic_temperature = torch.norm(pred, p=2, dim=1, keepdim=True) ** args.class_tau
                pred_softmax = F.log_softmax(pred / dynamic_temperature, dim=1)
                label = label.cuda()
                sensitive_label = sensitive_label.cuda()
                loss = criterion(pred_softmax, label)

                model.zero_grad()
                classification_head.zero_grad()
                loss.backward()
                optimizer_head.step()

            # Calculating accuracy
            pred, representation = model(sent)
            pred = classification_head(representation)
            pred = F.log_softmax(pred, dim=1)
            pred = pred.cpu()
            pred_label = pred.data.max(1)[1].numpy()
            pred_res += [x for x in pred_label]

    avg_loss /= len(train_iter)
    acc = get_accuracy(truth_res, pred_res)

    return avg_loss, acc


def evaluate_first_stage(model, data, loss_function, name):
    model.eval()
    avg_loss = 0.0
    truth_res = []
    pred_res = []
    for batch in data:
        sent, label, sensitive_label = batch[0], batch[1], batch[2]
        sent = sent.type(torch.cuda.FloatTensor)
        label = label.type(torch.cuda.LongTensor)
        sent = Variable(sent)
        label = Variable(label)
        truth_res += list(label.data)
        model.batch_size = len(label.data)
        sent = sent.cuda()
        pred, representation = model(sent)
        pred = F.log_softmax(pred, dim=1)
        pred = pred.cpu()
        pred_label = pred.data.max(1)[1].numpy()
        pred_res += [x for x in pred_label]
        label = label.cpu()
        loss = loss_function(pred, label)
        # avg_loss += loss.data[0]
        avg_loss += loss.item()
    avg_loss /= len(data)
    acc = get_accuracy(truth_res, pred_res)
    print(name + ': loss %.2f acc %.1f' % (avg_loss, acc * 100))
    return acc


def evaluate_second_stage(args, model, classification_head, adv_head_model, data, loss_function, epoch, name):
    model.eval()
    classification_head.eval()
    adv_head_model.eval()

    avg_loss = 0.0
    truth_res = []
    truth_sensitive = []
    pred_res = []
    for batch in data:
        sent, label, sensitive_label = batch[0], batch[1], batch[2]
        sent = sent.type(torch.cuda.FloatTensor)
        label = label.type(torch.cuda.LongTensor)
        sent = Variable(sent)
        label = Variable(label)
        truth_res += list(label.cpu().data.numpy())
        truth_sensitive += list(sensitive_label.cpu().data.numpy())
        model.batch_size = len(label.data)
        sent = sent.cuda()
        pred, representation = model(sent)
        if epoch >= args.first_stage_epochs:
            pred = classification_head(representation)
        pred = F.log_softmax(pred, dim=1)
        pred = pred.cpu()
        pred_label = pred.data.max(1)[1].numpy()
        pred_res += [x for x in pred_label]
        label = label.cpu()
        loss = loss_function(pred, label)
        # avg_loss += loss.data[0]
        avg_loss += loss.item()
    avg_loss /= len(data)

    acc = get_accuracy(truth_res, pred_res)
    acc_sklearn = accuracy_score(truth_res, pred_res)

    ap = average_precision_score(truth_res, pred_res)
    bca = balanced_accuracy_score(truth_res, pred_res)

    pred = pd.DataFrame()
    pred['true_cls'] = truth_res
    pred['pred_cls'] = pred_res
    pred['protected_cls'] = truth_sensitive

    TP_1 = len(pred[(pred['pred_cls'] == 1) & (pred['true_cls'] == 1) & (pred['protected_cls'] == 1)])
    FP_1 = len(pred[(pred['pred_cls'] == 1) & (pred['true_cls'] == 0) & (pred['protected_cls'] == 1)])
    FN_1 = len(pred[(pred['pred_cls'] == 0) & (pred['true_cls'] == 1) & (pred['protected_cls'] == 1)])
    TN_1 = len(pred[(pred['pred_cls'] == 0) & (pred['true_cls'] == 0) & (pred['protected_cls'] == 1)])
    TP_0 = len(pred[(pred['pred_cls'] == 1) & (pred['true_cls'] == 1) & (pred['protected_cls'] == 0)])
    FP_0 = len(pred[(pred['pred_cls'] == 1) & (pred['true_cls'] == 0) & (pred['protected_cls'] == 0)])
    FN_0 = len(pred[(pred['pred_cls'] == 0) & (pred['true_cls'] == 1) & (pred['protected_cls'] == 0)])
    TN_0 = len(pred[(pred['pred_cls'] == 0) & (pred['true_cls'] == 0) & (pred['protected_cls'] == 0)])



    TPR_0 = TP_0 / (TP_0 + FN_0)
    TPR_1 = TP_1 / (TP_1 + FN_1)
    FPR_0 = FP_0 / (FP_0 + TN_0)
    FPR_1 = FP_1 / (FP_1 + TN_1)

    EOD = abs(FPR_0 - FPR_1) + abs(TPR_0 - TPR_1)
    EOP = abs(TPR_0 - TPR_1)

    PR_1 = len(pred[(pred['pred_cls'] == 1) & (pred['protected_cls'] == 1)]) / len(pred[(pred['protected_cls'] == 1)])
    PR_0 = len(pred[(pred['pred_cls'] == 1) & (pred['protected_cls'] == 0)]) / len(pred[(pred['protected_cls'] == 0)])
    DPD = abs(PR_1 - PR_0)

    print (name + ' > : EOD %.3f | DPD %.3f | EOP %.3f ' % (EOD, DPD, EOP))
    print( name + ' >> : loss %.2f | acc %.1f | bacc %.1f | AP %.1f' % (avg_loss, acc * 100, bca * 100, ap * 100))

    print(" >>> Female TPR {:.4f}, FPR {:.4f} and PR {:.4f}".format(TPR_0, FPR_0, PR_0)) # Positive Rate
    print(" >>>> Male TPR, {:.4f}, FPR {:.4f} and PR {:.4f}".format(TPR_1, FPR_1, PR_1))

    # print("Confusion Matrix", confusion_matrix(truth_res, pred_res))

    wandb.log({'loss': avg_loss, 'EOD': EOD, 'DPD': DPD, 'EOP': EOP,  'ACC': acc * 100, 'AP': ap*100, 'BACC': bca*100})
    # print('Second Stage Acc', acc*100)

    np_truth_res, np_truth_sensitive = np.array(truth_res), np.array(truth_sensitive)
    a0c0 = (np_truth_res == 0) &  (np_truth_sensitive == 0)
    a0c1 = (np_truth_res== 1) &  (np_truth_sensitive == 0)
    a1c0 = (np_truth_res == 0) &  (np_truth_sensitive == 1)
    a1c1 = (np_truth_res == 1) &  (np_truth_sensitive== 1)

    # print('a0c0 :',a0c0)

    acc_a0c0 = 1- sum(pred['pred_cls'][a0c0])/len(pred['pred_cls'][a0c0])
    acc_a0c1 = sum(pred['pred_cls'][a0c1])/len(pred['pred_cls'][a0c1])
    acc_a1c0 = 1 - sum(pred['pred_cls'][a1c0])/len(pred['pred_cls'][a1c0])
    acc_a1c1 = sum(pred['pred_cls'][a1c1])/len(pred['pred_cls'][a1c1])

    worst_group_acc = min(acc_a0c0, acc_a0c1, acc_a1c0, acc_a1c1)
    print('worst_group_acc', worst_group_acc*100)

    wandb.log({'worst_group_acc': worst_group_acc*100})

    return acc

def get_the_representations(model, data):
    model.eval()
    truth_res = []
    truth_sensitive = []
    representation_res = []
    index_res = []
    for batch in data:
        sent, label, sensitive_label, index = batch[0], batch[1], batch[2], batch[3]
        sent = sent.type(torch.cuda.FloatTensor)
        label = label.type(torch.cuda.LongTensor)
        sent = Variable(sent)
        label = Variable(label)
        truth_res += list(label.cpu().data.numpy())
        index_res += index.cpu().data.numpy().tolist()
        truth_sensitive += list(sensitive_label.cpu().data.numpy())
        # model.batch_size = len(label.data)
        sent = sent.cuda()
        _, representation = model(sent)
        representation_res += representation.detach().cpu().data.numpy().tolist()
        # print(representation_res, pred, representation)
        # print(len(representation_res), len(representation[0]))
        # print(representation_res)

    return torch.tensor(representation_res), torch.tensor(truth_res), torch.tensor(truth_sensitive), torch.tensor(index_res)




# neb_tau_list = [0.21] # [0.13, 0.18, 0.22] #0.35 # 0.1, 0.2, 0.3, 0.4, 0.5

args = get_parameters()
neb_tau_list = args.neb_tau_list
print('args : ', args)
print('neb_tau_list :', neb_tau_list)
for neb_tau in neb_tau_list:
    args.neb_tau = neb_tau
    wandb.init(project="Fairness-RAAN", entity="qiqi-helloworld", config=args)
    rand_seed = 1
    torch.set_num_threads(2) #8
    torch.manual_seed(rand_seed)
    random.seed(rand_seed)
    EPOCHS_first_stage = args.first_stage_epochs # 5
    EPOCHS_second_stage = args.second_stage_epochs # 4
    USE_GPU = torch.cuda.is_available()
    HIDDEN_DIM = args.HIDDEN_DIM # 50
    #BATCH_SIZE = 128
    BATCH_SIZE = args.train_batch # 390
    # INPUp  T_DIM = 138
    INPUT_DIM = args.INPUT_DIM # 120
    alpha = 0  # This is the value in Equation 4 to control the fairness accuracy trade-off

    # Adam temperature = 5
    temperature = 5 # Temperature scaling
    if args.data == 'meps':
        train_iter, dev_iter, test_iter = meps_dataloader(BATCH_SIZE) # INPUT_DIM = 138, HIDDEN_DIM = 50,  BATCH_SIZE = 64 # Epoch 9, 5 # 11362 Cls: 9401, 1961  Attributes: 7263, 4099
    elif args.data == 'adult':
        train_iter, dev_iter, test_iter = adult_dataloader(BATCH_SIZE) # INPUT_DIM = 120, HIDDEN_DIM = 50, BATCH_SIZE = 64 # Epoch 9, 5 # Cls: 20448, 6684 Attributes: 18320, 8812
    elif args.data == 'celeba':
        train_iter, dev_iter, test_iter = celeba_dataloader(BATCH_SIZE)  # HIDDEN_DIM = 512,  BATCH_SIZE = 390 # Epoch 5, 4

    print(args.data, ': Train ', len(train_iter.dataset), ' | Val : ', len(dev_iter.dataset), ' | Test : ', len(test_iter.dataset))


    lr = args.lr
    lr2= args.lr2 #5e-3

    print('Training...', args.methods, args.neb_tau, args.train_batch)

    adv_head_model = Adversarial_head()
    if args.data == 'celeba':
        model = resnet18(pretrained=False, num_classes=2)
        classification_head = TwoLayerFC(INPUT_DIM, HIDDEN_DIM, 2)
        print('model is resnet')
    else:
        model = Net(INPUT_DIM, HIDDEN_DIM, 2)
        print('model is mlp')
        classification_head = FC(INPUT_DIM, HIDDEN_DIM, 2)
    if USE_GPU:
        model = model.cuda()
        adv_head_model = adv_head_model.cuda()
        classification_head = classification_head.cuda()
    #
    if args.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        optimizer_head = optim.SGD([{'params': model.parameters()}, {'params':classification_head.parameters()}], lr = lr2, momentum=0.9) #[{'params': model.parameters()},{'params':
    elif args.optimizer == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=lr)
        optimizer_head = optim.Adam([{'params': model.parameters()},{'params':classification_head.parameters()}], lr=lr2)
        #optimizer_head = optim.Adam(classification_head.parameters(), lr=lr2)

    raan_criterion = None
    if args.methods == 'SCRAAN':
        raan_criterion  =  RAAN(args.gamma,  data_size = len(train_iter.dataset))
        print(raan_criterion)


    optimizer_adv_head = optim.SGD(adv_head_model.parameters(), lr = lr2, momentum=0.9)

    loss_function = nn.NLLLoss()
    kd_loss_function = nn.MSELoss()

    best_test_acc, best_epoch = 0, 0
    RAAN_adversarial_weights = None
    for epoch in range(EPOCHS_first_stage + EPOCHS_second_stage):
        # if epoch == EPOCHS_first_stage and args.methods == 'SCRAAN_ALL':
        #     if data == 'celeba':
        #         tmp_train_iter, _, _ = celeba_dataloader(64)
        #     else:
        #         tmp_train_iter = train_iter
        #     xtime = time.time()
        #     print(">>>>>>>>>>>  ")
        #     representation, label, sensitive_label, index = get_the_representations(model, tmp_train_iter)
        #     print("<<<<<<<<<<< : Calculate the Similarity")
        #     RAAN_adversarial_weights = calculates_p_AAN(representation, label, sensitive_label, index, neb_tau=neb_tau)
        #     print("Time of calculates p : ", (time.time() - xtime) / 60)


            # if epoch == EPOCHS_first_stage+EPOCHS_second_stage-1:
            #     print(">>>>>>> GET REPRESENTATION: ")
            #     representation, label, sensitive_label, index = get_the_representations(model, test_iter)
            #     print(representation.size())
            #
            #      #tsne = TSNE(n_components=2)
            #      # X_test_kernel_pca = tsne.fit_transform(representation[0:1024])
            #
            #     # MEPS:
            #     # kernel_pca = KernelPCA(
            #     #      n_components=2, kernel="rbf", gamma= 0.3, fit_inverse_transform=True, alpha=0.1
            #     #  )
            #
            #     # Adult
            #     kernel_pca = KernelPCA(
            #          n_components=2, kernel="rbf", gamma=2, fit_inverse_transform=True, alpha=0.2
            #      )
            #     index = np.load('Adult_index.npy')
            #     index_partial = np.load('Adult_index_partial.npy')
            #     trimed_index = np.load('Adult_trimed_index.npy')
            #     X_test_kernel_pca = kernel_pca.fit(representation).transform(representation)
            #     plt.figure()
            #     plt.scatter(X_test_kernel_pca[index, 0][trimed_index], X_test_kernel_pca[index, 1][trimed_index], c = sensitive_label[index][trimed_index])
            #     plt.xlim(-0.8, 0.8)
            #     plt.ylim(-0.8, 0.8)
            #     plt.grid(linestyle='dotted')
            #     plt.ylabel("Principal component #1")
            #     plt.xlabel("Principal component #0")
            #     plt.title("CE Representation Projection of \n testing data using PCA")
            #     plt.savefig('Adult-CE-Kernel_PCA_sensitivity_0.3_last_epoch.png')
            #

        avg_loss, acc = train_epoch_progress(args, model, classification_head, adv_head_model, train_iter, loss_function, optimizer,
                                                 optimizer_head, epoch, RAAN_adversarial_weights, raan_criterion = raan_criterion)


        acc = evaluate_second_stage(args, model, classification_head, adv_head_model, test_iter, loss_function, epoch, 'Final Test')

        tqdm.write('Train: loss %.2f acc %.1f' % (avg_loss, acc * 100))






    # acc = evaluate_second_stage(model, classification_head, test_iter, loss_function, 'Final Test')
            # if epoch >= EPOCHS_first_stage and acc > best_test_acc :
            #     best_epoch = epoch
            #     best_test_acc = acc
            # if epoch >= EPOCHS_first_stage:
            #    acc = evaluate_second_stage(model, classification_head, test_iter, loss_function, 'Final Test')

        # dev_acc = evaluate_second_stage(model, classification_head, dev_iter, loss_function, 'Dev')
        # torch.save(model, 'model.pkl')
        # torch.save(classification_head, 'classification_head.pkl')


    print("End of training method : ", args.methods, 'neb_tau : ', args.neb_tau, 'class_tau : ', args.class_tau,
           'optimizer :', args.optimizer)
    print('Data : ', args.data, 'rand_seed :', rand_seed)
    print('lr :', lr, "lr2: ", lr2, 'best_test_acc :', best_test_acc, 'best_epoch :', best_epoch)
    print("Requires Gradient :", args.is_feature_trainable, 'alpha:', alpha, 'temperature:', temperature)