import torch
import pandas as pd
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from libauc.models import DenseNet121
from libauc.datasets import ImbalanceSampler
from chexpert import CheXpert
import math
from torch import optim
import numpy as np
from sklearn.metrics import roc_auc_score

from losses import AUCMLoss, CrossEntropyBinaryLoss
from parameters_dc import para

def diff_matrix(x, y):
   x_m = x.repeat(y.shape[0], 1)
   x_m = torch.transpose(x_m, 0, 1)
   y_m = y.repeat(x.shape[0], 1)
   return x_m - y_m
 
def split_pos_index(y):
   index = torch.where(y > 0.1)
   return index

def split_neg_index(y):
   index = torch.where(y < 0.1)
   return index

def partial_AUC(scores, labels, k1, k2, alpha, beta):
    n = scores.shape[0]
    n_pos = int(sum(labels))
    n_neg = n-n_pos

    sorted, indices = torch.sort(scores, descending=True)
    re_index = indices
    re_label = labels[re_index]

    tp = torch.cumsum(re_label, dim=0)
    fp = torch.arange(1, n+1)-tp

    fpr = fp / n_neg
    tpr = tp / n_pos

    idx = torch.where((fp > k1) & (fp < k2))
    p_fpr = fpr[idx]
    p_tpr = tpr[idx]
    p_fp = fp[idx]
    p_tp = tp[idx]
    n_ptpr = p_tpr.shape[0]

    diff_list = torch.diff(p_fp)
    length = p_tpr.shape[0]
    pauc = torch.matmul(p_tpr[0:length-1], diff_list)/(n_neg*(beta-alpha))

    return pauc

def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# all paramaters
alpha1 = 0.05
alpha2 = 0.5
BATCH_SIZE = 16
class_id = para.class_id # 0:Cardiomegaly, 1:Edema, 2:Consolidation, 3:Atelectasis, 4:Pleural Effusion
root = '../dataset/CheXpert-v1.0-small/'

num_iter = para.num_iter
SEED = para.seed
lr = para.lr

set_all_seeds(SEED)

print(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# dataloader 
IMG_SIZE = 224
n_total = 223414
print("total data = ", n_total)
n_train = int(n_total * 0.9)
n_val = n_total - n_train

trainSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=False, use_frontal=False, image_size=IMG_SIZE, mode='train', seed=SEED, class_index=class_id, split=True, start=0, end=n_train)
valSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=False, use_frontal=False, image_size=IMG_SIZE, mode='train', seed=SEED, class_index=class_id, split=True, start=n_train, end=n_total)
testSet = CheXpert(csv_path=root+'valid.csv', image_root_path=root, use_upsampling=False, use_frontal=False, image_size=IMG_SIZE, mode='valid', seed=SEED, class_index=class_id)

trainloader = torch.utils.data.DataLoader(trainSet, batch_size=2*BATCH_SIZE,
                                          sampler=ImbalanceSampler(np.array(trainSet._labels_list).flatten().astype(int), 2*BATCH_SIZE, pos_num=BATCH_SIZE),
                                          num_workers=2, pin_memory=True, drop_last=True)
valloader = torch.utils.data.DataLoader(valSet, batch_size=1, num_workers=2, drop_last=False, shuffle=False)
testloader = torch.utils.data.DataLoader(testSet, batch_size=1, num_workers=2, drop_last=False, shuffle=False)

n_test = len(testSet._labels_list)
n_val = len(valSet._labels_list)
n_pos = int(sum(trainSet._labels_list))
n_neg = n_train - n_pos
print(n_pos)
print(n_neg)

n_pos_val = int(sum(valSet._labels_list))
n_neg_val = n_val - n_pos_val

n_pos_test = int(sum(testSet._labels_list))
n_neg_test = n_test-n_pos_test

k1_value = int(math.floor(alpha1 * n_neg))
k2_value = int(math.ceil(alpha2 * n_neg))

k1_value_test = int(math.floor(alpha1 * n_neg_test))
k2_value_test = int(math.ceil(alpha2 * n_neg_test))

k1_value_val = int(math.floor(alpha1 * n_neg_val))
k2_value_val = int(math.ceil(alpha2 * n_neg_val))

# model
model = DenseNet121(pretrained=False, last_activation='sigmoid', activations='relu', num_classes=1)
model = model.cuda()   
   
pauc_test_list = []
pauc_val_list = []
data_pass_list = []
    
