import torch

from lambda_ac.rl_types import ActorModule, CriticModule


def get_robust_augmentation(
    state: torch.Tensor,
    action: torch.Tensor,
    actor: ActorModule,
    critic: CriticModule,
    epsilon: float = 0.1,
    sign_project: bool = False,
):
    s_0 = torch.normal(state, epsilon**2)
    s_0.requires_grad = True
    action_0, _, _ = actor.forward_sample_log_prob(s_0)
    qf1_0, qf2_0 = critic(s_0, action_0)
    qf1, qf2 = critic(state, action)
    fgsm_loss = torch.mean(0.5 * (qf1_0 - qf1) ** 2 + 0.5 * (qf2_0 - qf2) ** 2, dim=1)
    fgsm_loss.backward()
    s_prime_grad = s_0.grad.data
    if sign_project:
        s_prime = s_0 + epsilon * s_prime_grad.sign()
    else:
        s_prime = s_0 + epsilon * s_prime_grad
    s_prime = torch.clamp(s_prime, min=s_0 - epsilon, max=s_0 + epsilon)

    action_prime, _, _ = actor.forward_sample_log_prob(s_prime)

    return s_prime, action_prime


def get_adversarial_augmentation(
    state: torch.Tensor,
    action: torch.Tensor,
    critic: CriticModule,
    epsilon: float = 0.1,
    max_projection: float = 0.1,
):
    state.requires_grad = True
    qf1, qf2 = critic(state, action)
    torch.mean(qf1 + qf2)
    state = state - torch.clamp(
        epsilon * state.grad.data, min=-max_projection, max=max_projection
    )
    return state, action
