import jax
import jax.numpy as jnp
from functools import partial

def _project_linf(x, x0, eps, low=None, high=None):
    # Project onto L_infty ball around x0 and optional box [low, high]
    x = jnp.clip(x, x0 - eps, x0 + eps)
    if low is not None and high is not None:
        x = jnp.clip(x, low, high)
    return x

def _kl_categorical_from_logits(logits_p, logits_q):
    # KL(P||Q) with logits; numerically stable
    log_p = jax.nn.log_softmax(logits_p, axis=-1)
    log_q = jax.nn.log_softmax(logits_q, axis=-1)
    p = jnp.exp(log_p)
    return jnp.sum(p * (log_p - log_q), axis=-1)

def _kl_gaussian_diag(mu_p, std_p, mu_q, std_q):
    # KL(Np||Nq) for diagonal Gaussians; shapes [..., A]
    var_p, var_q = std_p**2, std_q**2
    term = (var_p + (mu_p - mu_q) ** 2) / (var_q + 1e-8)
    return 0.5 * (jnp.sum(term - 1.0 + jnp.log((var_q + 1e-8) / (var_p + 1e-8)), axis=-1))

def _policy_kl(pi_s, pi_s_adv, sigma=0.1):
    mu_p = pi_s.mean()
    mu_q = pi_s_adv.mean()
    diff = mu_p - mu_q
    return 0.5 * jnp.sum(diff * diff, axis=-1) / (sigma**2)

def _sa_regularizer_mlp(params, network, obs, init_pi=None,
                        eps=0.05, steps=3, step_size=0.01, low=None, high=None):
    """
    Approximates max_ObsAdv KL(pi(s)||pi(s_adv)) via PGD in observation space.
    obs: [B, ...]
    """
    # Compute reference policy on clean states
    pi_s, _ = network.apply(params, obs)  # expecting (pi, value)
    # Initialize adversarial obs inside the L_inf ball
    obs0 = obs
    obs_adv = obs

    def pgd_body(i, obs_adv):
        # Objective: KL(pi(s) || pi(s_adv)); s is fixed, optimize over obs_adv
        def kl_on_obs_adv(x_adv):
            pi_adv, _ = network.apply(params, x_adv)
            kl = _policy_kl(pi_s, pi_adv)  # shape [B]
            return kl.mean()

        grad = jax.grad(kl_on_obs_adv)(obs_adv)
        obs_adv_new = obs_adv + step_size * jnp.sign(grad)
        obs_adv_new = _project_linf(obs_adv_new, obs0, eps, low, high)
        return obs_adv_new

    obs_adv = jax.lax.fori_loop(0, steps, pgd_body, obs_adv)
    # Final KL on the adversarial states
    pi_adv, _ = network.apply(params, obs_adv)
    kl = _policy_kl(pi_s, pi_adv)  # [B]
    return kl.mean()

def _sa_regularizer_rnn(params, network, init_hstate, obs_seq, done_seq,
                        eps=0.05, steps=3, step_size=0.01, low=None, high=None):
    """
    RNN case: obs_seq/done_seq are shaped [T, B, ...] (we’ll pass that to network.apply).
    We perturb the *entire sequence* jointly (PGD over obs_seq).
    """
    obs0 = obs_seq
    obs_adv = obs_seq

    # Reference policy on clean sequence
    _, pi_seq, _ = network.apply(params, init_hstate, (obs_seq, done_seq))

    def pgd_body(i, obs_adv):
        def kl_on_obs_adv(x_adv):
            _, pi_adv_seq, _ = network.apply(params, init_hstate, (x_adv, done_seq))
            # Compute timestep-wise KL and average over T and B
            kl_tb = _policy_kl(pi_seq, pi_adv_seq)  # expects broadcasting across T,B
            return kl_tb.mean()
        grad = jax.grad(kl_on_obs_adv)(obs_adv)
        obs_adv_new = obs_adv + step_size * jnp.sign(grad)
        obs_adv_new = _project_linf(obs_adv_new, obs0, eps, low, high)
        return obs_adv_new

    obs_adv = jax.lax.fori_loop(0, steps, pgd_body, obs_adv)
    _, pi_adv_seq, _ = network.apply(params, init_hstate, (obs_adv, done_seq))
    kl_tb = _policy_kl(pi_seq, pi_adv_seq)  # [T,B] or [T,B,1]
    return kl_tb.mean()

