from __future__ import annotations

import jax
import jax.numpy as jnp
from functools import partial
from dataclasses import dataclass
from typing import Callable, Tuple, Optional, Any
from utils import load_sarsa
Array = jax.Array

def random_attack(obs: jnp.ndarray, rng: jax.Array, epsilon: float):
    """
    Random L∞ attack: adds uniform noise in [-epsilon, epsilon] to the observation.
    """
    new_rng, subkey = jax.random.split(rng)
    noise = jax.random.uniform(
        subkey,
        shape=obs.shape,
        minval=-float(epsilon),
        maxval=float(epsilon),
        dtype=obs.dtype if jnp.issubdtype(obs.dtype, jnp.floating) else jnp.float32,
    )
    corrupted = obs + noise
    return corrupted, new_rng


@dataclass(frozen=True)
class PGDConfig:
    steps: int = 10
    step_eps: Optional[float] = None  # if None, uses epsilon / steps
    use_sgld: bool = False
    init: str = "none"                # {"none", "uniform", "gaussian-sign"}


def _pgd_linf_optimize(
    *,
    obs: jax.Array,
    rng: jax.random.PRNGKey,
    epsilon: float,
    objective: Callable[[jax.Array], jax.Array],
    cfg: PGDConfig,
) -> Tuple[jax.Array, jax.random.PRNGKey]:
    steps = int(cfg.steps)
    if steps <= 0:
        return obs, rng

    seps = (epsilon / steps) if cfg.step_eps is None else float(cfg.step_eps)
    clamp_min, clamp_max = obs - epsilon, obs + epsilon

    rng, sub = jax.random.split(rng)
    if cfg.init == "uniform":
        noise0 = jax.random.uniform(sub, shape=obs.shape, minval=-epsilon, maxval=epsilon)
        s_hat = obs + noise0
    elif cfg.init == "gaussian-sign":
        noise0 = jax.random.normal(sub, shape=obs.shape)
        s_hat = obs + jnp.sign(noise0) * seps
    else:
        s_hat = obs

    def step_fn(i, carry):
        s_hat, rng = carry
        g = jax.grad(objective)(s_hat)

        if cfg.use_sgld:
            rng, sub = jax.random.split(rng)
            noise_scale = jnp.sqrt(2.0 * seps) / (i.astype(jnp.float32) + 2.0)
            z = jax.random.normal(sub, shape=obs.shape) * noise_scale
            ascent_dir = g + z
        else:
            ascent_dir = g

        s_hat = s_hat + jnp.sign(ascent_dir) * seps
        s_hat = jnp.clip(s_hat, clamp_min, clamp_max)
        return (s_hat, rng)

    s_hat, rng = jax.lax.fori_loop(0, steps, step_fn, (s_hat, rng))
    return s_hat, rng


def make_mad_attack(
    policy_apply: Callable[[object, jax.Array], Tuple[jax.Array, jax.Array]],
    params: object,
    *,
    steps: int = 10,
    step_eps: Optional[float] = None,
    use_sgld: bool = True,
    init: str = "gaussian-sign",
) -> Callable[[jax.Array, jax.random.PRNGKey, float], Tuple[jax.Array, jax.random.PRNGKey]]:

    cfg = PGDConfig(steps=steps, step_eps=step_eps, use_sgld=use_sgld, init=init)

    def mad_attack(obs: jax.Array, rng: jax.random.PRNGKey, epsilon: float) -> Tuple[jax.Array, jax.random.PRNGKey]:
        mu_s, std_s = policy_apply(params, obs)
        std_norm = std_s / jnp.mean(std_s)

        def objective(s_hat: jax.Array) -> jax.Array:
            mu_hat, _ = policy_apply(params, s_hat)
            diff = (mu_hat - mu_s) / std_norm
            return jnp.sum(diff * diff)

        return _pgd_linf_optimize(
            obs=obs, rng=rng, epsilon=epsilon, objective=objective, cfg=cfg
        )

    return mad_attack


def make_critic_attack(
    critic_apply: Callable[[object, jax.Array], jax.Array],
    params: object,
    *,
    steps: int = 10,
    step_eps: Optional[float] = None,
    init: str = "uniform",
) -> Callable[[jax.Array, jax.random.PRNGKey, float], Tuple[jax.Array, jax.random.PRNGKey]]:

    cfg = PGDConfig(steps=steps, step_eps=step_eps, use_sgld=False, init=init)

    def critic_attack(obs: jax.Array, rng: jax.random.PRNGKey, epsilon: float) -> Tuple[jax.Array, jax.random.PRNGKey]:
        def objective(s_hat: jax.Array) -> jax.Array:
            v = critic_apply(params, s_hat)
            v_mean = jnp.mean(v)
            return -v_mean
        return _pgd_linf_optimize(
            obs=obs, rng=rng, epsilon=epsilon, objective=objective, cfg=cfg
        )

    return critic_attack


