import torch
import torch.nn as nn
import numpy as np


class ACELoss(nn.Module):
    def __init__(self, topn=1):
        super(ACELoss, self).__init__()
        self.criterion = torch.nn.CrossEntropyLoss()
        self.topn = topn

    def forward(self, outputs, ys, outputs_victim, targeted=False, target_labels=None):
        if targeted==False:
            loss = self.criterion(outputs, ys)
            outputs_victim_arr = outputs_victim.cpu().detach().numpy()
            for i in range(self.topn):
                max_index = np.argmax(outputs_victim_arr, axis=1)
                min_value = np.min(outputs_victim_arr, axis=1)
                for i in range(len(max_index)):
                    outputs_victim_arr[i, max_index[i]] = min_value[i]
                topn_labels = np.argmax(outputs_victim_arr, axis=1)
                topn_labels = torch.tensor(topn_labels).to(outputs.device)
                loss -= self.criterion(outputs, topn_labels)/self.topn
        else:
            loss = -self.criterion(outputs, target_labels)
            loss += self.criterion(outputs, ys)/self.topn
            outputs_victim_arr = outputs_victim.cpu().detach().numpy()
            for i in range(self.topn-1):
                max_index = np.argmax(outputs_victim_arr, axis=1)
                min_value = np.min(outputs_victim_arr, axis=1)
                for i in range(len(max_index)):
                    outputs_victim_arr[i, max_index[i]] = min_value[i]
                topn_labels = np.argmax(outputs_victim_arr, axis=1)
                topn_labels = torch.tensor(topn_labels).to(outputs.device)
                loss += self.criterion(outputs, topn_labels) / self.topn

        return loss
