import pandas as pd
import torch
import math
from libauc.datasets import ImbalanceSampler
from libauc.models import DenseNet121
import numpy as np
from losses import pAUCloss
from chexpert import CheXpert
from parameters_dc import para
from datetime import datetime

def individualloss(h):
   return torch.log(1 + torch.exp(-h))

def diff_matrix(x, y):
   x_m = x.repeat(y.shape[0], 1)
   x_m = torch.transpose(x_m, 0, 1)
   y_m = y.repeat(x.shape[0], 1)
   return x_m - y_m
 
def split_pos_index(y):
   index = torch.where(y > 0.1)
   return index

def split_neg_index(y):
   index = torch.where(y < 0.1)
   return index

def loss(y_pred, y_true, lamb):
   index_pos = split_pos_index(y_true)
   index_neg = split_neg_index(y_true)
   y_pred_pos = y_pred[index_pos]
   y_pred_neg = y_pred[index_neg]
   m_1 = diff_matrix(y_pred_pos, y_pred_neg)
   loss_1 = individualloss(m_1)
   n_pos_batch = index_pos[0].shape[0]
   lamb_batch = lamb[index_pos[0]]
     
   hinge_1 = loss_1 - lamb_batch.reshape((n_pos_batch, 1))
   loss_1[hinge_1 < 0] = 0
   return loss_1   

def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def partial_AUC(scores, labels, k1, k2, alpha, beta):
    n = scores.shape[0]
    n_pos = int(sum(labels))
    n_neg = n-n_pos

    sorted, indices = torch.sort(scores, descending=True)
    re_index = indices
    re_label = labels[re_index]

    tp = torch.cumsum(re_label, dim=0)
    fp = torch.arange(1, n+1)-tp

    fpr = fp / n_neg
    tpr = tp / n_pos

    idx = torch.where((fp > k1) & (fp < k2))
    p_fpr = fpr[idx]
    p_tpr = tpr[idx]
    p_fp = fp[idx]
    p_tp = tp[idx]
    n_ptpr = p_tpr.shape[0]

    diff_list = torch.diff(p_fp)
    length = p_tpr.shape[0]
    pauc = torch.matmul(p_tpr[0:length-1], diff_list)/(n_neg*(beta-alpha))

    return pauc

# all paramaters
alpha1 = 0.05
alpha2 = 0.5
SEED = para.seed
BATCH_SIZE = 16
lr_0 = para.lr_0
lr_outer = para.lr_outer
T0 = para.T0
mu = para.mu
inner_iter = 1
num_iter = para.num_iter
class_id = para.class_id # 0:Cardiomegaly, 1:Edema, 2:Consolidation, 3:Atelectasis, 4:Pleural Effusion
root = '../dataset/CheXpert-v1.0-small/'

set_all_seeds(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# dataloader 
IMG_SIZE = 224
n_total = 223414
print("total data = ", n_total)
n_train = int(n_total * 0.9)
n_val = n_total - n_train

trainSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=False, use_frontal=False, image_size=IMG_SIZE, mode='train', seed=SEED, class_index=class_id, split=True, start=0, end=n_train)
valSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=False, use_frontal=False, image_size=IMG_SIZE, mode='train', seed=SEED, class_index=class_id, split=True, start=n_train, end=n_total)
testSet = CheXpert(csv_path=root+'valid.csv',  image_root_path=root, use_upsampling=False, use_frontal=False, image_size=IMG_SIZE, mode='valid', seed=SEED, class_index=class_id)

trainloader = torch.utils.data.DataLoader(trainSet, batch_size=2*BATCH_SIZE,
                                          sampler=ImbalanceSampler(np.array(trainSet._labels_list).flatten().astype(int), 2*BATCH_SIZE, pos_num=BATCH_SIZE),
                                          num_workers=2, pin_memory=True, drop_last=True)

valloader = torch.utils.data.DataLoader(valSet, batch_size=1, num_workers=2, drop_last=False, shuffle=False)
testloader = torch.utils.data.DataLoader(testSet, batch_size=1, num_workers=2, drop_last=False, shuffle=False)

