# Codes imported and modified from FARE:
# https://github.com/chs20/RobustVLM/blob/main/train/adversarial_training_clip.py

import torch

import sys

sys.path.append("RobustVLM")
from vlm_eval.attacks.utils import project_perturbation, normalize_grad


def pgd(
    forward,
    images_normalize,  # added
    loss_fn,
    data_clean,
    targets,
    norm,
    eps,
    iterations,
    stepsize,
    output_normalize,
    perturbation=None,
    mode="min",
    momentum=0.9,
    verbose=False,
):
    """
    Minimize or maximize given loss

    # modified
    """
    # make sure data is in image space
    assert torch.max(data_clean) < 1.0 + 1e-6 and torch.min(data_clean) > -1e-6

    if perturbation is None:
        perturbation = torch.zeros_like(data_clean, requires_grad=True)
    velocity = torch.zeros_like(data_clean)
    for i in range(iterations):
        perturbation.requires_grad = True
        with torch.enable_grad():
            # out = forward(data_clean + perturbation, output_normalize=output_normalize)
            # loss = loss_fn(out, targets)
            outs = forward(images_normalize(data_clean + perturbation)) # added
            if output_normalize:
                raise ValueError("output_normalize not used in original paper")
                outs = outs / outs.norm(dim=-1, keepdim=True)
                targets = targets / targets.norm(dim=-1, keepdim=True)
            loss = loss_fn(outs, targets)

            if verbose:
                print(f"[{i}] {loss.item():.5f}")

        with torch.no_grad():
            gradient = torch.autograd.grad(loss, perturbation)[0]
            gradient = gradient
            if gradient.isnan().any():  #
                print(f"attention: nan in gradient ({gradient.isnan().sum()})")  #
                gradient[gradient.isnan()] = 0.0
            # normalize
            gradient = normalize_grad(gradient, p=norm)
            # momentum
            velocity = momentum * velocity + gradient
            velocity = normalize_grad(velocity, p=norm)
            # update
            if mode == "min":
                perturbation = perturbation - stepsize * velocity
            elif mode == "max":
                perturbation = perturbation + stepsize * velocity
            else:
                raise ValueError(f"Unknown mode: {mode}")
            # project
            perturbation = project_perturbation(perturbation, eps, norm)
            perturbation = torch.clamp(data_clean + perturbation, 0, 1) - data_clean  # clamp to image space
            assert not perturbation.isnan().any()
            assert (
                torch.max(data_clean + perturbation) < 1.0 + 1e-6
                and torch.min(data_clean + perturbation) > -1e-6
            )

            # assert (ctorch.compute_norm(perturbation, p=self.norm) <= self.eps + 1e-6).all()
    # todo return best perturbation
    # problem is that model currently does not output expanded loss
    return data_clean + perturbation.detach()
