import random
import os
import numpy as np
import copy
import argparse
from PIL import Image
# from torch
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.nn import DataParallel
from torch.autograd import grad
from torch.utils.data import Dataset
import torch.nn.functional as F   
import torchvision.transforms as transforms
from libauc.datasets import CIFAR100
from libauc.losses import MultiLabelAUCMLoss
from libauc.optimizers import PESG, Adam, SGD
from libauc.models import resnet18
from libauc.metrics import auc_roc_score # for multi-task
# from other files
from datasets.chexpert import CheXpert
from datasets.celeba import CelebaDataset
from models.densenet import DenseNet121
from auclosses import CrossEntropyBinaryLoss_MultiLabel
from sklearn.metrics import roc_auc_score
from utils import log
import higher
import copy
import shutil
import time



import argparse


parser = argparse.ArgumentParser()

parser.add_argument('--SEED', default=123, type=int)
parser.add_argument('--num_workers', default=0, type=int)
parser.add_argument('--gpuid', default='0', type=str, help='id(s) for CUDA_VISIBLE_DEVICES')
parser.add_argument('--margin', default=1.0, type=float)
parser.add_argument('--inner_update_steps', default=1, type=int)

parser.add_argument('--sample_batch_size', '-s', default=128, type=int)
parser.add_argument('--task_batch_size', '-t', default=10, type=int)

# cifar100: 2000, celeba: 40, chexpert: 6
parser.add_argument('--total_epoch', '-e', default=100, type=int)
parser.add_argument('--lr', default=0.1, type=float)
parser.add_argument('--beta_ct', default=0.9, type=float)
parser.add_argument('--beta', default=0.9, type=float)
parser.add_argument('--lambda_value', default=6, type=float)

parser.add_argument('--save_dir', default='exp', type=str)
# cifar100, celeba, chexpert
parser.add_argument('--dataset', default='cifar100', type=str)
parser.add_argument('--data_dir', default='', type=str)



args = parser.parse_args()


os.environ["CUDA_VISIBLE_DEVICES"] = args.gpuid


def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
def zero_grad(model):
    for name, p in model.named_parameters():
        if p.grad is not None:
            p.grad.data.zero_()

def label2onehot(labels, num_classes):
    onehot = torch.zeros(len(labels), num_classes)
    for i, vec in enumerate(onehot):    
        vec[labels[i]] = 1
    return onehot

def auc_score(true, pred):
    try:
        score = roc_auc_score(true, pred)
    except:
        score = 0
    return score

def evaluate(loader, model, weight):
    pred = []
    true = []
    for j, (data, targets) in enumerate(loader):
        data = data.cuda()
        outputs = model(data, params=weight)
        y_pred = torch.sigmoid(outputs)
        # print(outputs)
        pred.append(y_pred.cpu().detach().numpy())
        targets = label2onehot(targets, num_classes=100)
        true.append(targets.numpy())
    true = np.concatenate(true)
    pred = np.concatenate(pred)
    score = auc_score(true, pred)
    return score

class ImageDataset(Dataset):
    def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
       self.images = images.astype(np.uint8)
       self.targets = targets
       self.mode = mode
       self.transform_train = transforms.Compose([                                                
                              transforms.ToTensor(),
                              transforms.RandomCrop((crop_size, crop_size), padding=None),
                              transforms.RandomHorizontalFlip(),
                              transforms.Resize((image_size, image_size)),
                              ])
       self.transform_test = transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Resize((image_size, image_size)),
                              ])
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        image = Image.fromarray(image.astype('uint8'))
        if self.mode == 'train':
            image = self.transform_train(image)
        else:
            image = self.transform_test(image)
        return image, target


