""" Much of the code for this project was taken from the publicly available
    'Witches' Brew' repository <https://github.com/JonasGeiping/poisoning-gradient-matching>.
    As such, we include below relevant functions/classes to adapt to our poison dataset
    method. NOTE: THIS CODE DOES NOT RUN ON ITS OWN. IT INTEGRATES INTO THE
    REPOSITORY LISTED ABOVE.
"""

# Below is the main mechanism for crafting the perturbations.
class WitchGradientMatching(_Witch):
    """ This adapts the WitchMatching class to poison dataset. We call our crafting
        class 'Forgemaster'
    """

    def _define_objective(self, inputs, labels, targets, intended_classes=None, true_classes=None):
        """Implement the closure here."""
        def closure(model, criterion, optimizer, target_grad, target_gnorm):
            """This function will be evaluated on all GPUs.
               inputs = clean_inputs + perturbations, so the perturbations are included
               in the computation graph when alignment_loss.backward() is called.
            """
            outputs = model(inputs)
            if self.args.target_criterion in ['cw', 'carlini-wagner']:
                criterion = cw_loss
            else:
                pass
            poison_loss = criterion(outputs, labels)
            prediction = (outputs.data.argmax(dim=1) == labels).sum()
            poison_grad = torch.autograd.grad(poison_loss, model.parameters(), retain_graph=True, create_graph=True)

            alignment_loss = self._alignment_loss(poison_grad, target_grad, target_gnorm)
            alignment_loss.backward(retain_graph=self.retain)
            return alignment_loss.detach().cpu(), prediction.detach().cpu()
        return closure

    def _alignment_loss(self, poison_grad, target_grad, target_gnorm):
        """Compute the alignment loss term as in Eq. (3)."""
        alignment_loss = 0
        poison_norm = 0
        indices = torch.arange(len(target_grad))
        SIM_TYPE = ['similarity', 'similarity-narrow', 'top5-similarity', 'top10-similarity', 'top20-similarity']
        for i in indices:
            if self.args.loss in ['scalar_product', *SIM_TYPE]: #Default loss is 'similarity'
                alignment_loss -= (target_grad[i] * poison_grad[i]).sum()
            elif self.args.loss == 'cosine1':
                alignment_loss -= torch.nn.functional.cosine_similarity(target_grad[i].flatten(), poison_grad[i].flatten(), dim=0)
            elif self.args.loss == 'SE':
                alignment_loss += 0.5 * (target_grad[i] - poison_grad[i]).pow(2).sum()
            elif self.args.loss == 'MSE':
                alignment_loss += torch.nn.functional.mse_loss(target_grad[i], poison_grad[i])

        alignment_loss = alignment_loss / target_gnorm  # this is a constant

        if self.args.loss in SIM_TYPE:
            if self.args.independent_brewing:
                # Detach denominator to perform independent_brewing
                alignment_loss = 1 + alignment_loss / poison_norm.sqrt().detach()
            else:
                alignment_loss = 1 + alignment_loss / poison_norm.sqrt()

        return alignment_loss


# Here we include code for calculating reverse cross entropy loss. This is used
# when calculating the target gradient.
def reverse_xent_avg(outputs, intended_classes):
    max_exp = outputs.max(dim=1, keepdim=True)[0]
    denominator = torch.log(torch.exp(outputs - max_exp).sum(dim=1)) + max_exp
    other_class_map = torch.tensor([[i for i in range(outputs.shape[1]) if i!=j] for j in range(outputs.shape[1])], device=intended_classes.device)
    selected_indices = other_class_map[intended_classes]
    other_outputs = outputs.gather(dim=1, index=selected_indices)
    other_max_exp = other_outputs.max(dim=1, keepdim=True)[0]
    numerator = -torch.log(torch.exp(other_outputs - other_max_exp).sum(dim=1)) - other_max_exp
    return torch.mean(numerator + denominator)
