import torch
import math
import pandas as pd
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from libauc.models import DenseNet121
from chexpert import CheXpert
from libauc.optimizers import PESG
from libauc.datasets import ImbalanceSampler
from parameters_dc import para
import numpy as np
from losses import AUCMLoss, CrossEntropyBinaryLoss

def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def partial_AUC(scores, labels, k1, k2, alpha, beta):
    n = scores.shape[0]
    n_pos = int(sum(labels))
    n_neg = n-n_pos

    sorted, indices = torch.sort(scores, descending=True)
    re_index = indices
    re_label = labels[re_index]

    tp = torch.cumsum(re_label, dim=0)
    fp = torch.arange(1, n+1)-tp

    fpr = fp / n_neg
    tpr = tp / n_pos

    idx = torch.where((fp > k1) & (fp < k2))
    p_fpr = fpr[idx]
    p_tpr = tpr[idx]
    p_fp = fp[idx]
    p_tp = tp[idx]
    n_ptpr = p_tpr.shape[0]

    diff_list = torch.diff(p_fp)
    length = p_tpr.shape[0]
    pauc = torch.matmul(p_tpr[0:length-1], diff_list)/(n_neg*(beta-alpha))

    return pauc

# all paramaters
alpha1 = 0.05
alpha2 = 0.5
SEED = para.seed
BATCH_SIZE = 16
num_iter = para.num_iter
class_id = para.class_id # 0:Cardiomegaly, 1:Edema, 2:Consolidation, 3:Atelectasis, 4:Pleural Effusion
root = '../dataset/CheXpert-v1.0-small/'

imratio = 0.5
lr = para.lr
gamma = para.gamma
weight_decay = 0
margin = 1.0

set_all_seeds(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Using {device} device')

# dataloader 
IMG_SIZE = 224
n_total = 223414
print("total data = ", n_total)
n_train = int(n_total * 0.9)
n_val = n_total - n_train

trainSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=False, use_frontal=False, image_size=IMG_SIZE, mode='train', seed=SEED, class_index=class_id, split=True, start=0, end=n_train)
valSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=False, use_frontal=False, image_size=IMG_SIZE, mode='train', seed=SEED, class_index=class_id, split=True, start=n_train, end=n_total)
testSet = CheXpert(csv_path=root+'valid.csv',  image_root_path=root, use_upsampling=False, use_frontal=False, image_size=IMG_SIZE, mode='valid', seed=SEED, class_index=class_id)


trainloader = torch.utils.data.DataLoader(trainSet, batch_size=2*BATCH_SIZE,
                                          sampler=ImbalanceSampler(np.array(trainSet._labels_list).flatten().astype(int), 2*BATCH_SIZE, pos_num=BATCH_SIZE),
                                          num_workers=2, pin_memory=True, drop_last=True)
valloader = torch.utils.data.DataLoader(valSet, batch_size=1, num_workers=2, drop_last=False, shuffle=False)
testloader = torch.utils.data.DataLoader(testSet, batch_size=1, num_workers=2, drop_last=False, shuffle=False)

n_test = len(testSet._labels_list)
n_val = len(valSet._labels_list)
n_pos = int(sum(trainSet._labels_list))
n_neg = n_train - n_pos
print(n_pos)
print(n_neg)

n_pos_val = int(sum(valSet._labels_list))
n_neg_val = n_val - n_pos_val

n_pos_test = int(sum(testSet._labels_list))
n_neg_test = n_test-n_pos_test

k1_value = int(math.floor(alpha1 * n_neg))
k2_value = int(math.ceil(alpha2 * n_neg))

k1_value_test = int(math.floor(alpha1 * n_neg_test))
k2_value_test = int(math.ceil(alpha2 * n_neg_test))

k1_value_val = int(math.floor(alpha1 * n_neg_val))
k2_value_val = int(math.ceil(alpha2 * n_neg_val))

# model
model = DenseNet121(pretrained=False, last_activation='sigmoid', activations='relu', num_classes=1)
model = model.cuda()   
    
pauc_test_list = []
pauc_val_list = []
data_pass_list = []
    