if __name__ == '__main__':
    '''
    ###################################
    Hyperparameter setting
    ###################################
    '''
    print('Hyperparameter setting...')
    # HyperParameters
    SEED = args.SEED
    set_all_seeds(SEED=SEED)
    imratio = 0.1 
    total_epochs = args.total_epoch

    lr = args.lr
    lr_u = args.lr
    lr_v = lr_u * args.lambda_value
    margin = 1.0
    epoch_decay = 0.003 
    weight_decay = 0.0001
    
    beta = args.beta
    beta_ct = args.beta_ct
    _lambda = args.lambda_value
    '''
    ###################################
    Logging
    ###################################
    '''
    save_dir = args.save_dir
    try:
        os.mkdir(save_dir)
    except:
        shutil.rmtree(save_dir)
        os.mkdir(save_dir)
        
    logdir = os.path.join(save_dir, 'training_log.txt')
    log(logdir, str(vars(args)))
    '''
    ##################################
    Dataset realted
    ##################################
    '''
    print('Dataset setting...')
    if args.dataset == 'cifar100':
        decay_epochs = []
        SAMPLE_BATCH_SIZE = 128
        TASK_BATCH_SIZE = 10
        NUM_CLASSES = 100
        NUM_SAMPLE_TASKS = 10
        TOTAL_BATCH_SIZE = SAMPLE_BATCH_SIZE * TASK_BATCH_SIZE
        # load data as numpy arrays 
        train_data, train_targets = CIFAR100(root=args.data_dir, train=True).as_array()
        test_data, test_targets  = CIFAR100(root=args.data_dir, train=False).as_array()

        trainSet = ImageDataset(train_data, train_targets)
        trainSet_eval = ImageDataset(train_data, train_targets, mode='test')
        testSet = ImageDataset(test_data, test_targets, mode='test')

        trainloader = torch.utils.data.DataLoader(trainSet, batch_size=TOTAL_BATCH_SIZE, num_workers=8, drop_last=True)
        trainloader_eval = torch.utils.data.DataLoader(trainSet_eval, batch_size=TOTAL_BATCH_SIZE, shuffle=False, num_workers=8)
        testloader = torch.utils.data.DataLoader(testSet, batch_size=TOTAL_BATCH_SIZE, shuffle=False, num_workers=8)
        
        trainloader2 = torch.utils.data.DataLoader(trainSet, batch_size=TOTAL_BATCH_SIZE, num_workers=8, drop_last=True)
        trainloader2_copy = iter(trainloader2)
    elif args.dataset == 'celeba':
        decay_epochs = [30]
        SAMPLE_BATCH_SIZE = 128
        TASK_BATCH_SIZE = 1
        NUM_CLASSES = 40
        NUM_SAMPLE_TASKS = 1
        TOTAL_BATCH_SIZE = SAMPLE_BATCH_SIZE * TASK_BATCH_SIZE
        # load data as numpy arrays 
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])
        root = args.data_dir
        train_dataset = CelebaDataset(
            root + 'celeba_attr_train.csv',
            root + 'img_align_celeba/img_align_celeba/',
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        val_dataset = CelebaDataset(root + 'celeba_attr_val.csv', root + 'img_align_celeba/img_align_celeba/',
                                    transforms.Compose([
                                        transforms.ToTensor(),
                                        normalize,
                                    ]))
        test_dataset = CelebaDataset(root + 'celeba_attr_test.csv', root + 'img_align_celeba/img_align_celeba/',
                                        transforms.Compose([
                                            transforms.ToTensor(),
                                            normalize,
                                        ]))
        train_sampler = None
        trainloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=TOTAL_BATCH_SIZE, shuffle=(train_sampler is None),
            num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)
        valloader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=TOTAL_BATCH_SIZE, shuffle=False,
            num_workers=args.num_workers, pin_memory=True)
        testloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=TOTAL_BATCH_SIZE, shuffle=False,
            num_workers=args.num_workers, pin_memory=True)
        imratio = train_dataset.imratio_list
        
        
        trainloader2 = torch.utils.data.DataLoader(
            train_dataset, batch_size=TOTAL_BATCH_SIZE, shuffle=(train_sampler is None),
            num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)
        trainloader2_copy = iter(trainloader2)
        
        imratio = train_dataset.imratio_list

    elif args.dataset == 'chexpert':
        decay_epochs = [4]
        SAMPLE_BATCH_SIZE = 32
        TASK_BATCH_SIZE = 1
        NUM_CLASSES = 13
        NUM_SAMPLE_TASKS = 1
        TOTAL_BATCH_SIZE = SAMPLE_BATCH_SIZE * TASK_BATCH_SIZE
        root = args.data_dir
        trainSet = CheXpert(csv_path=root + 'train.csv', image_root_path=root, use_upsampling=False, use_frontal=True,
                                image_size=224, mode='train', class_index=-1, data_split='train')
        valSet = CheXpert(csv_path=root + 'train.csv', image_root_path=root, use_upsampling=False, use_frontal=True,
                            image_size=224, mode='valid', class_index=-1, data_split='valid')
        testSet = CheXpert(csv_path=root + 'valid.csv', image_root_path=root, use_upsampling=False, use_frontal=True,
                            image_size=224, mode='valid', class_index=-1, data_split='test')
        trainloader = torch.utils.data.DataLoader(trainSet, batch_size=TOTAL_BATCH_SIZE, num_workers=args.num_workers, shuffle=True)

        valloader = torch.utils.data.DataLoader(valSet, batch_size=TOTAL_BATCH_SIZE, num_workers=args.num_workers,
                                                    shuffle=False)
        testloader = torch.utils.data.DataLoader(testSet, batch_size=TOTAL_BATCH_SIZE, num_workers=args.num_workers, shuffle=False)

        imratio = trainSet.imratio_list
        
        
        trainloader2 = torch.utils.data.DataLoader(trainSet, batch_size=TOTAL_BATCH_SIZE, num_workers=args.num_workers, shuffle=True)
        trainloader2_copy = iter(trainloader2)
        
        imratio = trainSet.imratio_list
    else:
        raise(NotImplementedError('No corresponding dataset'))
    

    '''
    ##################################
    model and training config
    ##################################
    '''
    print('Model setting...')

    if args.dataset == 'cifar100' or args.dataset == 'celeba':
        model = resnet18(pretrained=False, last_activation=None, num_classes=NUM_CLASSES).cuda()
    elif args.dataset == 'chexpert':
        model = DenseNet121(pretrained=True, last_activation=None, activations='relu', num_classes=13).cuda()
    else:
        raise(NotImplementedError('No corresponding dataset')) 

    mp_model = higher.monkeypatch(model, copy_initial_weights=True).cuda()
    
    w_weights_list = list(model.parameters())
    w_weights = [param.requires_grad_(True) for param in w_weights_list]
    u_weights_list = copy.deepcopy(w_weights_list)
    u_weights = [param.requires_grad_(True) for param in u_weights_list]
    v_weights_list = copy.deepcopy(w_weights_list)
    v_weights = [param.requires_grad_(True) for param in v_weights_list]

    loss_fn = MultiLabelAUCMLoss(num_labels=NUM_CLASSES)
    Loss_ce = CrossEntropyBinaryLoss_MultiLabel()
    optimizer = SGD(model.parameters(), lr=lr)
    z_a = copy.deepcopy(loss_fn.a)
    z_b = copy.deepcopy(loss_fn.b)
    z_w_list = []
    for (name, w) in model.named_parameters():
        z_w_list.append(torch.zeros_like(w))

    '''
    ############################
    training
    ############################
    '''
    label_set = np.linspace(0, NUM_CLASSES - 1, NUM_CLASSES).astype(int)
    print ('Start Training')
    print ('-'*30)
    best_val_auc_u = 0 
    best_val_auc_v = 0 
    training_start = time.time()
    time_list = []
    step_flag = 1
    for epoch in range(total_epochs):
        if epoch in decay_epochs:
            lr = lr / 10
        
        '''#################################inner loop#################################################
        for FO:
        1. v is updated with inner loss, i.e. CE loss
        2. u is updated with Lagrangian, i.e. auc loss + lambda (inner_w - inner_v)
        3. note that for CT method inner and outer are both updated on training set, so we are doing the same
        '''
        
        for idx, data in enumerate(trainloader):
            iter_start = time.time()
            
            grads_a, grads_b, grads_alp, loss_auc = 0, 0, 0, 0
            
            np.random.shuffle(label_set)
            selectTasks = torch.Tensor(label_set[:TASK_BATCH_SIZE]).to(torch.int)
            
            try:
                data_in, targets_in = trainloader2_copy.next()
            except:
                trainloader2_copy = iter(trainloader2)
                data_in, targets_in = trainloader2_copy.next()
            data_in, targets_in = data_in.cuda(), label2onehot(targets_in, num_classes=NUM_CLASSES).cuda()
            data_out, targets_out = data
            data_out, targets_out = data_out.cuda(), label2onehot(targets_out, num_classes=NUM_CLASSES).cuda()
            
            
            loss_ce = 0.
            loss_auc_u = 0.
            loss_auc_v = 0.
            
            loss_fn.zero_grad()
            for task in range(TASK_BATCH_SIZE):
                data_in_perTask = data_in[task * SAMPLE_BATCH_SIZE : (task+1) * SAMPLE_BATCH_SIZE]
                targets_in_perTask = targets_in[task * SAMPLE_BATCH_SIZE: (task + 1) * SAMPLE_BATCH_SIZE]
                
                data_out_perTask = data_out[task * SAMPLE_BATCH_SIZE : (task+1) * SAMPLE_BATCH_SIZE]
                targets_out_perTask = targets_out[task * SAMPLE_BATCH_SIZE: (task + 1) * SAMPLE_BATCH_SIZE]
                
                outputs_in_w = mp_model(data_in_perTask, params=w_weights)
                outputs_out_u = mp_model(data_out_perTask, params=u_weights)
                
                pred_out_u = torch.sigmoid(outputs_out_u)
                
                loss_ce += Loss_ce(outputs_in_w, targets_in_perTask, selectTasks=[selectTasks[task]])
                selected_task = selectTasks[task]
                loss_auc_u += loss_fn(pred_out_u, targets_out_perTask, task_id=torch.Tensor([selectTasks[task]]).to(torch.int))
            
            loss_ce = loss_ce / TASK_BATCH_SIZE
            loss_auc_u = loss_auc_u / TASK_BATCH_SIZE
            grads_ce = torch.autograd.grad(loss_ce, w_weights, retain_graph=True)
            grads_u = torch.autograd.grad(loss_auc_u, u_weights, retain_graph=False)

            for v, grad_ce, u, grad_u, w in zip(v_weights, grads_ce, u_weights, grads_u, w_weights):
                v.data = (1 - lr) * v.data + lr * (v.data - beta_ct * grad_ce.data)
                u.data = u.data - lr * (grad_u.data + _lambda * (u.data - w.data + beta_ct * grad_ce.data))
                # from Hu. et al. 2022
                w.data = (1 - lr_v) * w.data + lr_v * (w.data - beta_ct * grad_ce.data)
                
            '''#################################outer loop#################################################
            1. calculate auc loss for other conponents that does not depend on inner variable using v
            2. gradient of w is from Lagrangian
            '''
            
            try:
                loss_fn.a.grad.data.zero_()
                loss_fn.b.grad.data.zero_()
                loss_fn.alpha.grad.data.zero_()
            except:
                pass
            for task in range(TASK_BATCH_SIZE):
                data_out_perTask = data_out[task * SAMPLE_BATCH_SIZE : (task+1) * SAMPLE_BATCH_SIZE]
                targets_out_perTask = targets_out[task * SAMPLE_BATCH_SIZE: (task + 1) * SAMPLE_BATCH_SIZE]
                
                outputs_out_v = mp_model(data_out_perTask, params=v_weights)
                pred_out_v = torch.sigmoid(outputs_out_v)

                loss = loss_fn(pred_out_v, targets_out_perTask, task_id=torch.Tensor([selectTasks[task]]).to(torch.int))
                loss_auc_v += loss
            
            
            optimizer.zero_grad()
            loss_auc_v = loss_auc_v / TASK_BATCH_SIZE
            loss_auc_v.backward()

            z_a.data = (1 - beta) * z_a + beta * loss_fn.a.grad
            loss_fn.a.data = loss_fn.a - lr * z_a
            
            z_b.data = (1 - beta) * z_b + beta * loss_fn.b.grad
            loss_fn.b.data = loss_fn.b - lr * z_b
            
            loss_fn.alpha.data = loss_fn.alpha + lr * loss_fn.alpha.grad
            loss_fn.alpha.data = torch.clamp(loss_fn.alpha.data, 0, 999)
            
            
            for w, u, v, z_w in zip(w_weights, u_weights, v_weights, z_w_list):
                grad_w = _lambda * (v - u)
                z_w.data = (1 - beta) * z_w + beta * grad_w
                w.data = w.data - lr * z_w
                u.data = w.data
                v.data = w.data
            
            iter_end = time.time()
            iter_time = iter_end - iter_start
            time_list.append(iter_time)
                
            # validation  
            if idx % 400 == 0:
                test_auc_u = evaluate(testloader, mp_model, u_weights)
                test_auc_v = evaluate(testloader, mp_model, v_weights)
                
                if best_val_auc_u < test_auc_u:
                    best_val_auc_u = test_auc_u
                if best_val_auc_v < test_auc_v:
                    best_val_auc_v = test_auc_v
                
                log(logdir, f'''Epoch={epoch}, BatchID={idx}, 
                    Val_AUC_u={test_auc_u}, Best_Val_AUC_u={best_val_auc_u},
                    Val_AUC_v={test_auc_v}, Best_Val_AUC_v={best_val_auc_v},
                    Train_AUC_Loss_u={loss_auc_u}, Train_AUC_Loss_v={loss_auc_v},
                    a = {torch.mean(loss_fn.a)}, b = {torch.mean(loss_fn.b)}, alpha = {torch.mean(loss_fn.alpha)},
                    average iteration time = {np.mean(time_list)}
                    ''')
    
    training_end = time.time()
    training_time = training_end - training_start            
    log(logdir, f'''Training Finished! Total time = {training_time}''')
