import torch
from .restartattack import RestartAttack
from .utils import project_perturbation, normalize_perturbation, initialize_perturbation


class FGM(RestartAttack):
    # one step attack with l2 or inf norm constraint
    def __init__(
        self,
        eps,
        num_classes,
        norm="inf",
        loss="CrossEntropy",
        restarts=0,
        init_noise_generator=None,
        model=None,
        save_trajectory=False,
    ):
        super().__init__(
            loss, restarts, num_classes, model=model, save_trajectory=save_trajectory
        )
        self.eps = eps
        self.norm = norm
        self.init_noise_generator = init_noise_generator

    def _get_trajectory_depth(self):
        return 2

    def get_config_dict(self):
        dict = {}
        dict["type"] = "FGM"
        dict["eps"] = self.eps
        dict["norm"] = self.norm
        if isinstance(self.loss, str):
            dict["loss"] = self.loss
        else:
            dict["loss"] = "custom"
        dict["restarts"] = self.restarts
        return dict

    def perturb_inner(self, x, y, targeted=False, x_init=None):
        l_f = self._get_loss_f(x, y, targeted, "none")

        pert = initialize_perturbation(
            x, self.eps, self.norm, x_init, self.init_noise_generator
        )

        pert.requires_grad_(True)

        with torch.enable_grad():
            p_data = x + pert
            out = self.model(p_data)
            loss_expanded = l_f(p_data, out)
            loss = loss_expanded.mean()
            grad = torch.autograd.grad(loss, pert)[0]

        with torch.no_grad():
            pert = project_perturbation(
                pert - self.eps * normalize_perturbation(grad, self.norm),
                self.eps,
                self.norm,
            )
            p_data = x + pert
            p_data = torch.clamp(p_data, 0, 1)
            final_loss = l_f(p_data, self.model(p_data))

            if self.save_trajectory:
                trajectory = torch.zeros((2,) + x.shape, device=x.device)
                trajectory[0, :] = x
                trajectory[1, :] = p_data
            else:
                trajectory = None

        return p_data, final_loss, trajectory