# Training
if True:
    # L_AVG
    Loss_CE = CrossEntropyBinaryLoss()
    optimizer = optim.Adam(model.parameters(), 
                           lr=lr,
                           weight_decay=0)

    print ('-'*30)

    trainloader_copy = iter(trainloader)

    data_pass = 0
    data_pass2 = 0
    num_epoch = 0

    # outer loop
    for i in range(num_iter):
        model.train()
        #print(i)

        try:
            data_in, target_in = trainloader_copy.next()
        except:
            trainloader_copy = iter(trainloader)
            
        data_in, target_in  = data_in.cuda(), target_in.cuda()
        logits = model(data_in)
        y_pred = torch.sigmoid(logits)
        loss_ce = Loss_CE(y_pred, target_in)
        optimizer.zero_grad()
        loss_ce.backward()
        optimizer.step()
           
        # calculate num of data pass
        data_pass += 2 * BATCH_SIZE
        data_pass2 += 2 * BATCH_SIZE
        
        if data_pass / n_train > 1:
            data_pass = 0
            num_epoch += 1
            data_pass_list.append(data_pass2 / n_train)
            
            model.eval()
         
            test_pred_pos = []
            test_pred_neg = []
            test_pred = []
            test_true = []
            for k, data in enumerate(testloader):
                test_data, test_targets = data
                test_data = test_data.cuda()
                outputs = model(test_data)
                test_pred.append(outputs.detach())
                test_true.append(test_targets.data)
                if test_targets.data > 0.1:
                   test_pred_pos.append(outputs.detach())
                else:
                   test_pred_neg.append(outputs.detach())
                  
            z_pos = torch.tensor(test_pred_pos)
            z_neg = torch.tensor(test_pred_neg)
            ttest_pred = torch.tensor(test_pred)
            ttest_true = torch.tensor(test_true)
       
            n_test_pos = z_pos.shape[0]
       
            sum_loss = 0
            sum_loss_pauc = 0
            for k in range(n_test_pos):
                diff = z_pos[k] - z_neg
                loss_pauc = 0 - diff
                loss_pauc[loss_pauc <= 0] = 0
                loss_pauc[loss_pauc > 0] = 1
                loss_pauc_sort, indices_auc = torch.sort(loss_pauc, descending=True)
                sum_loss_pauc += torch.sum(loss_pauc_sort[k1_value_test:k2_value_test])
            sum_loss_pauc /= (k2_value_test-k1_value_test) * n_test_pos
            pauc_test = 1 - sum_loss_pauc.data
            pauc_test_list.append(pauc_test.cpu().numpy())
            #pauc = partial_AUC(ttest_pred, ttest_true, k1_value_test, k2_value_test, alpha1, alpha2)
       
            # evaluation of validation dataset
            val_pred_pos = []
            val_pred_neg = []
            val_pred = []
            val_true = []
            for k, data in enumerate(valloader):
                val_data, val_targets = data
                val_data = val_data.cuda()
                outputs = model(val_data)
                val_pred.append(outputs.detach())
                val_true.append(val_targets.data)
                if val_targets.data > 0.1:
                   val_pred_pos.append(outputs.detach())
                else:
                   val_pred_neg.append(outputs.detach())
       
            z_pos_val = torch.tensor(val_pred_pos)
            z_neg_val = torch.tensor(val_pred_neg)
            vval_pred = torch.tensor(val_pred)
            vval_true = torch.tensor(val_true)
       
            pauc_val = partial_AUC(vval_pred, vval_true, k1_value_val, k2_value_val, alpha1, alpha2).data
            pauc_val_list.append(pauc_val.cpu().numpy())
            
            model.train()
         
            # print results
            print("valid partial auc=", pauc_val)
            print("test partial auc =", 1-sum_loss_pauc.data) 

        if num_epoch == 5:
            lr = lr / 10

    df1 = pd.DataFrame(pauc_test_list, columns=['pauc_test'])
    df2 = pd.DataFrame(pauc_val_list, columns=['pauc_val'])
    df3 = pd.DataFrame(data_pass_list, columns=['data_pass'])

    d1 = df2.join(df1)
    d2 = d1.join(df3)
    s = "lr="+str(lr)+"_num_iter="+str(num_iter)+"_seed="+str(SEED)+"_class_id="+str(class_id)+".csv"
    d2.to_csv(s)
