import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

from loss.EFCE_nT_loss import EFCEnTLoss

'''
WAEFCEnTLoss: the query-prior and eliminating fuzzy domain based cross-entropy loss for non-Target attacks
'''
class WeightCrossEntropyLoss(nn.Module):
    def __init__(self, temperature_scale=1.0, num_classes=10):
        super(WeightCrossEntropyLoss, self).__init__()
        self.num_classes = num_classes
        self.temperature_scale = temperature_scale

    def forward(self, outputs, topn_max_labels, topn_max_values):
        epsilon = 1e-10

        outputs = outputs/self.temperature_scale
        pred = F.softmax(outputs, dim=1)
        weights = F.softmax(topn_max_values, dim=1)
        cel = torch.tensor(0.).to(outputs.device)
        for i in range(topn_max_values.shape[1]):
            topn_label = topn_max_labels[:, i]
            weight = weights[:, i]
            y_onehot = F.one_hot(topn_label, self.num_classes)
            pred = (pred+epsilon) / (1.0+epsilon*self.num_classes)
            cel_ = -y_onehot*torch.log(pred)
            cel_ = cel_.sum(dim=1)

            cel += torch.mean(weight*cel_)
        return cel


class WAEFCEnTLoss(nn.Module):
    def __init__(self, topn=1, num_classes=10, temperature_scale=1.0, fuzzy_scale=1.0):
        super(WAEFCEnTLoss, self).__init__()
        self.criterion = EFCEnTLoss(temperature_scale=temperature_scale, fuzzy_scale=fuzzy_scale)
        self.wceloss = WeightCrossEntropyLoss(temperature_scale=temperature_scale, num_classes=num_classes)
        self.topn = topn

    def forward(self, outputs, ys, outputs_victim):
        # ys对应输出值最小化
        loss = self.criterion(outputs, ys)
        outputs_victim_arr = outputs_victim.detach().cpu().numpy()
        topn_max_values = torch.zeros(outputs.shape[0], self.topn).to(outputs.device)
        topn_max_labels = torch.zeros(outputs.shape[0], self.topn).to(outputs.device)
        for j in range(self.topn):
            # ys外的最高概率类别输出值最大化
            if j == 0:
                max_index = ys.detach().cpu().numpy()
                min_value = np.min(outputs_victim_arr, axis=1)
            else:
                max_index = np.argmax(outputs_victim_arr, axis=1)
                min_value = np.min(outputs_victim_arr, axis=1)
            for i in range(len(max_index)):
                outputs_victim_arr[i, max_index[i]] = min_value[i]
            topn_labels = np.argmax(outputs_victim_arr, axis=1)
            topn_labels = torch.tensor(topn_labels).to(outputs.device)
            topn_values = outputs_victim_arr.max(axis=1)
            topn_values = torch.tensor(topn_values).to(outputs.device)
            topn_max_labels[:, j] = topn_labels
            topn_max_values[:, j] = topn_values
        topn_max_labels = topn_max_labels.to(torch.int64)
        loss -= self.wceloss(outputs, topn_max_labels, topn_max_values)

        return loss
