import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

from loss.EECE_Loss import EECELoss

class WeightCrossEntropyLoss(nn.Module):
    def __init__(self, num_classes=10):
        super(WeightCrossEntropyLoss, self).__init__()
        self.num_classes = num_classes

    def forward(self, outputs, topn_max_labels, topn_max_values):
        epsilon = 1e-10

        pred = F.softmax(outputs, dim=1)
        weights = F.softmax(topn_max_values, dim=1)
        cel = torch.tensor(0.).to(outputs.device)
        for i in range(topn_max_values.shape[1]):
            topn_label = topn_max_labels[:, i]
            weight = weights[:, i]
            y_onehot = F.one_hot(topn_label, self.num_classes)
            pred = (pred+epsilon) / (1.0+epsilon*self.num_classes)
            cel_ = -y_onehot*torch.log(pred)
            cel_ = cel_.sum(dim=1)

            cel += torch.mean(weight*cel_)
        return cel


class EWACELoss(nn.Module):
    def __init__(self, topn=1, num_classes=10):
        super(EWACELoss, self).__init__()
        self.criterion = EECELoss()
        self.wceloss = WeightCrossEntropyLoss(num_classes=num_classes)
        self.topn = topn

    def forward(self, outputs, ys, outputs_victim, targeted=False, target_labels=None):
        if targeted==False:
            # ys对应输出值最小化
            loss = self.criterion(outputs, ys, outputs_victim)
            outputs_victim_arr = outputs_victim.detach().cpu().numpy()
            topn_max_values = torch.zeros(outputs.shape[0], self.topn).to(outputs.device)
            topn_max_labels = torch.zeros(outputs.shape[0], self.topn).to(outputs.device)
            for j in range(self.topn):
                # ys外的最高概率类别输出值最大化
                if j == 0:
                    max_index = ys.detach().cpu().numpy()
                    min_value = np.min(outputs_victim_arr, axis=1)
                else:
                    max_index = np.argmax(outputs_victim_arr, axis=1)
                    min_value = np.min(outputs_victim_arr, axis=1)
                for i in range(len(max_index)):
                    outputs_victim_arr[i, max_index[i]] = min_value[i]
                topn_labels = np.argmax(outputs_victim_arr, axis=1)
                topn_labels = torch.tensor(topn_labels).to(outputs.device)
                topn_values = outputs_victim_arr.max(axis=1)
                topn_values = torch.tensor(topn_values).to(outputs.device)
                topn_max_labels[:, j] = topn_labels
                topn_max_values[:, j] = topn_values
            topn_max_labels = topn_max_labels.to(torch.int64)
            loss -= self.wceloss(outputs, topn_max_labels, topn_max_values)
        else:
            # not implemented
            pass
        #     loss = -self.criterion(outputs, target_labels)
        #     topn_max_values = torch.zeros(outputs.shape[0], self.topn).to(outputs.device)
        #     topn_max_labels = torch.zeros(outputs.shape[0], self.topn).to(outputs.device)
        #     # 真实标签
        #     topn_max_labels[:, 0] = ys
        #     topn_max_values[:, 0] = outputs.max(dim=1)[0]
        #
        #     outputs_victim_arr = outputs_victim.cpu().detach().numpy()
        #     for j in range(self.topn-1):
        #         max_index = np.argmax(outputs_victim_arr, axis=1)
        #         min_value = np.min(outputs_victim_arr, axis=1)
        #         for i in range(len(max_index)):
        #             outputs_victim_arr[i, max_index[i]] = min_value[i]
        #         topn_labels = np.argmax(outputs_victim_arr, axis=1)
        #         topn_labels = torch.tensor(topn_labels).to(outputs.device)
        #         topn_values = outputs_victim_arr.max(axis=1)
        #         topn_values = torch.tensor(topn_values).to(outputs.device)
        #         topn_max_labels[:, j+1] = topn_labels
        #         topn_max_values[:, j+1] = topn_values
        #     topn_max_labels = topn_max_labels.to(torch.int64)
        #     loss += self.wceloss(outputs, topn_max_labels, topn_max_values)

        return loss