from functools import partial as _partial
from typing import Any, Callable, Optional, Tuple
import os
import jax
import jax.numpy as jnp

def make_rs_attack(
    policy_apply: Callable[[object, jax.Array], Tuple[Any, Any]],  # returns (pi, value)
    policy_params: object,
    checkpoint_path: str,
    *,
    steps: int = 50,
    step_eps: Optional[float] = None,
    init: str = "uniform",
    takes_params: bool = True,
) -> Callable[..., Tuple[jax.Array, jax.random.PRNGKey]]:
    """
    Robust SARSA (RS) state attack (param-passing variant):
        argmin_{||ŝ - s0||_∞ ≤ ε} Q_RS(s0, π(ŝ))

    Returns:
        rs_attack(obs, rng, epsilon, q_params, q_apply) -> (adv_obs, rng)
    """
    LOG = False
    cfg = PGDConfig(steps=steps, step_eps=step_eps, use_sgld=False, init=init)

    def _kl_diag_gaussian(mu0, var0, mu1, var1):
        eps = 1e-8
        var0 = jnp.maximum(var0, eps)
        var1 = jnp.maximum(var1, eps)
        log_ratio = jnp.log(var1) - jnp.log(var0)
        quad = (var0 + (mu0 - mu1) ** 2) / var1
        kl_per_env = 0.5 * jnp.sum(log_ratio + quad - 1.0, axis=-1)
        return kl_per_env

    @_partial(jax.jit, static_argnums=(4,))
    def rs_attack(
        obs: jax.Array,
        rng: jax.random.PRNGKey,
        epsilon: float,
        q_params: Any,
        q_apply: Any,
    ) -> Tuple[jax.Array, jax.random.PRNGKey]:
        s0 = obs

        def objective(s_hat: jax.Array) -> jax.Array:
            pi, _ = policy_apply(policy_params, s_hat)
            a_hat = pi.mean()
            q_val = q_apply(q_params, s0, a_hat)
            return -jnp.mean(q_val)

        adv_obs, rng = _pgd_linf_optimize(
            obs=s0, rng=rng, epsilon=epsilon, objective=objective, cfg=cfg
        )

        if LOG:
            pi_clean, _ = policy_apply(policy_params, s0)
            pi_adv,   _ = policy_apply(policy_params, adv_obs)

            mu0, var0 = pi_clean.mean(),    pi_clean.variance()
            mu1, var1 = pi_adv.mean(),      pi_adv.variance()
            kl_mean = jnp.mean(_kl_diag_gaussian(mu0, var0, mu1, var1))

            linf = jnp.max(jnp.abs(adv_obs - s0), axis=1)
            linf_max = jnp.max(linf)
            linf_mean = jnp.mean(linf)

            a_clean = mu0
            a_adv   = mu1
            q_clean = q_apply(q_params, s0, a_clean)
            q_adv   = q_apply(q_params, s0, a_adv)
            dq_mean = jnp.mean(q_adv - q_clean)

        return adv_obs, rng

    return rs_attack


from functools import partial as _partial
from typing import Any, Callable, Optional, Tuple
import jax
import jax.numpy as jnp

# ---------------------------------------------------------------------
# RNN MAD attack
# ---------------------------------------------------------------------
def make_mad_attack_rnn(
    policy_apply_rnn: Callable[[object, Any, Tuple[jax.Array, jax.Array]], Tuple[Any, Any, Any]],
    params: object,
    *,
    steps: int = 10,
    step_eps: Optional[float] = None,
    use_sgld: bool = True,
    init: str = "gaussian-sign",
) -> Callable[[jax.Array, jax.random.PRNGKey, float, Any, jax.Array], Tuple[jax.Array, jax.random.PRNGKey]]:
    """
    RNN variant of MAD attack.

    Returned callable:
        mad_attack_rnn(obs, rng, epsilon, hstate, done_flags) -> (adv_obs, rng)
    """
    cfg = PGDConfig(steps=steps, step_eps=step_eps, use_sgld=use_sgld, init=init)

    def _policy_stats(s: jax.Array, h: Any, d: jax.Array):
        _, pi, _ = policy_apply_rnn(params, h, (s[jnp.newaxis, ...], d[jnp.newaxis, ...]))
        mu = jnp.squeeze(pi.mean(), axis=0)
        std = jnp.sqrt(jnp.squeeze(pi.variance(), axis=0))
        return mu, std

    def mad_attack_rnn(
        obs: jax.Array,
        rng: jax.random.PRNGKey,
        epsilon: float,
        hstate: Any,
        done_flags: jax.Array,
    ) -> Tuple[jax.Array, jax.random.PRNGKey]:
        mu_s, std_s = _policy_stats(obs, hstate, done_flags)
        std_norm = std_s / jnp.mean(std_s)

        def objective(s_hat: jax.Array) -> jax.Array:
            mu_hat, _ = _policy_stats(s_hat, hstate, done_flags)
            diff = (mu_hat - mu_s) / std_norm
            return jnp.sum(diff * diff)

        return _pgd_linf_optimize(
            obs=obs, rng=rng, epsilon=epsilon, objective=objective, cfg=cfg
        )

    return mad_attack_rnn


