import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models.resnet import ResNet, Bottleneck, BasicBlock
from torchvision import datasets
from torch import optim

import os
import argparse
import pandas as pd

import numpy as np
from sklearn.metrics import roc_auc_score

from multi_task_sampler import DataSampler
from Multi_pAUC_KL import Multi_pAUC_KL

from utils import PAUC_MultiLabel, ImageDataset1,ImageDataset, pretrain, partial_auc, load_celeba, pAUC_two_metric, resnet18,pAUC_mini,FocalLoss
from libauc.datasets import imbalance_generator
from libauc.datasets import CIFAR10, CIFAR100
from libauc.models import DenseNet121
from chexpert import CheXpert, CheXpert1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



def set_params():
    if args.ds=='cifar100':
        args.batch_size = 100
        args.iter_record = None
        args.tasks = 100
        args.sample_task = 10
        args.epochs = 50
        args.pretrain = True

        args.beta_pauc = 0.2
        args.eta1_pauc = 0.1
        args.eta2_pauc = 0.1
        args.tau1_pauc = 1
        args.tau2_pauc = 1

        args.gamma_sopa = 0.1
        args.tau_sopa = 1.

        args.gamma_focal = 2
        args.alpha_focal = 0.25

    elif args.ds=='celeba':
        args.batch_size = 50
        args.iter_record = None
        args.tasks = 40
        args.sample_task = 5
        args.epochs = 50
        args.pretrain = True

        args.beta_pauc = 0.5
        args.eta1_pauc = 0.1
        args.eta2_pauc = 0.1
        args.tau1_pauc = 1
        args.tau2_pauc = 1

        args.gamma_sopa = 0.1
        args.tau_sopa = 1.

        args.gamma_focal = 2
        args.alpha_focal = 0.25

    elif args.ds=='chexpert':
        args.batch_size = 32
        args.tasks = 13
        args.sample_task = 1
        args.iter_record = 500
        args.epochs = 5
        args.pretrain = True

        args.beta_pauc = 0.7
        args.eta1_pauc = 0.1
        args.eta2_pauc = 0.1
        args.tau1_pauc = 10
        args.tau2_pauc = 1.

        args.gamma_sopa = 1.
        args.tau_sopa = 0.1

        args.gamma_focal = 1
        args.alpha_focal = 0.75


def load_model():
    if args.ds == 'chexpert':
        model = DenseNet121(pretrained=args.pretrain, last_activation=None, activations='relu', num_classes=args.tasks).to(device)
    else:
        model = resnet18(args.tasks).to(device)
    return model