# Training
if True:
    # L_AVG
    Loss = AUCMLoss(imratio=imratio)
    optimizer = PESG(model, 
                     a=Loss.a, 
                     b=Loss.b, 
                     alpha=Loss.alpha, 
                     imratio=imratio, 
                     lr=lr, 
                     gamma=gamma, 
                     margin=margin, 
                     weight_decay=weight_decay)

    print ('-'*30)

    trainloader_copy = iter(trainloader)

    data_pass = 0
    data_pass2 = 0
    num_epoch = 0

    # outer loop
    for i in range(num_iter):
        model.train()
        print(i)

        try:
            data_in, target_in = trainloader_copy.next()
        except:
            trainloader_copy = iter(trainloader)
            
        data_in, target_in = data_in.cuda(), target_in.cuda()
        preds = model(data_in)
        loss = Loss(preds, target_in)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
           
        # calculate num of data pass
        data_pass += 2 * BATCH_SIZE
        data_pass2 += 2 * BATCH_SIZE
        
        if data_pass / n_train > 1:
            data_pass = 0
            num_epoch += 1
            data_pass_list.append(data_pass2 / n_train)
            
            model.eval()
         
            test_pred_pos = []
            test_pred_neg = []
            test_pred = []
            test_true = []
            for k, data in enumerate(testloader):
                test_data, test_targets = data
                test_data = test_data.cuda()
                outputs = model(test_data)
                test_pred.append(outputs.detach())
                test_true.append(test_targets.data)
                if test_targets.data > 0.1:
                   test_pred_pos.append(outputs.detach())
                else:
                   test_pred_neg.append(outputs.detach())
                  
            z_pos = torch.tensor(test_pred_pos)
            z_neg = torch.tensor(test_pred_neg)
            ttest_pred = torch.tensor(test_pred)
            ttest_true = torch.tensor(test_true)
       
            n_test_pos = z_pos.shape[0]
       
            sum_loss = 0
            sum_loss_pauc = 0
            for k in range(n_test_pos):
                diff = z_pos[k] - z_neg
                loss_pauc = 0 - diff
                loss_pauc[loss_pauc <= 0] = 0
                loss_pauc[loss_pauc > 0] = 1
                loss_pauc_sort, indices_auc = torch.sort(loss_pauc, descending=True)
                sum_loss_pauc += torch.sum(loss_pauc_sort[k1_value_test:k2_value_test])
            sum_loss_pauc /= (k2_value_test-k1_value_test) * n_test_pos
            pauc_test = 1 - sum_loss_pauc.data
            pauc_test_list.append(pauc_test.cpu().numpy())
            #pauc = partial_AUC(ttest_pred, ttest_true, k1_value_test, k2_value_test, alpha1, alpha2)
       
            # evaluation of validation dataset
            val_pred_pos = []
            val_pred_neg = []
            val_pred = []
            val_true = []
            for k, data in enumerate(valloader):
                val_data, val_targets = data
                val_data = val_data.cuda()
                outputs = model(val_data)
                val_pred.append(outputs.detach())
                val_true.append(val_targets.data)
                if val_targets.data > 0.1:
                    val_pred_pos.append(outputs.detach())
                else:
                    val_pred_neg.append(outputs.detach())
       
            z_pos_val = torch.tensor(val_pred_pos)
            z_neg_val = torch.tensor(val_pred_neg)
            vval_pred = torch.tensor(val_pred)
            vval_true = torch.tensor(val_true)
       
            pauc_val = partial_AUC(vval_pred, vval_true, k1_value_val, k2_value_val, alpha1, alpha2).data
            pauc_val_list.append(pauc_val.cpu().numpy())
            
            model.train()
         
            # print results
            print("valid partial auc=", pauc_val)
            print("test partial auc =", 1-sum_loss_pauc.data) 

        if num_epoch == 5:
            lr = lr / 10

    df1 = pd.DataFrame(pauc_test_list, columns=['pauc_test'])
    df2 = pd.DataFrame(pauc_val_list, columns=['pauc_val'])
    df3 = pd.DataFrame(data_pass_list, columns=['data_pass'])

    d1 = df2.join(df1)
    d2 = d1.join(df3)
    s = "lr="+str(para.lr)+"_num_iter="+str(num_iter)+"_seed="+str(SEED)+"_class_id="+str(class_id)+"_gamma="+str(gamma)+".csv"
    d2.to_csv(s)       
