#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torch.nn import Parameter
from PIL import Image
import numpy as np
from torchvision.models.resnet import ResNet, Bottleneck, BasicBlock
import torchvision.transforms as transforms
from sklearn.metrics import roc_auc_score
from multi_task_sampler import DataSampler
import pandas as pd
import os

train_trans = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

test_trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

def resnet18(num_classes):
    model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
    return model

def pAUC_two_metric(target, pred, max_fpr):
    target = target.reshape(-1)
    pred = pred.reshape(-1)

    idx_pos = np.where(target == 1)[0]
    idx_neg = np.where(target != 1)[0]

    num_pos = round(len(idx_pos) * max_fpr)
    num_neg = round(len(idx_neg) * max_fpr)

    if num_pos < 1: num_pos = 1
    if num_neg < 1: num_pos = 1

    selected_arg_pos = np.argpartition(pred[idx_pos], num_pos)[:num_pos]
    selected_arg_neg = np.argpartition(-pred[idx_neg], num_neg)[:num_neg]

    selected_target = np.concatenate((target[idx_pos][selected_arg_pos], target[idx_neg][selected_arg_neg]))
    selected_pred = np.concatenate((pred[idx_pos][selected_arg_pos], pred[idx_neg][selected_arg_neg]))

    pAUC_score = roc_auc_score(selected_target, selected_pred)

    return pAUC_score


class CelebaDataset(Dataset):
    """Custom Dataset for loading CelebA face images"""

    def __init__(self, csv_path, img_dir, transform=None):
        df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
        self.csv_path = csv_path
        self.img_names = df.index.values
        # self.y = df['Male'].values
        self.y = df.values
        self.transform = transform

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,
                                      self.img_names[index]))

        if self.transform is not None:
            img = self.transform(img)
        #  Class Wavy_Hair index 33
        #  Attribute Male index 20

        return index, img, self.y[index]  # index, self.img_names[index],

    def __len__(self):
        return self.y.shape[0]


def load_celeba(batch_size, sampler=False):
    data_dir = '/dual_data/not_backed_up/CelebA/'
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = CelebaDataset(
        data_dir + 'celeba_attr_train.csv',
        data_dir + 'img_align_celeba/img_align_celeba/',
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    val_dataset = CelebaDataset(data_dir + 'celeba_attr_val.csv', data_dir + 'img_align_celeba/img_align_celeba/',
                                transforms.Compose([
                                    transforms.ToTensor(),
                                    normalize,
                                ]))

    test_dataset = CelebaDataset(data_dir + 'celeba_attr_test.csv', data_dir + 'img_align_celeba/img_align_celeba/',
                                 transforms.Compose([
                                     transforms.ToTensor(),
                                     normalize,
                                 ]))

    if sampler:
        train_labels = train_dataset.y
        train_sampler = DataSampler(train_labels, batchSize=batch_size, multi_tasks=10)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
        num_workers=4, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True)
    print(len(train_dataset), len(val_dataset), len(test_dataset))
    return train_loader, val_loader, test_loader


def focal_loss(input_values, alpha, gamma, reduction = 'mean'):
    """Computes the focal loss"""

    '''
    input_values = -\log(p_t)
    '''
    p = torch.exp(-input_values)
    loss = alpha * (1 - p) ** gamma * input_values

    if reduction == 'none':
        return loss
    else:
        return loss.mean()

class FocalLoss(nn.Module):
    def __init__(self, alpha = 1, gamma=0, reduction = 'mean'):
        super(FocalLoss, self).__init__()
        assert gamma >= 0
        self.gamma = gamma
        self.reduction = reduction
        self.alpha = alpha

    def forward(self, input, target, weight=None):
            return focal_loss(F.binary_cross_entropy(input.view(-1), target.view(-1), reduction='none', weight=weight), self.alpha, self.gamma, reduction = self.reduction)

def partial_auc(y_pred, y_true, max_fpr=None):
    classes = y_pred.shape[1]
    pauc = 0
    for i in range(classes):
        pauc += roc_auc_score(y_pred[:, i], y_true[:, i], max_fpr=max_fpr)
    return pauc / classes


