#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue May 10 20:23:04 2022

@author: yaoyao
"""

import pandas as pd
import torch
import math
from torchvision import transforms, datasets
from data_sampler import ImbalanceSampler
# from libauc.datasets import ImbalanceSampler
# from libauc.models import DenseNet121
from resnet_cifar import ResNet20
from cifar10_LT import IMBALANCECIFAR10
import numpy as np
from losses import pAUC_mini
# from chexpert import CheXpert
from parameters_dc import para
from datetime import datetime

def individualloss(h):
   return torch.log(1 + torch.exp(-h))

def diff_matrix(x, y):
   x_m = x.repeat(y.shape[0], 1)
   x_m = torch.transpose(x_m, 0, 1)
   y_m = y.repeat(x.shape[0], 1)
   return x_m - y_m
 
def split_pos_index(y):
   index = torch.where(y > 0.1)
   return index

def split_neg_index(y):
   index = torch.where(y < 0.1)
   return index

def loss(y_pred, y_true, lamb):
   index_pos = split_pos_index(y_true)
   index_neg = split_neg_index(y_true)
   y_pred_pos = y_pred[index_pos]
   y_pred_neg = y_pred[index_neg]
   m_1 = diff_matrix(y_pred_pos, y_pred_neg)
   loss_1 = individualloss(m_1)
   n_pos_batch = index_pos[0].shape[0]
   lamb_batch = lamb[index_pos[0]]
     
   hinge_1 = loss_1 - lamb_batch.reshape((n_pos_batch, 1))
   loss_1[hinge_1 < 0] = 0
   return loss_1   

def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def partial_AUC(scores, labels, k1, k2, alpha, beta):
    n = scores.shape[0]
    n_pos = int(sum(labels))
    n_neg = n-n_pos

    sorted, indices = torch.sort(scores, descending=True)
    re_index = indices
    re_label = labels[re_index]

    tp = torch.cumsum(re_label, dim=0)
    fp = torch.arange(1, n+1)-tp

    fpr = fp / n_neg
    tpr = tp / n_pos

    idx = torch.where((fp > k1) & (fp < k2))
    p_fpr = fpr[idx]
    p_tpr = tpr[idx]
    p_fp = fp[idx]
    p_tp = tp[idx]
    n_ptpr = p_tpr.shape[0]

    diff_list = torch.diff(p_fp)
    length = p_tpr.shape[0]
    
    diff_list = diff_list.float()
    p_tpr = p_tpr.float()
    
    pauc = torch.matmul(p_tpr[0:length-1], diff_list)/(n_neg*(beta-alpha))

    return pauc

# all paramaters
alpha1 = 0.05
alpha2 = 0.5
SEED = para.seed
BATCH_SIZE = 64
lr_0 = para.lr_0
num_iter = para.num_iter

root = '../dataset/cifar-10-batches-py/'
# root = '../dataset/CheXpert-v1.0-small/'
# root = '../CheXpert/CheXpert-v1.0-small/'
# root = './CheXpert-v1.0-small/'
#root = '/Users/yaoyao/Desktop/DC_deep/CheXpert-v1.0-small/'
#root = '/dual_data/not_backed_up/CheXpert/CheXpert-v1.0/'

#torch.manual_seed(SEED)
set_all_seeds(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# dataloader 
n_total = 12406
print("total data = ", n_total)
n_train = int(n_total * 0.9)
n_val = n_total - n_train

# random split training and validation data
index = torch.randperm(n_total)
train_index = index[0:n_train]
val_index = index[n_train:n_total]

transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

transform_val = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

trainSet = IMBALANCECIFAR10(root=root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
                            transform=transform_train, target_transform=None, download=False, index=train_index, split=True)
valSet = IMBALANCECIFAR10(root=root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
                            transform=transform_train, target_transform=None, download=False, index=val_index, split=True)
testSet = IMBALANCECIFAR10(root=root, imb_type='exp', imb_factor=0.01, rand_number=0, train=False,
                            transform=transform_val, target_transform=None, download=False)

trainloader = torch.utils.data.DataLoader(trainSet, batch_size=2*BATCH_SIZE,
                                          sampler=ImbalanceSampler(np.array(trainSet.targets).flatten().astype(int), 2*BATCH_SIZE, pos_num=BATCH_SIZE),
                                          num_workers=2, pin_memory=True, drop_last=True)
valloader = torch.utils.data.DataLoader(valSet, batch_size=1, num_workers=2, drop_last=False, shuffle=False)
testloader = torch.utils.data.DataLoader(testSet, batch_size=1, num_workers=2, drop_last=False, shuffle=False)

n_test = len(testSet.targets)
n_val = len(valSet.targets)
n_pos = int(sum(trainSet.targets))
n_neg = n_train - n_pos
print(n_pos)
print(n_neg)

n_pos_val = int(sum(valSet.targets))
n_neg_val = n_val - n_pos_val

n_pos_test = int(sum(testSet.targets))
n_neg_test = n_test-n_pos_test

k1_value = int(math.floor(alpha1 * n_neg))
k2_value = int(math.ceil(alpha2 * n_neg))

k1_value_test = int(math.floor(alpha1 * n_neg_test))
k2_value_test = int(math.ceil(alpha2 * n_neg_test))

k1_value_val = int(math.floor(alpha1 * n_neg_val))
k2_value_val = int(math.ceil(alpha2 * n_neg_val))

# model
set_all_seeds(SEED)
model = ResNet20(pretrained=False, last_activation='sigmoid', activations='relu', num_classes=1)
model = model.cuda()      

pauc_test_list = []
pauc_val_list = []
data_pass_list = []

# Training
if True:

    # loss_pAUC
    Loss = pAUC_mini(alpha1 = alpha1, alpha2 = alpha2, num_neg = BATCH_SIZE)
    
    w = list(model.parameters())

    print ('-'*30)

    s_pos = []
    s_neg = []

    # compute initial partial auc
    model.eval()
    test_pred_pos = []
    test_pred_neg = []
    test_pred = []
    test_true = []
    for k, data in enumerate(testloader):
        test_data, test_targets = data
        test_data = test_data.cuda()
        outputs = model(test_data)
        test_pred.append(outputs.detach())
        test_true.append(test_targets.data)
        if test_targets.data > 0.1:
            test_pred_pos.append(outputs.detach())
        else:
            test_pred_neg.append(outputs.detach())

    z_pos = torch.tensor(test_pred_pos)
    z_neg = torch.tensor(test_pred_neg)
    ttest_pred = torch.tensor(test_pred)
    ttest_true = torch.tensor(test_true)

    n_test_pos = z_pos.shape[0]

    sum_loss = 0
    sum_loss_pauc = 0
    for k in range(n_test_pos):
        diff = z_pos[k] - z_neg
        loss_pauc = 0 - diff
        loss_pauc[loss_pauc <= 0] = 0
        loss_pauc[loss_pauc > 0] = 1
        loss_pauc_sort, indices_auc = torch.sort(loss_pauc, descending=True)
        sum_loss_pauc += torch.sum(loss_pauc_sort[k1_value_test:k2_value_test])
    sum_loss_pauc /= (k2_value_test - k1_value_test) * n_test_pos
    pauc_test = 1 - sum_loss_pauc.data
    pauc_test_list.append(pauc_test.cpu().numpy())
    data_pass_list.append(0)

    print("initial pauc=", 1-sum_loss_pauc.data)

    # initialize val pauc
    # val_pred_pos = []
    # val_pred_neg = []
    # val_pred = []
    # val_true = []
    # for k, data in enumerate(valloader):
    #     val_data, val_targets = data
    #     val_data = val_data.cuda()
    #     outputs = model(val_data)
    #     val_pred.append(outputs.detach())
    #     val_true.append(val_targets.data)
    #     if val_targets.data > 0.1:
    #         val_pred_pos.append(outputs.detach())
    #     else:
    #         val_pred_neg.append(outputs.detach())

    # z_pos_val = torch.tensor(val_pred_pos)
    # z_neg_val = torch.tensor(val_pred_neg)
    # vval_pred = torch.tensor(val_pred)
    # vval_true = torch.tensor(val_true)

    # pauc_val = partial_AUC(vval_pred, vval_true, k1_value_val, k2_value_val, alpha1, alpha2).data
    # pauc_val_list.append(pauc_val.cpu().numpy())
    pauc_val_list.append(0)    

    trainloader_copy = iter(trainloader)

    data_pass = 0
    data_pass2 = 0
    num_epoch = 0

    # outer loop
    for i in range(num_iter):
        lr = lr_0
        
        model.train()
        
        try:
               data_in, target_in = trainloader_copy.next()
        except:
               trainloader_copy = iter(trainloader)

        # find postive batch data index
        index_pos = split_pos_index(target_in)
        index_neg = split_neg_index(target_in)

        data_in, target_in = data_in.cuda(), target_in.cuda()
        # compute score
        logits = model(data_in)
        logits_pos = logits[index_pos]
        logits_neg = logits[index_neg]

        # compute loss
        loss_mini = Loss(logits_pos, logits_neg)

        # compute gradients of model parameters
        grads = torch.autograd.grad(loss_mini, model.parameters())
   
        # update model parameter of each layer of network
        for g, (name, w) in zip(grads, model.named_parameters()):
            w.data -= lr * g

        # calculate num of data pass
        data_pass += 2 * BATCH_SIZE
        data_pass2 += 2 * BATCH_SIZE
        # data_pass_list.append(data_pass / n_train)

        
        if data_pass / n_train > 1:
            data_pass = 0
            num_epoch += 1
            lr = lr_0/math.sqrt(num_epoch)
            data_pass_list.append(data_pass2 / n_train)
            
            # evaluations of test dataset
            model.eval()
            test_pred_pos = []
            test_pred_neg = []
            test_pred = []
            test_true = []
            for k, data in enumerate(testloader):
               test_data, test_targets = data
               test_data = test_data.cuda()
               outputs = model(test_data)
               test_pred.append(outputs.detach())
               test_true.append(test_targets.data)
               if test_targets.data > 0.1:
                 test_pred_pos.append(outputs.detach())
               else:
                 test_pred_neg.append(outputs.detach())
               
            z_pos = torch.tensor(test_pred_pos)
            z_neg = torch.tensor(test_pred_neg)
            ttest_pred = torch.tensor(test_pred)
            ttest_true = torch.tensor(test_true)
    
            n_test_pos = z_pos.shape[0]
    
            sum_loss = 0
            sum_loss_pauc = 0
            for k in range(n_test_pos):
               diff = z_pos[k] - z_neg
               loss_pauc = 0 - diff
               loss_pauc[loss_pauc <= 0] = 0
               loss_pauc[loss_pauc > 0] = 1
               loss_pauc_sort, indices_auc = torch.sort(loss_pauc, descending=True)
               sum_loss_pauc += torch.sum(loss_pauc_sort[k1_value_test:k2_value_test])
            sum_loss_pauc /= (k2_value_test-k1_value_test) * n_test_pos
            pauc_test = 1 - sum_loss_pauc.data
            pauc_test_list.append(pauc_test.cpu().numpy())
            #pauc = partial_AUC(ttest_pred, ttest_true, k1_value_test, k2_value_test, alpha1, alpha2)
    
            # evaluation of validation dataset
            # val_pred_pos = []
            # val_pred_neg = []
            # val_pred = []
            # val_true = []
            # for k, data in enumerate(valloader):
            #     val_data, val_targets = data
            #     val_data = val_data.cuda()
            #     outputs = model(val_data)
            #     val_pred.append(outputs.detach())
            #     val_true.append(val_targets.data)
            #     if val_targets.data > 0.1:
            #         val_pred_pos.append(outputs.detach())
            #     else:
            #         val_pred_neg.append(outputs.detach())
    
            # z_pos_val = torch.tensor(val_pred_pos)
            # z_neg_val = torch.tensor(val_pred_neg)
            # vval_pred = torch.tensor(val_pred)
            # vval_true = torch.tensor(val_true)
    
            # pauc_val = partial_AUC(vval_pred, vval_true, k1_value_val, k2_value_val, alpha1, alpha2).data
            # pauc_val_list.append(pauc_val.cpu().numpy())
            pauc_val_list.append(0)
    
            model.train()

            # print("val partial auc = ", pauc_val)
            print("test partial auc =", 1-sum_loss_pauc.data)

    df1 = pd.DataFrame(pauc_test_list, columns=['pauc_test'])
    df2 = pd.DataFrame(pauc_val_list, columns=['pauc_val'])
    df3 = pd.DataFrame(data_pass_list, columns=['data_pass'])

    d1 = df2.join(df1)
    d2 = d1.join(df3)
    s = "mini_lr0="+str(para.lr_0)+"_num_iter="+str(para.num_iter)+"_seed="+str(para.seed)+".csv"
    d2.to_csv(s)
