from advertorch.attacks import LinfPGDAttack
from torch import nn


class AdversaryCreator(object):
    """A factory producing adversary.
    Args:
        attack: Name. MIA for MomentumIterativeAttack with Linf norm. LSA for LocalSearchAttack.
        eps: Constraint on the distortion norm
        steps: Number of attack steps
    """
    supported_adv = ['LinfPGD', 'jointLinfPGD', 'LinfPGD20', 'LinfPGD20_eps16', 'LinfPGD100', 'jointLinfPGD100', 'LinfPGD100_eps16',
                     'LinfPGD4_eps4', 'LinfPGD3_eps4', 'LinfPGD7_eps4',  # combined for LinfPGD7_eps8
                     'MIA', 'MIA20', 'MIA20_eps16', 'MIA100', 'MIA100_eps16', 'Free', 'LSA',
                     'LinfAA', 'LinfAA+',
                     'TrnLinfPGD',  # transfer attack
                     ]

    def __init__(self, attack: str, joint_noise_detector=None, **kwargs):
        self.attack = attack
        self.joint_noise_detector = joint_noise_detector
        # eps = 8., steps = 7
        if attack == 'Free':  # only for training.
            self.eps = kwargs.setdefault('eps', 4.)
            self.steps = kwargs.setdefault('steps', 4)
        elif attack == 'Affine':  # only for training.
            self.eps = kwargs.setdefault('eps', 0.1)  # = 1/lambda. Need to tune
            self.steps = kwargs.setdefault('steps', 2)
        else:
            if '_eps' in self.attack:
                self.attack, default_eps = self.attack.split('_eps')
                self.eps = kwargs.setdefault('eps', int(default_eps))
            else:
                self.eps = kwargs.setdefault('eps', 8.)
            if self.attack.startswith('LinfPGD') and self.attack[len('LinfPGD'):].isdigit():
                assert 'steps' not in kwargs, "The steps is set by the attack name while " \
                                              "found additional set in kwargs."
                self.steps = int(self.attack[len('LinfPGD'):])
            elif self.attack.startswith('MIA') and self.attack[len('MIA'):].isdigit():
                assert 'steps' not in kwargs, "The steps is set by the attack name while " \
                                              "found additional set in kwargs."
                self.steps = int(self.attack[len('MIA'):])
            else:
                self.steps = kwargs.setdefault('steps', 7)

    def __call__(self, model, eps=None):
        if eps is None:
            eps = self.eps
        if self.joint_noise_detector is not None:
            from torch.nn.modules.loss import _WeightedLoss
            from torch.nn import functional as F

            class JointCrossEntropyLoss(_WeightedLoss):
                __constants__ = ['ignore_index', 'reduction']
                ignore_index: int

                def __init__(self, weight = None, size_average=None, ignore_index: int = -100,
                             reduce=None, reduction: str = 'mean', alpha=0.5) -> None:
                    super(JointCrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
                    self.ignore_index = ignore_index
                    self.alpha = alpha

                def forward(self, input, target):
                    model_pred, dct_pred = input[:, :-2], input[:, -2:]
                    model_trg, dct_trg = target[:, 0], target[:, 1]
                    return (1 - self.alpha) * F.cross_entropy(model_pred, model_trg, weight=self.weight,
                                           ignore_index=self.ignore_index, reduction=self.reduction) \
                           + self.alpha * F.cross_entropy(dct_pred, dct_trg, weight=self.weight,
                                             ignore_index=self.ignore_index, reduction=self.reduction)

            loss_fn = JointCrossEntropyLoss(reduction="sum", alpha=self.joint_noise_detector)
        else:
            loss_fn = nn.CrossEntropyLoss(reduction="sum")
        if self.attack.startswith('LinfPGD'):
            adv = LinfPGDAttack(
                model, loss_fn=loss_fn, eps=eps / 255,
                nb_iter=self.steps, eps_iter=min(eps / 255 * 1.25, eps / 255 + 4. / 255) / self.steps, rand_init=True,
                clip_min=0.0, clip_max=1.0,
                targeted=False)
        elif self.attack == 'none':
            adv = None
        else:
            raise ValueError(f"attack: {self.attack}")
        return adv