n_test = len(testSet._labels_list)
n_val = len(valSet._labels_list)
n_pos = int(sum(trainSet._labels_list))
n_neg = n_train - n_pos
print(n_pos)
print(n_neg)

n_pos_val = int(sum(valSet._labels_list))
n_neg_val = n_val - n_pos_val

n_pos_test = int(sum(testSet._labels_list))
n_neg_test = n_test-n_pos_test

k1_value = int(math.floor(alpha1 * n_neg))
k2_value = int(math.ceil(alpha2 * n_neg))

k1_value_test = int(math.floor(alpha1 * n_neg_test))
k2_value_test = int(math.ceil(alpha2 * n_neg_test))

k1_value_val = int(math.floor(alpha1 * n_neg_val))
k2_value_val = int(math.ceil(alpha2 * n_neg_val))

# model
set_all_seeds(SEED)
model1 = DenseNet121(pretrained=False, last_activation='sigmoid', activations='relu', num_classes=1)
model2 = DenseNet121(pretrained=False, last_activation='sigmoid', activations='relu', num_classes=1)
model1 = model1.cuda()
model2 = model2.cuda()    

pauc_test_list = []
pauc_val_list = []
time_list = []
data_pass_list = []

# Training
if True:

    # L_AUC
    Loss1 = pAUCloss(n_pos = n_pos, n_neg = n_neg, k_value = k1_value)
    Loss2 = pAUCloss(n_pos = n_pos, n_neg = n_neg, k_value = k2_value)
    
    w1 = list(model1.parameters())
    w2 = list(model2.parameters())

    print ('-'*30)

    # initialize lambda
    lambda1 = torch.zeros(n_pos, dtype=torch.float32, device="cuda", requires_grad=False).cuda()
    lambda2 = torch.zeros(n_pos, dtype=torch.float32, device="cuda", requires_grad=False).cuda()

    s_pos = []
    s_neg = []

    # compute initial partial auc
    model1.eval()
    model2.eval()
    test_pred_pos = []
    test_pred_neg = []
    test_pred = []
    test_true = []
    for k, data in enumerate(testloader):
        test_data, test_targets = data
        test_data = test_data.cuda()
        #outputs = model2(test_data)
        outputs = model1(test_data)
        test_pred.append(outputs.detach())
        test_true.append(test_targets.data)
        if test_targets.data > 0.1:
            test_pred_pos.append(outputs.detach())
        else:
            test_pred_neg.append(outputs.detach())

    z_pos = torch.tensor(test_pred_pos)
    z_neg = torch.tensor(test_pred_neg)
    ttest_pred = torch.tensor(test_pred)
    ttest_true = torch.tensor(test_true)

    n_test_pos = z_pos.shape[0]

    sum_loss = 0
    sum_loss_pauc = 0
    for k in range(n_test_pos):
        diff = z_pos[k] - z_neg
        loss_pauc = 0 - diff
        loss_pauc[loss_pauc <= 0] = 0
        loss_pauc[loss_pauc > 0] = 1
        loss_pauc_sort, indices_auc = torch.sort(loss_pauc, descending=True)
        sum_loss_pauc += torch.sum(loss_pauc_sort[k1_value_test:k2_value_test])
    sum_loss_pauc /= (k2_value_test - k1_value_test) * n_test_pos
    pauc_test = 1 - sum_loss_pauc.data
    pauc_test_list.append(pauc_test.cpu().numpy())
    time_list.append(0)
    data_pass_list.append(0)

    print("initial pauc=", 1-sum_loss_pauc.data)

    # initialize val pauc
    val_pred_pos = []
    val_pred_neg = []
    val_pred = []
    val_true = []
    for k, data in enumerate(valloader):
        val_data, val_targets = data
        val_data = val_data.cuda()
        outputs = model2(val_data)
        val_pred.append(outputs.detach())
        val_true.append(val_targets.data)
        if val_targets.data > 0.1:
            val_pred_pos.append(outputs.detach())
        else:
            val_pred_neg.append(outputs.detach())

    z_pos_val = torch.tensor(val_pred_pos)
    z_neg_val = torch.tensor(val_pred_neg)
    vval_pred = torch.tensor(val_pred)
    vval_true = torch.tensor(val_true)

    pauc_val = partial_AUC(vval_pred, vval_true, k1_value_val, k2_value_val, alpha1, alpha2).data
    pauc_val_list.append(pauc_val.cpu().numpy())

    # initialize z
    zz = []
    for var in list(model1.parameters()): 
        zz.append(torch.zeros(var.shape, dtype=torch.float32, device=device, requires_grad=False).to(device)) 
    
    trainloader_copy = iter(trainloader)

    time_spent = 0
    data_pass = 0

    # outer loop
    for i in range(num_iter):
        lr = lr_0/(i+1)
        print("learning rate:", lr)

        inner_iter = (i+1)**2 * T0

        sum_w1 = []
        sum_w2 = []
        avg_w1 = []
        avg_w2 = []
        for var in list(model1.parameters()):
            sum_w1.append(torch.zeros(var.shape, dtype=torch.float32, device=device, requires_grad=False).to(device))
            sum_w2.append(torch.zeros(var.shape, dtype=torch.float32, device=device, requires_grad=False).to(device))
            avg_w1.append(torch.zeros(var.shape, dtype=torch.float32, device=device, requires_grad=False).to(device))
            avg_w2.append(torch.zeros(var.shape, dtype=torch.float32, device=device, requires_grad=False).to(device))

        sum_lambda1 = torch.zeros(n_pos)
        sum_lambda2 = torch.zeros(n_pos)
        sum_lambda1, sum_lambda2 = sum_lambda1.to(device), sum_lambda2.to(device)

        start_time = datetime.now()
        # inner loop
        for j in range(inner_iter):
           model1.train()
           model2.train()
        
           try:
               data_in, target_in = trainloader_copy.next()
           except:
               trainloader_copy = iter(trainloader)

           # find postive batch data index
           index_pos = split_pos_index(target_in)

           data_in, target_in = data_in.cuda(), target_in.cuda()
           # compute score for 2 models
           logits1 = model1(data_in)
           logits2 = model2(data_in)

           # compute loss
           loss1 = Loss1(logits1, target_in, lambda1)
           loss2 = Loss2(logits2, target_in, lambda2)

           # compute gradients of model parameters
           grads1 = torch.autograd.grad(loss1, model1.parameters())
           grads2 = torch.autograd.grad(loss2, model2.parameters())

           #grad_lamb1 = torch.autograd.grad(loss1, lambda1)
           #grad_lamb2 = torch.autograd.grad(loss2, lambda2)
      
           # update model parameter of each layer of network
           for g, (name, w), z in zip(grads1, model1.named_parameters(), zz):
               w.data -= lr * (g + (w.data - z)/mu)

           # compute sum of model parameters
           for k, w in enumerate(list(model1.parameters())):
               sum_w1[k].data = sum_w1[k].data + w.data
           
           for g, (name, w), z in zip(grads2, model2.named_parameters(), zz):
               w.data -= lr * (g + (w.data - z)/mu)
           
           for k, w in enumerate(list(model2.parameters())):
               sum_w2[k].data = sum_w2[k].data + w.data

           # for test
           #diff_loss = loss2 - loss1
           #print("difference of loss", diff_loss)
           
           # update lambda1, lambda2
           loss_1_dual = loss(logits1, target_in, lambda1)
           loss_2_dual = loss(logits2, target_in, lambda2)
           grad_lamb1 = k1_value/n_neg - torch.count_nonzero(loss_1_dual, dim=1)/BATCH_SIZE
           grad_lamb2 = k2_value/n_neg - torch.count_nonzero(loss_2_dual, dim=1)/BATCH_SIZE
          
           #print(grad_lamb1) 
           lambda1[index_pos[0]].data -= lr * grad_lamb1.data
           lambda2[index_pos[0]].data -= lr * grad_lamb2.data
         
           sum_lambda1 = torch.add(sum_lambda1, lambda1.data)
           sum_lambda2 = torch.add(sum_lambda2, lambda2.data)
           
        # compute average of inner loop parameters
        for k, param in enumerate(sum_w1):
           avg_w1[k].data = sum_w1[k].data / inner_iter
           avg_w2[k].data = sum_w2[k].data / inner_iter
           # update zz
           zz[k].data -= lr_outer * (avg_w1[k]-avg_w2[k])

        end_time = datetime.now()

        # compute time spent
        time_spent += (end_time - start_time).total_seconds()
        time_list.append(time_spent)

        # calculate num of data pass
        data_pass += inner_iter * 2 * BATCH_SIZE
        data_pass_list.append(data_pass / n_train)

        lambda1 = sum_lambda1.data / inner_iter
        lambda2 = sum_lambda2.data / inner_iter    

        train_pred = []
        train_true = []

        # evaluations of test dataset
        model1.eval()
        model2.eval()
        test_pred_pos = []
        test_pred_neg = []
        test_pred = []
        test_true = []
        for k, data in enumerate(testloader):
           test_data, test_targets = data
           test_data = test_data.cuda()
           outputs = model2(test_data)
           test_pred.append(outputs.detach())
           test_true.append(test_targets.data)
           if test_targets.data > 0.1:
             test_pred_pos.append(outputs.detach())
           else:
             test_pred_neg.append(outputs.detach())
           
        z_pos = torch.tensor(test_pred_pos)
        z_neg = torch.tensor(test_pred_neg)
        ttest_pred = torch.tensor(test_pred)
        ttest_true = torch.tensor(test_true)

        n_test_pos = z_pos.shape[0]

        sum_loss = 0
        sum_loss_pauc = 0
        for k in range(n_test_pos):
           diff = z_pos[k] - z_neg
           loss_pauc = 0 - diff
           loss_pauc[loss_pauc <= 0] = 0
           loss_pauc[loss_pauc > 0] = 1
           loss_pauc_sort, indices_auc = torch.sort(loss_pauc, descending=True)
           sum_loss_pauc += torch.sum(loss_pauc_sort[k1_value_test:k2_value_test])
        sum_loss_pauc /= (k2_value_test-k1_value_test) * n_test_pos
        pauc_test = 1 - sum_loss_pauc.data
        pauc_test_list.append(pauc_test.cpu().numpy())
        #pauc = partial_AUC(ttest_pred, ttest_true, k1_value_test, k2_value_test, alpha1, alpha2)

        # evaluation of validation dataset
        val_pred_pos = []
        val_pred_neg = []
        val_pred = []
        val_true = []
        for k, data in enumerate(valloader):
            val_data, val_targets = data
            val_data = val_data.cuda()
            outputs = model2(val_data)
            val_pred.append(outputs.detach())
            val_true.append(val_targets.data)
            if val_targets.data > 0.1:
                val_pred_pos.append(outputs.detach())
            else:
                val_pred_neg.append(outputs.detach())

        z_pos_val = torch.tensor(val_pred_pos)
        z_neg_val = torch.tensor(val_pred_neg)
        vval_pred = torch.tensor(val_pred)
        vval_true = torch.tensor(val_true)

        pauc_val = partial_AUC(vval_pred, vval_true, k1_value_val, k2_value_val, alpha1, alpha2).data
        pauc_val_list.append(pauc_val.cpu().numpy())

        model1.train()
        model2.train()

        print("test partial auc =", 1-sum_loss_pauc.data)

    df1 = pd.DataFrame(pauc_test_list, columns=['pauc_test'])
    df2 = pd.DataFrame(pauc_val_list, columns=['pauc_val'])
    df3 = pd.DataFrame(time_list, columns=['time_pass'])
    df4 = pd.DataFrame(data_pass_list, columns=['data_pass'])

    d1 = df2.join(df1)
    d2 = d1.join(df3)
    d3 = d2.join(df4)
    s = "lr0="+str(para.lr_0)+"_lr_out="+str(para.lr_outer)+"_mu="+str(para.mu)+"_T0="+str(para.T0)+"_numStages="+str(para.num_iter)+"_seed="+str(para.seed)+"_class_id="+str(para.class_id)+".csv"
    d3.to_csv(s)
