
import os
import argparse
import copy
from libauc.losses import AUCMLoss
from libauc.models import resnet20 as ResNet20
from libauc.datasets import CIFAR10, CIFAR100, STL10, CAT_VS_DOG
from libauc.utils import ImbalancedDataGenerator
from libauc.sampler import DualSampler
from libauc.metrics import auc_roc_score
from multiprocessing.spawn import freeze_support
import torch 
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score
import warnings

from optimizer import VRAda, Adam, TiAda, TiAda_Adam
warnings.filterwarnings("ignore")


def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False




class ImageDataset(Dataset):
    def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
       self.images = images.astype(np.uint8)
       self.targets = targets
       self.mode = mode
       self.transform_train = transforms.Compose([                                                
                              transforms.ToTensor(),
                              transforms.RandomCrop((crop_size, crop_size), padding=None),
                              transforms.RandomHorizontalFlip(),
                              transforms.Resize((image_size, image_size)),
                              ])
       self.transform_test = transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Resize((image_size, image_size)),
                              ])
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        image = Image.fromarray(image.astype('uint8'))
        if self.mode == 'train':
            image = self.transform_train(image)
        else:
            image = self.transform_test(image)
        return image, target

def main_VRAda(imratio,dataset,beta,lr_x,lr_y):
    BATCH_SIZE = 128

    total_epochs = 100
    decay_epochs = [50, 75]


    margin = 1.0
    epoch_decay = 0.003  # refers gamma in the paper
    weight_decay = 0.0001

    # oversampling minority class, you can tune it in (0, 0.5]
    # e.g., sampling_rate=0.2 is that num of positive samples in mini-batch is sampling_rate*batch_size=13
    sampling_rate = 0.2

    # load data as numpy arrays
    if dataset == 'C10':
        IMG_SIZE = 32
        train_data, train_targets = CIFAR10(root='./data', train=True)
        test_data, test_targets = CIFAR10(root='./data', train=False)
    elif dataset == 'C100':
        IMG_SIZE = 32
        train_data, train_targets = CIFAR100(root='./data', train=True)
        test_data, test_targets = CIFAR100(root='./data', train=False)
    elif dataset == 'STL10':
        BATCH_SIZE = 32
        IMG_SIZE = 96
        train_data, train_targets = STL10(root='./data/', split='train')
        test_data, test_targets = STL10(root='./data/', split='test')
        train_data = train_data.transpose(0, 2, 3, 1)
        test_data = test_data.transpose(0, 2, 3, 1)
    elif dataset == 'C2':
        IMG_SIZE = 50
        train_data, train_targets = CAT_VS_DOG('./data/', train=True)
        test_data, test_targets = CAT_VS_DOG('./data/', train=False)

    # generate imbalanced data
    generator = ImbalancedDataGenerator(verbose=True, random_seed=0)
    (train_images, train_labels) = generator.transform(train_data, train_targets, imratio=imratio)
    (test_images, test_labels) = generator.transform(test_data, test_targets, imratio=0.5)

    # data augmentations
    trainSet = ImageDataset(train_images, train_labels)
    trainSet_eval = ImageDataset(train_images, train_labels, mode='test')
    testSet = ImageDataset(test_images, test_labels, mode='test')

    # dataloaders
    sampler = DualSampler(trainSet, BATCH_SIZE, sampling_rate=sampling_rate)
    trainloader = torch.utils.data.DataLoader(trainSet, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2)
    trainloader_eval = torch.utils.data.DataLoader(trainSet_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    testloader = torch.utils.data.DataLoader(testSet, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    """# **Creating models & AUC Optimizer**"""
    # You can include sigmoid/l2 activations on model's outputs before computing loss
    model = ResNet20(pretrained=False, last_activation=None, num_classes=1)
    model = model.cuda()

    model_old = None

    # You can also pass Loss.a, Loss.b, Loss.alpha to optimizer (for old version users)
    loss_fn = AUCMLoss()

    optimizer = VRAda(model,
                     loss_fn=loss_fn,
                     lr_x=lr_x,
                     lr_y=lr_y,
                     beta=beta,
                     momentum=0.9,
                     margin=margin,
                     epoch_decay=epoch_decay,
                     weight_decay=weight_decay)
    """# **Training**"""
    print ('Start Training')
    print ('-'*30)

    train_auc_log = []
    test_auc_log = []
    train_loss_log = []
    iteration = 0
    for epoch in range(total_epochs):
        if epoch in decay_epochs:
            optimizer.update_regularizer(decay_factor=10) # decrease learning rate by 10x & update regularizer

        train_loss = []
        model.train()
        for data, targets in trainloader:
            data, targets = data.cuda(), targets.cuda()
            y_pred = model(data)
            y_pred = torch.sigmoid(y_pred)
            loss = loss_fn(y_pred, targets)

            if model_old is not None:
                model_old_pred = model_old(data)
                model_old_pred = torch.sigmoid(model_old_pred)
                model_old_loss = loss_fn(model_old_pred, targets)
                model_old.zero_grad()
                model_old_loss.backward()
                delta_x = [g.grad.data.clone() for g in model_old.parameters()]
                delta_y = loss_fn.alpha.grad.data.clone()
            else:
                delta_x = None
                delta_y = None
            model_old = copy.deepcopy(model).cuda()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step(delta_x=delta_x, delta_y=delta_y)
            train_loss.append(loss.item())

         # evaluation on train & test sets
        model.eval()
        train_pred_list = []
        train_true_list = []
        for train_data, train_targets in trainloader_eval:
            train_data = train_data.cuda()
            train_pred = model(train_data)
            train_pred_list.append(train_pred.cpu().detach().numpy())
            train_true_list.append(train_targets.numpy())
        train_true = np.concatenate(train_true_list)
        train_pred = np.concatenate(train_pred_list)
        train_auc = auc_roc_score(train_true, train_pred)
        train_loss = np.mean(train_loss)

        test_pred_list = []
        test_true_list = []
        for test_data, test_targets in testloader:
            test_data = test_data.cuda()
            test_pred = model(test_data)
            test_pred_list.append(test_pred.cpu().detach().numpy())
            test_true_list.append(test_targets.numpy())
        test_true = np.concatenate(test_true_list)
        test_pred = np.concatenate(test_pred_list)
        val_auc = auc_roc_score(test_true, test_pred)
        model.train()

         # print results
        print("epoch: %s, lr_x: %.1f, lr_y: %.1f, train_loss: %.4f, train_auc: %.4f, test_auc: %.4f"%(epoch, lr_x, lr_y, train_loss, train_auc, val_auc))
        train_auc_log.append(train_auc)
        test_auc_log.append(val_auc)
        train_loss_log.append(train_loss)
        iteration += 1


def main_Adam(imratio, dataset, lr_x, lr_y):

    # HyperParameters
    SEED = 123
    BATCH_SIZE = 128

    total_epochs = 100
    decay_epochs = [50, 75]

    margin = 1.0
    epoch_decay = 0.003  # refers gamma in the paper
    weight_decay = 0.0001

    # oversampling minority class, you can tune it in (0, 0.5]
    # e.g., sampling_rate=0.2 is that num of positive samples in mini-batch is sampling_rate*batch_size=13
    sampling_rate = 0.2

    # load data as numpy arrays
    if dataset == 'C10':
        IMG_SIZE = 32
        train_data, train_targets = CIFAR10(root='./data', train=True)
        test_data, test_targets = CIFAR10(root='./data', train=False)
    elif dataset == 'C100':
        IMG_SIZE = 32
        train_data, train_targets = CIFAR100(root='./data', train=True)
        test_data, test_targets = CIFAR100(root='./data', train=False)
    elif dataset == 'STL10':
        BATCH_SIZE = 32
        IMG_SIZE = 96
        train_data, train_targets = STL10(root='./data/', split='train')
        test_data, test_targets = STL10(root='./data/', split='test')
        train_data = train_data.transpose(0, 2, 3, 1)
        test_data = test_data.transpose(0, 2, 3, 1)
    elif dataset == 'C2':
        IMG_SIZE = 50
        train_data, train_targets = CAT_VS_DOG('./data/', train=True)
        test_data, test_targets = CAT_VS_DOG('./data/', train=False)

    # generate imbalanced data
    generator = ImbalancedDataGenerator(verbose=True, random_seed=0)
    (train_images, train_labels) = generator.transform(train_data, train_targets, imratio=imratio)
    (test_images, test_labels) = generator.transform(test_data, test_targets, imratio=0.5)

    # data augmentations
    trainSet = ImageDataset(train_images, train_labels)
    trainSet_eval = ImageDataset(train_images, train_labels, mode='test')
    testSet = ImageDataset(test_images, test_labels, mode='test')

    # dataloaders
    sampler = DualSampler(trainSet, BATCH_SIZE, sampling_rate=sampling_rate)
    trainloader = torch.utils.data.DataLoader(trainSet, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2)
    trainloader_eval = torch.utils.data.DataLoader(trainSet_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    testloader = torch.utils.data.DataLoader(testSet, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    """# **Creating models & AUC Optimizer**"""
    # You can include sigmoid/l2 activations on model's outputs before computing loss
    model = ResNet20(pretrained=False, last_activation=None, num_classes=1)
    model = model.cuda()

    # You can also pass Loss.a, Loss.b, Loss.alpha to optimizer (for old version users)
    loss_fn = AUCMLoss()
    optimizer = Adam(model,
                     loss_fn=loss_fn,
                     margin=margin,
                     lr_x=lr_x,
                     lr_y=lr_y,
                     epoch_decay=epoch_decay,
                     weight_decay=weight_decay)
    """# **Training**"""
    print ('Start Training')
    print ('-'*30)
    train_loss_log = []
    train_auc_log = []
    test_auc_log = []
    iteration = 0
    for epoch in range(total_epochs):
        if epoch in decay_epochs:
            optimizer.update_regularizer(decay_factor=10)  # decrease learning rate by 10x & update regularizer

        train_loss = []
        model.train()
        for data, targets in trainloader:
            data, targets = data.cuda(), targets.cuda()
            y_pred = model(data)
            y_pred = torch.sigmoid(y_pred)
            loss = loss_fn(y_pred, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())

        # evaluation on train & test sets
        model.eval()
        train_pred_list = []
        train_true_list = []
        for train_data, train_targets in trainloader_eval:
            train_data = train_data.cuda()
            train_pred = model(train_data)
            train_pred_list.append(train_pred.cpu().detach().numpy())
            train_true_list.append(train_targets.numpy())
        train_true = np.concatenate(train_true_list)
        train_pred = np.concatenate(train_pred_list)
        train_auc = auc_roc_score(train_true, train_pred)
        train_loss = np.mean(train_loss)

        test_pred_list = []
        test_true_list = []
        for test_data, test_targets in testloader:
            test_data = test_data.cuda()
            test_pred = model(test_data)
            test_pred_list.append(test_pred.cpu().detach().numpy())
            test_true_list.append(test_targets.numpy())
        test_true = np.concatenate(test_true_list)
        test_pred = np.concatenate(test_pred_list)
        val_auc = auc_roc_score(test_true, test_pred)
        model.train()

        # print results
        print(
            "epoch: %s, lr_x: %.1f, lr_y: %.1f, dataset: %s, imratio: %.2f, train_loss: %.4f, train_auc: %.4f, test_auc: %.4f" % (
                epoch, lr_x, lr_y, dataset, imratio, train_loss, train_auc, val_auc))
        train_auc_log.append(train_auc)
        test_auc_log.append(val_auc)
        train_loss_log.append(train_loss)
        iteration += 1


def main_TiAda(imratio,dataset,lr_x,lr_y):

    # HyperParameters
    SEED = 123
    BATCH_SIZE = 128

    total_epochs = 100
    decay_epochs = [50, 75]

    margin = 1.0
    epoch_decay = 0.003  # refers gamma in the paper
    weight_decay = 0.0001

    # oversampling minority class, you can tune it in (0, 0.5]
    # e.g., sampling_rate=0.2 is that num of positive samples in mini-batch is sampling_rate*batch_size=13
    sampling_rate = 0.2

    # load data as numpy arrays
    if dataset == 'C10':
        IMG_SIZE = 32
        train_data, train_targets = CIFAR10(root='./data', train=True)
        test_data, test_targets = CIFAR10(root='./data', train=False)
    elif dataset == 'C100':
        IMG_SIZE = 32
        train_data, train_targets = CIFAR100(root='./data', train=True)
        test_data, test_targets = CIFAR100(root='./data', train=False)
    elif dataset == 'STL10':
        BATCH_SIZE = 32
        IMG_SIZE = 96
        train_data, train_targets = STL10(root='./data/', split='train')
        test_data, test_targets = STL10(root='./data/', split='test')
        train_data = train_data.transpose(0, 2, 3, 1)
        test_data = test_data.transpose(0, 2, 3, 1)
    elif dataset == 'C2':
        IMG_SIZE = 50
        train_data, train_targets = CAT_VS_DOG('./data/', train=True)
        test_data, test_targets = CAT_VS_DOG('./data/', train=False)

    # generate imbalanced data
    generator = ImbalancedDataGenerator(verbose=True, random_seed=0)
    (train_images, train_labels) = generator.transform(train_data, train_targets, imratio=imratio)
    (test_images, test_labels) = generator.transform(test_data, test_targets, imratio=0.5)

    # data augmentations
    trainSet = ImageDataset(train_images, train_labels)
    trainSet_eval = ImageDataset(train_images, train_labels, mode='test')
    testSet = ImageDataset(test_images, test_labels, mode='test')

    # dataloaders
    sampler = DualSampler(trainSet, BATCH_SIZE, sampling_rate=sampling_rate)
    trainloader = torch.utils.data.DataLoader(trainSet, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2)
    trainloader_eval = torch.utils.data.DataLoader(trainSet_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    testloader = torch.utils.data.DataLoader(testSet, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    """# **Creating models & AUC Optimizer**"""
    # You can include sigmoid/l2 activations on model's outputs before computing loss
    model = ResNet20(pretrained=False, last_activation=None, num_classes=1)
    model = model.cuda()

    # You can also pass Loss.a, Loss.b, Loss.alpha to optimizer (for old version users)
    loss_fn = AUCMLoss()
    optimizer = TiAda(model,
                     loss_fn=loss_fn,
                     margin=margin,
                     lr_x=lr_x,
                     lr_y=lr_y,
                     epoch_decay=epoch_decay,
                     weight_decay=weight_decay)
    """# **Training**"""
    print ('Start Training')
    print ('-'*30)
    train_loss_log = []
    train_auc_log = []
    test_auc_log = []
    iteration = 0
    for epoch in range(total_epochs):
        if epoch in decay_epochs:
            optimizer.update_regularizer(decay_factor=10)  # decrease learning rate by 10x & update regularizer

        train_loss = []
        model.train()
        for data, targets in trainloader:
            data, targets = data.cuda(), targets.cuda()
            y_pred = model(data)
            y_pred = torch.sigmoid(y_pred)
            loss = loss_fn(y_pred, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())

        # evaluation on train & test sets
        model.eval()
        train_pred_list = []
        train_true_list = []
        for train_data, train_targets in trainloader_eval:
            train_data = train_data.cuda()
            train_pred = model(train_data)
            train_pred_list.append(train_pred.cpu().detach().numpy())
            train_true_list.append(train_targets.numpy())
        train_true = np.concatenate(train_true_list)
        train_pred = np.concatenate(train_pred_list)
        train_auc = auc_roc_score(train_true, train_pred)
        train_loss = np.mean(train_loss)

        test_pred_list = []
        test_true_list = []
        for test_data, test_targets in testloader:
            test_data = test_data.cuda()
            test_pred = model(test_data)
            test_pred_list.append(test_pred.cpu().detach().numpy())
            test_true_list.append(test_targets.numpy())
        test_true = np.concatenate(test_true_list)
        test_pred = np.concatenate(test_pred_list)
        val_auc = auc_roc_score(test_true, test_pred)
        model.train()

        # print results
        print(
            "epoch: %s, lr_x: %.1f, lr_y: %.1f, dataset: %s, imratio: %.2f, train_loss: %.4f, train_auc: %.4f, test_auc: %.4f" % (
                epoch, lr_x, lr_y, dataset, imratio, train_loss, train_auc, val_auc))
        train_auc_log.append(train_auc)
        test_auc_log.append(val_auc)
        train_loss_log.append(train_loss)
        iteration += 1

def mian_TiAda_Adam(imratio,dataset,lr_x,lr_y):

    # HyperParameters
    SEED = 123
    BATCH_SIZE = 128

    total_epochs = 100
    decay_epochs = [50, 75]

    margin = 1.0
    epoch_decay = 0.003  # refers gamma in the paper
    weight_decay = 0.0001

    # oversampling minority class, you can tune it in (0, 0.5]
    # e.g., sampling_rate=0.2 is that num of positive samples in mini-batch is sampling_rate*batch_size=13
    sampling_rate = 0.2

    # load data as numpy arrays
    if dataset == 'C10':
        IMG_SIZE = 32
        train_data, train_targets = CIFAR10(root='./data', train=True)
        test_data, test_targets = CIFAR10(root='./data', train=False)
    elif dataset == 'C100':
        IMG_SIZE = 32
        train_data, train_targets = CIFAR100(root='./data', train=True)
        test_data, test_targets = CIFAR100(root='./data', train=False)
    elif dataset == 'STL10':
        BATCH_SIZE = 32
        IMG_SIZE = 96
        train_data, train_targets = STL10(root='./data/', split='train')
        test_data, test_targets = STL10(root='./data/', split='test')
        train_data = train_data.transpose(0, 2, 3, 1)
        test_data = test_data.transpose(0, 2, 3, 1)
    elif dataset == 'C2':
        IMG_SIZE = 50
        train_data, train_targets = CAT_VS_DOG('./data/', train=True)
        test_data, test_targets = CAT_VS_DOG('./data/', train=False)

    # generate imbalanced data
    generator = ImbalancedDataGenerator(verbose=True, random_seed=0)
    (train_images, train_labels) = generator.transform(train_data, train_targets, imratio=imratio)
    (test_images, test_labels) = generator.transform(test_data, test_targets, imratio=0.5)

    # data augmentations
    trainSet = ImageDataset(train_images, train_labels)
    trainSet_eval = ImageDataset(train_images, train_labels, mode='test')
    testSet = ImageDataset(test_images, test_labels, mode='test')

    # dataloaders
    sampler = DualSampler(trainSet, BATCH_SIZE, sampling_rate=sampling_rate)
    trainloader = torch.utils.data.DataLoader(trainSet, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2)
    trainloader_eval = torch.utils.data.DataLoader(trainSet_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    testloader = torch.utils.data.DataLoader(testSet, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    """# **Creating models & AUC Optimizer**"""
    # You can include sigmoid/l2 activations on model's outputs before computing loss
    model = ResNet20(pretrained=False, last_activation=None, num_classes=1)
    model = model.cuda()

    # You can also pass Loss.a, Loss.b, Loss.alpha to optimizer (for old version users)
    loss_fn = AUCMLoss()
    optimizer = TiAda_Adam(model,
                     loss_fn=loss_fn,
                     margin=margin,
                     lr_x=lr_x,
                     lr_y=lr_y,
                     epoch_decay=epoch_decay,
                     weight_decay=weight_decay)
    """# **Training**"""
    print ('Start Training')
    print ('-'*30)
    train_loss_log = []
    train_auc_log = []
    test_auc_log = []
    iteration = 0
    for epoch in range(total_epochs):
        if epoch in decay_epochs:
            optimizer.update_regularizer(decay_factor=10)  # decrease learning rate by 10x & update regularizer

        train_loss = []
        model.train()
        for data, targets in trainloader:
            data, targets = data.cuda(), targets.cuda()
            y_pred = model(data)
            y_pred = torch.sigmoid(y_pred)
            loss = loss_fn(y_pred, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())

        # evaluation on train & test sets
        model.eval()
        train_pred_list = []
        train_true_list = []
        for train_data, train_targets in trainloader_eval:
            train_data = train_data.cuda()
            train_pred = model(train_data)
            train_pred_list.append(train_pred.cpu().detach().numpy())
            train_true_list.append(train_targets.numpy())
        train_true = np.concatenate(train_true_list)
        train_pred = np.concatenate(train_pred_list)
        train_auc = auc_roc_score(train_true, train_pred)
        train_loss = np.mean(train_loss)

        test_pred_list = []
        test_true_list = []
        for test_data, test_targets in testloader:
            test_data = test_data.cuda()
            test_pred = model(test_data)
            test_pred_list.append(test_pred.cpu().detach().numpy())
            test_true_list.append(test_targets.numpy())
        test_true = np.concatenate(test_true_list)
        test_pred = np.concatenate(test_pred_list)
        val_auc = auc_roc_score(test_true, test_pred)
        model.train()

        print(
            "epoch: %s, lr_x: %.1f, lr_y: %.1f, dataset: %s, imratio: %.2f, train_loss: %.4f, train_auc: %.4f, test_auc: %.4f" % (
                epoch, lr_x, lr_y, dataset, imratio, train_loss, train_auc, val_auc))
        train_auc_log.append(train_auc)
        test_auc_log.append(val_auc)
        train_loss_log.append(train_loss)
        iteration += 1


if __name__ == '__main__':
    freeze_support()
    parser = argparse.ArgumentParser(description='Your script description')
    parser.add_argument('--lr_x', type=float, default=0.3, help='Learning rate x')
    parser.add_argument('--lr_y', type=float, default=0.7, help='Learning rate y')
    parser.add_argument('--beta', type=float, default=0.9, help='beta')
    parser.add_argument('--imratio', type=float, default=0.1, help='imratio')
    parser.add_argument('--dataset', type=str, default='C10', help='dataset')
    parser.add_argument('--optim', type=str, default='TiAda_Adam', help='optimizer')
    args = parser.parse_args()
    imratio = args.imratio  # for demo
    dataset = args.dataset
    beta = args.beta
    lr_x = args.lr_x
    lr_y = args.lr_y

    if args.optim == 'VRAda':
        main_VRAda(imratio, dataset, beta, lr_x, lr_y)
    elif args.optim == 'Adam':
        main_Adam(imratio, dataset, lr_x, lr_y)
    elif args.optim == 'TiAda':
        main_Adam(imratio, dataset,  lr_x, lr_y)
    elif args.optim == 'TiAda_Adam':
        main_Adam(imratio, dataset, lr_x, lr_y)