from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.jit, static_argnums=[6, 7])
def _ppo_loss_fn(params, init_hstate, traj_batch, gae, targets, step_num, config, network):
    # Re-run policy/value to get current log_probs and values
    if config["USE_RNN"]:
        print(traj_batch.obs.shape, traj_batch.done.shape)
        print(init_hstate.shape)
        _, pi, value = network.apply(
            params,
            init_hstate[0],
            (traj_batch.obs, traj_batch.done),
        )
        log_prob = pi.log_prob(traj_batch.action)
        print(log_prob.shape, value.shape)
    else:
        # MLP case
        pi, value = network.apply(params, traj_batch.obs)
        log_prob = pi.log_prob(traj_batch.action)

    # Value loss (clipped)
    value_pred_clipped = traj_batch.value + (
        value - traj_batch.value
    ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
    value_losses = jnp.square(value - targets)
    value_losses_clipped = jnp.square(value_pred_clipped - targets)
    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

    # Policy loss (clipped)
    ratio = jnp.exp(log_prob - traj_batch.log_prob)
    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
    loss_actor1 = ratio * gae
    loss_actor2 = jnp.clip(ratio, 1.0 - config["CLIP_EPS"], 1.0 + config["CLIP_EPS"]) * gae
    loss_actor = -jnp.minimum(loss_actor1, loss_actor2).mean()

    # Entropy bonus
    entropy = pi.entropy().mean()

    # === SA-PPO regularizer (with ε schedule) =================================
    sa_reg = 0.0
    if config.get("USE_SA_PPO", False):
        # Target epsilon and schedule (reset-aware for LTH)
        eps_target = float(config.get("ADV_EPSILON", 0.075))
        warmup_frac = float(config.get("SA_EPS_WARMUP_FRAC", 0.75))  # fraction of updates to ramp ε
        steps = int(config.get("SA_STEPS", 3))

        # Fallback estimate of total updates (your original logic)
        denom = max(int(config["NUM_ENVS"]) * int(config["NUM_STEPS"]), 1)
        default_total_updates = max(int(config["TOTAL_TIMESTEPS"] // denom), 1)

        # Define the length of ONE cycle between LTH resets.
        # Prefer an explicit NUM_UPDATES if you have it; otherwise fall back.
        cycle_updates = int(config["NUM_UPDATES"])

        # Step within the current cycle (so the ramp restarts after each rewind)
        # If step_num is a Python int, jnp.mod will still work fine.
        cycle_step = jnp.mod(step_num, cycle_updates)

        # Progress in [0, 1] within the current cycle
        progress = jnp.clip(
            cycle_step / jnp.maximum(cycle_updates - 1, 1),
            0.0, 1.0
        )

        # Linear warmup of epsilon over warmup_frac of the cycle
        eps = jnp.where(
            progress < warmup_frac,
            eps_target * (progress / jnp.maximum(warmup_frac, 1e-8)),
            eps_target,
        )

        # PGD step size: if SA_STEP_SIZE <= 0, use ε/steps; else fixed from config
        cfg_step_size = float(config.get("SA_STEP_SIZE", 0.0))
        step_size = eps / max(steps, 1) if cfg_step_size <= 0.0 else cfg_step_size

        # Optional observation clamping (match argparse names: obs_clip_low/high)
        low  = config.get("OBS_CLIP_LOW", None)
        high = config.get("OBS_CLIP_HIGH", None)

        if config["USE_RNN"]:
            # Expect obs/done as [T,B,...] for the regularizer path
            obs_TB  = jnp.swapaxes(traj_batch.obs, 1, 0)
            done_TB = jnp.swapaxes(traj_batch.done, 1, 0)
            sa_reg = _sa_regularizer_rnn(
                params, network, init_hstate[0], traj_batch.obs, traj_batch.done,
                eps=eps, steps=steps, step_size=step_size, low=low, high=high
            )
        else:
            sa_reg = _sa_regularizer_mlp(
                params, network, traj_batch.obs,
                eps=eps, steps=steps, step_size=step_size, low=low, high=high
            )

    total_loss = (
        loss_actor
        + config["VF_COEF"] * value_loss
        - config["ENT_COEF"] * entropy
        + config.get("SA_LAMBDA", 1.0) * sa_reg
    )
    return total_loss, (value_loss, loss_actor, entropy, sa_reg)




@partial(jax.jit, static_argnums=[6, 7])
def _adv_ppo_loss_fn(params, init_hstate, traj_batch, gae, targets, step_num, config, network):
    # Re-run adversary policy/value (no latent)
    if config["USE_ADV_RNN"]:
        # adv_network.apply(params, hstate, (obs[B,T,...] or [T,B,...]? here we keep [B,T,...]))
        # Your rollout stored traj_batch.obs/done to match this call site.
        _, pi, value = network.apply(
            params,
            init_hstate[0],
            (traj_batch.obs, traj_batch.done),
        )
        log_prob = pi.log_prob(traj_batch.action)
    else:
        pi, value = network.apply(params, traj_batch.obs)
        log_prob = pi.log_prob(traj_batch.action)

    if 'fort' in config['ENV_NAME']:
        # sum over multi-dimensional action for fort envs
        log_prob = log_prob.sum(axis=-1)

    # Value loss (clipped)
    value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
        -config["ADV_CLIP_EPS"], config["ADV_CLIP_EPS"]
    )
    value_losses      = (value - targets) ** 2
    value_losses_clip = (value_pred_clipped - targets) ** 2
    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clip).mean()

    # Policy loss (clipped)
    ratio = jnp.exp(log_prob - traj_batch.log_prob)
    gae   = (gae - gae.mean()) / (gae.std() + 1e-8)
    loss_1 = ratio * gae
    loss_2 = jnp.clip(ratio, 1.0 - config["CLIP_EPS"], 1.0 + config["CLIP_EPS"]) * gae
    loss_actor = -jnp.minimum(loss_1, loss_2).mean()

    entropy = pi.entropy().mean()

    total_loss = (
        loss_actor
        + config["ADV_VF_COEF"] * value_loss
        - config["ADV_ENT_COEF"] * entropy
    )
    return total_loss, (value_loss, loss_actor, entropy)


from functools import partial
import jax
import jax.numpy as jnp
import wandb

# ---- Host-side logger (runs outside jit via jax.debug.callback) ----
def _wandb_log(step, metrics_dict):
    # step and values arrive as device arrays; convert on host
    step = int(step)
    to_log = {k: float(v) for k, v in metrics_dict.items()}
    wandb.log(to_log)