def load_data():
    batch_size = args.batch_size
    if args.ds == 'chexpert':
        root = 'CheXpert-v1.0-small/'
        

        trainSet = CheXpert1(csv_path=root + 'train.csv', image_root_path=root, use_upsampling=False, use_frontal=True,
                                 image_size=224, data_split='train',mode='train', class_index=-1)
        valSet = CheXpert1(csv_path=root + 'train.csv', image_root_path=root, use_upsampling=False, use_frontal=True,
                                 image_size=224, data_split='valid',mode='valid', class_index=-1)
        testSet = CheXpert1(csv_path=root + 'valid.csv', image_root_path=root, use_upsampling=False, use_frontal=True,
                               image_size=224, data_split='test',mode='valid', class_index=-1)
            
        train_labels = torch.tensor(trainSet._labels_list)
        train_loader = torch.utils.data.DataLoader(trainSet, sampler=DataSampler(train_labels,batchSize=batch_size,multi_tasks=args.sample_task), batch_size=batch_size, num_workers=4, drop_last=True)

        test_loader = torch.utils.data.DataLoader(testSet, batch_size=batch_size, num_workers=2, shuffle=False)
        val_loader = torch.utils.data.DataLoader(valSet, batch_size=batch_size, num_workers=2, shuffle=False)

    elif args.ds == 'cifar100':
        (train_data, train_label), (test_data, test_label) = CIFAR100()
        idx = np.random.permutation(len(train_data))
        val_data, val_label = train_data[idx[:5000]], train_label[idx[:5000]]
        train_data, train_label = train_data[idx[5000:]], train_label[idx[5000:]]
        tmp1, tmp2, tmp3 = list(range(len(train_label))), list(range(len(test_label))), list(range(len(val_label)))
        train_labels, val_labels, test_labels = torch.zeros(len(train_label),args.tasks),torch.zeros(len(val_label),args.tasks), torch.zeros(len(test_label),args.tasks)
        train_labels[tmp1, torch.tensor(train_label).squeeze().long()] += 1
        test_labels[tmp2, torch.tensor(test_label).squeeze().long()] += 1
        val_labels[tmp3, torch.tensor(val_label).squeeze().long()] += 1
        # if args.ls in ['SOPA','PAUC','MB']:
        train_loader = DataLoader(ImageDataset1(train_data, train_labels, mode='train'),sampler=DataSampler(train_labels,batchSize=batch_size,multi_tasks=args.sample_task),
                         batch_size=args.batch_size, num_workers=4, pin_memory=True)
        val_loader = DataLoader(ImageDataset1(val_data, val_labels, mode='test'),
                         batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
        test_loader = DataLoader(ImageDataset1(test_data, test_labels, mode='test'), batch_size=args.batch_size,
                        shuffle=False, num_workers=4,  pin_memory=True)

    elif args.ds=='celeba':
        train_loader, test_loader, val_loader = load_celeba(batch_size, sampler=(args.ls in ['SOPA','PAUC']))

    return train_loader, val_loader, test_loader

def pretrained(model):
    if args.ds != 'chexpert':
        if args.pretrain:
            if args.ls in ['SOPA','PAUC','MB']:
                model_path = 'models/'+args.ds+'_resnet18_CrossEntropyLoss_pretrain.pth'
                if os.path.isfile(model_path):
                    model.load_state_dict(torch.load(model_path))
                else:
                    print('Please train a CE model first and name it as '+ model_path)

def run_with_idx(model, train_loader, val_loader, test_loader, criterion, optimizer, records, T, best_result):
    
    model.train()
    train_loss = 0
    save_flag = False

    for batch_idx, (ind, inputs, targets) in enumerate(train_loader):
        T += 1.
        if args.ls=='PAUC':
            criterion.beta1 = 1./np.sqrt(T)
        
        inputs, targets, ind = inputs.to(device), targets.float().to(device), ind.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = torch.sigmoid(outputs)
        if args.ls in ['SOPA','PAUC','MB']:
            loss = criterion(outputs, targets, ind)
        else:
            loss = criterion(outputs, targets) 
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        if (args.iter_record != None) and (T%args.iter_record == 0):
            tmp1, tmp2 = evaluation_with_id(model, criterion, train_loader)
            records['tr_auc1'].append(tmp1)
            records['tr_auc2'].append(tmp2)
            #wandb.log({"train_auc1": tmp1, "train_auc2": tmp2})
            tmp1, tmp2 = evaluation_with_id(model, criterion, val_loader)
            if tmp1>best_result:
                best_result = tmp1
                save_flag = True
            records['va_auc1'].append(tmp1)
            records['va_auc2'].append(tmp2)
            #wandb.log({"val_auc1": tmp1, "val_auc2": tmp2})
            tmp1, tmp2 = evaluation_with_id(model, criterion, test_loader)
            records['te_auc1'].append(tmp1)
            records['te_auc2'].append(tmp2)
            #wandb.log({"test_auc1": tmp1, "test_auc2": tmp2})
            print('\n')


    if args.iter_record == None:
        tmp1, tmp2 = evaluation(model, criterion, train_loader)
        records['tr_auc1'].append(tmp1)
        records['tr_auc2'].append(tmp2)
        #wandb.log({"train_auc1": tmp1, "train_auc2": tmp2})
        tmp1, tmp2 = evaluation(model, criterion, val_loader)
        if tmp1 > best_result:
            best_result = tmp1
            save_flag = True
        records['va_auc1'].append(tmp1)
        records['va_auc2'].append(tmp2)
        #wandb.log({"val_auc1": tmp1, "val_auc2": tmp2})
        tmp1, tmp2 = evaluation(model, criterion, test_loader)
        records['te_auc1'].append(tmp1)
        records['te_auc2'].append(tmp2)
        #wandb.log({"test_auc1": tmp1, "test_auc2": tmp2})
        print('\n')

    return save_flag, best_result


def run_no_idx(model, train_loader, val_loader, test_loader, criterion, optimizer, records, T, best_result):

    model.train()
    train_loss = 0
    save_flag = False

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        T += 1.

        inputs, targets = inputs.to(device), targets.float().to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = torch.sigmoid(outputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        if (args.iter_record != None) and (T % args.iter_record == 0):
            tmp1, tmp2 = evaluation(model, criterion, train_loader)
            records['tr_auc1'].append(tmp1)
            records['tr_auc2'].append(tmp2)
            #wandb.log({"train_auc1": tmp1, "train_auc2": tmp2})
            tmp1, tmp2 = evaluation(model, criterion, val_loader)
            if tmp1 > best_result:
                best_result = tmp1
                save_flag = True
            records['va_auc1'].append(tmp1)
            records['va_auc2'].append(tmp2)
            #wandb.log({"val_auc1": tmp1, "val_auc2": tmp2})
            tmp1, tmp2 = evaluation(model, criterion, test_loader)
            records['te_auc1'].append(tmp1)
            records['te_auc2'].append(tmp2)
            #wandb.log({"test_auc1": tmp1, "test_auc2": tmp2})
            print('\n')

    if args.iter_record == None:
        tmp1, tmp2 = evaluation_with_id(model, criterion, train_loader)
        records['tr_auc1'].append(tmp1)
        records['tr_auc2'].append(tmp2)
        #wandb.log({"train_auc1": tmp1, "train_auc2": tmp2})
        tmp1, tmp2 = evaluation_with_id(model, criterion, val_loader)
        if tmp1 > best_result:
            best_result = tmp1
            save_flag = True
        records['va_auc1'].append(tmp1)
        records['va_auc2'].append(tmp2)
        #wandb.log({"val_auc1": tmp1, "val_auc2": tmp2})
        tmp1, tmp2 = evaluation_with_id(model, criterion, test_loader)
        records['te_auc1'].append(tmp1)
        records['te_auc2'].append(tmp2)
        #wandb.log({"test_auc1": tmp1, "test_auc2": tmp2})
        print('\n')

    return save_flag, best_result

def evaluation_with_id(model, criterion, data_loader):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        test_pred = []
        test_true = []
        for batch_idx, (ind, inputs, targets) in enumerate(data_loader):
            ind, inputs, targets = ind.to(device), inputs.to(device), targets.float().to(device)
            outputs = model(inputs)
            outputs = torch.sigmoid(outputs)

            test_pred.append(outputs.cpu().detach().numpy())
            test_true.append(targets.cpu().numpy())
            if args.ls in ['SOPA','MB','PAUC']:
                loss = criterion(outputs, targets, ind)
            else:
                loss = criterion(outputs, targets)
            total_loss += loss.item()

        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        val_auc_mean1 = partial_auc(test_true, test_pred, max_fpr=0.1)
        val_auc_mean2 = partial_auc(test_true, test_pred, max_fpr=0.3)

    print('AUC1: %.3f, AUC2: %.3f'% (val_auc_mean1, val_auc_mean2))
    model.train()
    return val_auc_mean1, val_auc_mean2

def evaluation(model, criterion, data_loader):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        test_pred = []
        test_true = []
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            inputs, targets = inputs.to(device), targets.float().to(device)
            outputs = model(inputs)
            outputs = torch.sigmoid(outputs)

            test_pred.append(outputs.cpu().detach().numpy())
            test_true.append(targets.cpu().numpy())

            loss = criterion(outputs, targets)
            total_loss += loss.item()

        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        val_auc_mean1 = partial_auc(test_true, test_pred, max_fpr=0.1)
        val_auc_mean2 = partial_auc(test_true, test_pred, max_fpr=0.3)

    print('Loss: %.3f, AUC1: %.3f, AUC2: %.3f'%(total_loss / (batch_idx + 1), val_auc_mean1, val_auc_mean2))
    model.train()
    return val_auc_mean1, val_auc_mean2


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--ds', default='chexpert', type=str)
    parser.add_argument('--ls', default='MB', type=str)
    parser.add_argument('--id', default=0, type=int)
    global args
    args = parser.parse_args()


    set_params()
    


    model = load_model()
    train_loader, val_loader, test_loader = load_data()
    if args.ls == 'CE':
        criterion = nn.BCELoss(reduction='mean')
    elif args.ls == 'focal':
        criterion = FocalLoss(gamma=args.gamma_focal,alpha=args.alpha_focal)
    elif args.ls == 'SOPA':
        criterion = Multi_pAUC_KL(data_len=len(train_loader.dataset), gamma=args.gamma_sopa, Lambda=args.tau_sopa, total_tasks=args.tasks)
    elif args.ls == 'PAUC':
        criterion = PAUC_MultiLabel(num_classes=args.tasks, eta1=args.eta1_pauc, eta2=args.eta2_pauc, beta=args.beta_pauc, tau1=args.tau1_pauc, tau2=args.tau2_pauc)
    elif args.ls == 'MB':
        criterion = pAUC_mini(threshold=1., gamma=0.7,)

    if args.ls in ['PAUC']:
        optimizer = optim.SGD(list(model.parameters())+list(criterion.parameters()),momentum=0.9, lr=5e-3, weight_decay=5e-4)
    else:
        optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[3],gamma=0.1)

    pretrained(model)

    T = 0
    records = {'tr_auc1':[],
                'tr_auc2':[],
                'va_auc1':[],
                'va_auc2':[],
                'te_auc1':[],
                'te_auc2':[],
                }
    model_path = 'models/'+args.ds+'_'+args.ls+str(args.id)+'.pth'

    best_result = 0
    for i in range(args.epochs):
        print('\nEpoch: '+str(i+1))
        # if args.ls in ['SOPA','PAUC','MB']:
        save_flag, best_result = run_with_idx(model, train_loader, val_loader, test_loader, criterion, optimizer, records, T, best_result)
        # else:
        #     save_flag, best_result = run_no_idx(model, train_loader, val_loader, test_loader, criterion, optimizer, records, T, best_result)
        if save_flag:
            torch.save(model.state_dict(), model_path)
        scheduler.step()

    np.save(args.ds+'_'+args.ls+'_history_'+'.npy',records)

if __name__=='__main__':
        main()