# ---------------------------------------------------------------------
# RNN Critic attack
# ---------------------------------------------------------------------
def make_critic_attack_rnn(
    critic_apply_rnn: Callable[[object, Any, Tuple[jax.Array, jax.Array]], Tuple[Any, jax.Array]],
    params: object,
    *,
    steps: int = 10,
    step_eps: Optional[float] = None,
    init: str = "uniform",
) -> Callable[[jax.Array, jax.random.PRNGKey, float, Any, jax.Array], Tuple[jax.Array, jax.random.PRNGKey]]:
    """
    RNN variant of critic attack.

    Returned callable:
        critic_attack_rnn(obs, rng, epsilon, hstate, done_flags) -> (adv_obs, rng)
    """
    cfg = PGDConfig(steps=steps, step_eps=step_eps, use_sgld=False, init=init)

    def critic_attack_rnn(
        obs: jax.Array,
        rng: jax.random.PRNGKey,
        epsilon: float,
        hstate: Any,
        done_flags: jax.Array,
    ) -> Tuple[jax.Array, jax.random.PRNGKey]:
        def objective(s_hat: jax.Array) -> jax.Array:
            _, v = critic_apply_rnn(params, hstate, (s_hat[jnp.newaxis, ...], done_flags[jnp.newaxis, ...]))
            v_mean = jnp.mean(v)
            return -v_mean

        return _pgd_linf_optimize(
            obs=obs, rng=rng, epsilon=epsilon, objective=objective, cfg=cfg
        )

    return critic_attack_rnn


# ---------------------------------------------------------------------
# RNN Robust SARSA (RS) state attack
# ---------------------------------------------------------------------
def make_rs_attack_rnn(
    policy_apply_rnn: Callable[[object, Any, Tuple[jax.Array, jax.Array]], Tuple[Any, Any, Any]],
    policy_params: object,
    checkpoint_path: str,
    *,
    steps: int = 50,
    step_eps: Optional[float] = None,
    init: str = "uniform",
) -> Callable[..., Tuple[jax.Array, jax.random.PRNGKey]]:
    """
    RNN variant of Robust SARSA (param-passing):
        argmin_{||ŝ - s0||_∞ ≤ ε} Q_RS(s0, π(ŝ))

    Returned callable:
        rs_attack_rnn(obs, rng, epsilon, hstate, done_flags, q_params, q_apply) -> (adv_obs, rng)
    """
    LOG = False
    cfg = PGDConfig(steps=steps, step_eps=step_eps, use_sgld=False, init=init)

    def _kl_diag_gaussian(mu0, var0, mu1, var1):
        eps = 1e-8
        var0 = jnp.maximum(var0, eps)
        var1 = jnp.maximum(var1, eps)
        log_ratio = jnp.log(var1) - jnp.log(var0)
        quad = (var0 + (mu0 - mu1) ** 2) / var1
        return 0.5 * jnp.sum(log_ratio + quad - 1.0, axis=-1)

    @_partial(jax.jit, static_argnums=(6,))
    def rs_attack_rnn(
        obs: jax.Array,
        rng: jax.random.PRNGKey,
        epsilon: float,
        hstate: Any,
        done_flags: jax.Array,
        q_params: Any,
        q_apply: Any,
    ) -> Tuple[jax.Array, jax.random.PRNGKey]:
        s0 = obs

        def _policy_pi(s: jax.Array):
            _, pi, _ = policy_apply_rnn(policy_params, hstate, (s[jnp.newaxis, ...], done_flags[jnp.newaxis, ...]))
            mu = jnp.squeeze(pi.mean(), axis=0)
            var = jnp.squeeze(pi.variance(), axis=0)
            return mu, var, pi

        def objective(s_hat: jax.Array) -> jax.Array:
            mu_hat, _, _ = _policy_pi(s_hat)
            a_hat = mu_hat
            q_val = q_apply(q_params, s0, a_hat)
            return -jnp.mean(q_val)

        adv_obs, rng = _pgd_linf_optimize(
            obs=s0, rng=rng, epsilon=epsilon, objective=objective, cfg=cfg
        )

        if LOG:
            mu0, var0, _ = _policy_pi(s0)
            mu1, var1, _ = _policy_pi(adv_obs)
            kl_mean = jnp.mean(_kl_diag_gaussian(mu0, var0, mu1, var1))

            linf = jnp.max(jnp.abs(adv_obs - s0), axis=1)
            linf_max = jnp.max(linf)
            linf_mean = jnp.mean(linf)

            a_clean = mu0
            a_adv   = mu1
            q_clean = q_apply(q_params, s0, a_clean)
            q_adv   = q_apply(q_params, s0, a_adv)
            dq_mean = jnp.mean(q_adv - q_clean)

        return adv_obs, rng

    return rs_attack_rnn