def pretrain(model, train_loader, test_loader, file_path):
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), weight_decay=5e-4)
    for epoch in range(70):
        total_loss1 = 0
        model.train()
        m = nn.Sigmoid()
        for i, (idx, img, label) in enumerate(train_loader):
            img, label = img.cuda(), label.float().cuda()
            output = model(img)
            loss = criterion(m(output), label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss1 += loss.item()

        total_loss2 = 0
        model.eval()
        with torch.no_grad():
            for j, (idx, img, label) in enumerate(test_loader):
                img, label = img.cuda(), label.float().cuda()
                output = model(img)
                loss = criterion(m(output), label)

                total_loss2 += loss.item()

        print('Epoch: %5d, tr_loss: %5.3f, te_loss: %5.3f' % (epoch + 1, total_loss1 / (i + 1), total_loss2 / (j + 1)))

    torch.save(model.state_dict(), file_path)


def pretrain_cifar(model, train_loader, test_loader, file_path):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), weight_decay=5e-4)
    for epoch in range(70):
        total_loss1 = 0
        model.train()
        for i, (idx, img, label) in enumerate(train_loader):
            img, label = img.cuda(), label.cuda()
            output = model(img)
            loss = criterion(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss1 += loss.item()

        total_loss2 = 0
        model.eval()
        with torch.no_grad():
            for j, (idx, img, label) in enumerate(test_loader):
                img, label = img.cuda(), label.cuda()
                output = model(img)
                loss = criterion(output, label)

                total_loss2 += loss.item()

        print('Epoch: %5d, tr_loss: %5.3f, te_loss: %5.3f' % (epoch + 1, total_loss1 / (i + 1), total_loss2 / (j + 1)))

    torch.save(model.state_dict(), file_path)


class ImageDataset(Dataset):
    def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
        self.images = images.astype(np.uint8)
        self.targets = targets
        self.mode = mode
        self.transform_train = train_trans
        self.transform_test = test_trans

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        image = Image.fromarray(image.astype('uint8'))
        if self.mode == 'train':
            image = self.transform_train(image)
        else:
            image = self.transform_test(image)
        return idx, image, target


class ImageDataset1(Dataset):
    def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
        self.images = images.astype(np.uint8)
        self.targets = targets
        self.mode = mode
        self.transform_train = train_trans
        self.transform_test = test_trans

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        image = Image.fromarray(image.astype('uint8'))
        if self.mode == 'train':
            image = self.transform_train(image)
        else:
            image = self.transform_test(image)
        return image, target


class PAUC_MultiLabel(torch.nn.Module):

    def __init__(self, margin=1.0, num_classes=10, eta1=0.1, eta2=1e-3, beta=0.1, beta0=0.9, beta1=0.9, tau1=1,
                 tau2=0.2, device=None):
        super(PAUC_MultiLabel, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        self.margin = margin
        self.num_classes = num_classes
        self.a = Parameter(torch.zeros([num_classes, 1], dtype=torch.float32).to(self.device))
        self.b = Parameter(torch.zeros([num_classes, 1], dtype=torch.float32).to(self.device))
        self.Hessian = torch.zeros([num_classes, 1], dtype=torch.float32).to(self.device)
        self.Lambda = torch.zeros([num_classes, 1], dtype=torch.float32).to(self.device)
        self.Alpha = torch.zeros([num_classes, 1], dtype=torch.float32).to(self.device)
        self.eta1, self.eta2, self.beta0, self.beta1 = eta1, eta2, beta0, beta1
        self.beta, self.tau1, self.tau2 = beta, tau1, tau2
        print('loaded!')

    def forward(self, y_pred, y_true, index_s):
        task_ids = torch.nonzero((((y_true == 1).sum(dim=0) > 0) & ((y_true == 0).sum(dim=0) > 0)), as_tuple=True)[0]
        mask = ~torch.isnan(y_true[:, task_ids])

        with torch.no_grad():
            task_pred = y_pred[:, task_ids] * mask
            task_true = y_true[:, task_ids] * mask

            h_w_p = task_pred * (task_true == 1)
            h_w_n = task_pred * (task_true == 0)
            mask_p = (task_true == 1).float()
            mask_n = (task_true == 0).float()

            # update alpha        
            G_alpha = 2 * (
                    (torch.sigmoid(task_pred - self.Lambda[task_ids].view(1, -1)) * h_w_n).sum(0, keepdim=True) / (
                    self.beta * (mask_n.sum(0, keepdim=True) + 1e-6))
                    - h_w_p.sum(0, keepdim=True) / mask_p.sum(0, keepdim=True) + self.margin).view(-1, 1) - 2 * \
                      self.Alpha[task_ids]
            self.Alpha[task_ids] = torch.nn.functional.relu(self.Alpha[task_ids] + self.eta1 * G_alpha)

            # update lambda
            G_lambda = self.beta + self.tau2 * self.Lambda[task_ids] - ((mask_n * torch.sigmoid((task_pred - self.Lambda[task_ids].view(1, -1)) / self.tau1)).sum(0) / (mask_n.sum(0) + 1e-6)).view(-1, 1)
            self.Lambda[task_ids] -= self.eta2 * G_lambda

            # update hessian
            G_hessian = self.tau2 + ((mask_n * (
                    torch.exp((task_pred - self.Lambda[task_ids].view(1, -1)) / self.tau1) / torch.square(
                1 + torch.exp((task_pred - self.Lambda[task_ids].view(1, -1)) / self.tau1)))).sum(0) / (
                                self.tau1 * mask_n.sum(0) + 1e-6)).view(-1, 1)
            self.Hessian[task_ids] = (1 - self.beta1) * self.Hessian[task_ids] + self.beta1 * G_hessian

        # compute G_loss
        G_loss = 0
        for task_id in task_ids:
            mask = ~torch.isnan(y_true[:, task_id])
            task_pred = y_pred[:, task_id][mask]
            task_true = y_true[:, task_id][mask]

            h_ps = task_pred[task_true == 1].reshape(-1)

            h_ns = task_pred[task_true == 0].reshape(-1)


            with torch.no_grad():
                tmp = torch.sigmoid(h_ns.detach() - self.Lambda[task_id])
                tmp1 = tmp.mul(1 - tmp)
                tmp4 = torch.square(h_ns.detach() - self.b[task_id].detach())
                tmp5 = h_ns.detach()

            term1 = torch.square(h_ps - self.a[task_id]).mean()
            term2 = torch.mean(tmp * torch.square(h_ns - self.b[task_id]))/self.beta


            tmp3 = torch.sigmoid((h_ns - self.Lambda[task_id]) / self.tau1).mean()
            L_phi = tmp1.mul(h_ns - tmp3 / self.Hessian[task_id])

            term3 = torch.mean(L_phi.mul(tmp4)) / self.beta
            term4 = torch.mean(tmp.mul(h_ns)) / self.beta


            term5 = torch.mean(L_phi.mul(tmp5))
            term6 = torch.mean(h_ps)

            G_loss += term1 + term2 + term3 + 2 * self.Alpha[task_id] * (term4 + term5 - term6 + self.margin)

            if torch.isnan(G_loss):
                print('stop')
        G_loss = G_loss / len(task_ids)

        return G_loss


class pAUC_mini(nn.Module):
    def __init__(self, threshold, gamma, loss_type='sqh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(pAUC_mini, self).__init__()
        self.gamma = gamma
        #self.num_neg = num_neg
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)

    def forward(self, y_pred, y_true, index_s):
        task_ids = torch.nonzero((((y_true == 1).sum(dim=0) > 0) & ((y_true == 0).sum(dim=0) > 0)), as_tuple=True)[0]
        loss = 0
        for task_id in task_ids:
            mask = ~torch.isnan(y_true[:, task_id])
            task_pred = y_pred[:, task_id][mask]
            task_true = y_true[:, task_id][mask]

            num_neg = (task_true==0).sum()

            f_ps = task_pred[task_true == 1].reshape(-1)
            f_ns = task_pred[task_true == 0].reshape(-1)

            partial_arg = torch.topk(f_ns, int(num_neg * self.gamma), sorted=False)[1]
            vec_dat = f_ns[partial_arg]
            mat_data = vec_dat.repeat(len(f_ps), 1)

            f_ps = f_ps.view(-1, 1)

            if self.loss_type == 'sqh':
                neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2


            loss += torch.mean(neg_loss)

        return loss
