import time

import torch
from PIL import Image
import numpy as np
import time, os, sys, copy
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score
from torch.autograd import Variable
import torch.nn.functional as F
from datetime import datetime as dt
from datetime import timedelta as td

class AUCLoss():

    def __init__(self, imratio, m=1.0):
        self.p = imratio
        self.m = m

    def g1(self, outputs, a, b, targets):
        loss_val = (1 - self.p) * torch.mean((outputs - a) ** 2 * (1 == targets).float()) + \
                   self.p * torch.mean((outputs - b) ** 2 * (0 == targets).float())
        return loss_val

    def g1_grad_a(self, outputs, a, targets):
        grad_val = -2 * (1 - self.p) * torch.mean((outputs - a) * (1 == targets).float())
        return grad_val

    def g1_grad_b(self, outputs, b, targets):
        grad_val = -2 * self.p * torch.mean((outputs - b) * (0 == targets).float())
        return grad_val

    def g2(self, outputs, targets):
        loss_val = -2 * (1 - self.p) * torch.mean(outputs * (1 == targets).float()) + \
                   2 * self.p * torch.mean(outputs * (0 == targets).float()) + \
                   2 * self.p * (1-self.p) * self.m
        return loss_val

    def g3(self, alpha):
        return self.p * (1 - self.p) * alpha ** 2

    def g3_grad(self, alpha):
        return 2 * self.p * (1 - self.p) * alpha

class AUCLoss_multiLabel():

    def __init__(self, imratio, m=1.0):
        self.p = imratio
        self.m = m

    def g1(self, outputs, a, b, targets, task=0):
        p_i = self.p[task]
        outputs_i = outputs[:,task]
        a_i = a[task]
        b_i = b[task]
        targets_i = targets[:,task]
        
        loss_val = (1 - p_i) * torch.mean((outputs_i - a_i) ** 2 * (1 == targets_i).float()) + \
                   p_i * torch.mean((outputs_i - b_i) ** 2 * (0 == targets_i).float())
        return loss_val

    def g1_grad_a(self, outputs, a, targets, task=0):
        p_i = self.p[task]
        outputs_i = outputs[:, task]
        a_i = a[task]
        targets_i = targets[:, task]
        grad_val = -2 * (1 - p_i) * torch.mean((outputs_i - a_i) * (1 == targets_i).float())
        return grad_val

    def g1_grad_b(self, outputs, b, targets, task=0):
        p_i = self.p[task]
        outputs_i = outputs[:, task]
        b_i = b[task]
        targets_i = targets[:, task]
        grad_val = -2 * p_i * torch.mean((outputs_i - b_i) * (0 == targets_i).float())
        return grad_val

    def g2(self, outputs, targets, task=0):
        p_i = self.p[task]
        outputs_i = outputs[:, task]
        targets_i = targets[:, task]
        loss_val = -2 * (1 - p_i) * torch.mean(outputs_i * (1 == targets_i).float()) + \
                   2 * p_i * torch.mean(outputs_i * (0 == targets_i).float()) + \
                   2 * p_i * (1-p_i) * self.m
        return loss_val

    def g3(self, alpha, task=0):
        p_i = self.p[task]
        alpha_i = alpha[task]
        return p_i * (1 - p_i) * alpha_i ** 2

    def g3_grad(self, alpha, task=0):
        p_i = self.p[task]
        alpha_i = alpha[task]
        return 2 * p_i * (1 - p_i) * alpha_i

class AUCM_MultiLabel_selectTasks(torch.nn.Module):

    def __init__(self, margin=1.0, imratio=[0.1], num_classes=10, device=None):
        super(AUCM_MultiLabel_selectTasks, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        self.margin = margin
        self.p = torch.FloatTensor(imratio).to(self.device)
        self.num_classes = num_classes
        assert len(imratio)==num_classes, 'Length of imratio needs to be same as num_classes!'
        self.a = torch.zeros(num_classes, dtype=torch.float32, device="cuda", requires_grad=True).to(self.device)
        self.b = torch.zeros(num_classes, dtype=torch.float32, device="cuda", requires_grad=True).to(self.device)
        self.alpha = torch.zeros(num_classes, dtype=torch.float32, device="cuda", requires_grad=True).to(self.device)

    @property
    def get_a(self):
        return self.a.mean()
    @property
    def get_b(self):
        return self.b.mean()
    @property
    def get_alpha(self):
        return self.alpha.mean()

    def forward(self, y_pred, y_true, selectTasks):
        total_loss = 0
        for idx in selectTasks:
            y_pred_i = y_pred[:, idx].reshape(-1, 1)
            y_true_i = y_true[:, idx].reshape(-1, 1)
            loss = (1-self.p[idx])*torch.mean((y_pred_i - self.a[idx])**2*(1==y_true_i).float()) + \
                        self.p[idx]*torch.mean((y_pred_i - self.b[idx])**2*(0==y_true_i).float())   + \
                        2*self.alpha[idx]*(self.p[idx]*(1-self.p[idx]) + \
                        torch.mean((self.p[idx]*y_pred_i*(0==y_true_i).float() - (1-self.p[idx])*y_pred_i*(1==y_true_i).float())) )- \
                        self.p[idx]*(1-self.p[idx])*self.alpha[idx]**2
            total_loss += loss
        return total_loss

class CrossEntropyBinaryLoss(torch.nn.Module):
    def __init__(self):
        super(CrossEntropyBinaryLoss, self).__init__()
        self.criterion = F.binary_cross_entropy_with_logits  # with sigmoid

    def forward(self, y_pred, y_true):
        return self.criterion(y_pred, y_true)


class CrossEntropyBinaryLoss_MultiLabel(torch.nn.Module):
    def __init__(self, num_classes=10):
        super(CrossEntropyBinaryLoss_MultiLabel, self).__init__()
        self.num_classes = num_classes
        self.criterion = F.binary_cross_entropy_with_logits  # with sigmoid

    def forward(self, y_pred, y_true, selectTasks):
        total_loss = 0
        for idx in selectTasks:
            # 这里选出来了其中一个值
            # 控制其中的一个或者几个值影响结果
            y_pred_i = y_pred[:, idx].reshape(-1, 1)
            y_true_i = y_true[:, idx].reshape(-1, 1).float()
            loss = self.criterion(y_pred_i, y_true_i)
            total_loss += loss
        return total_loss