
import jax
import jax.numpy as jnp

# Custom utilities
from utils import Transition, filter_adv, filter_pro



# --- Helpers --------------------------------------------------------------

def select_action(rng, obs, done, hstate, train_state, network, config):
    """Sample policy action given observations; temporal context is in hstate."""
    rng, sample_rng = jax.random.split(rng)
    if config["USE_RNN"]:
        ac_in = (obs[jnp.newaxis, :], done[jnp.newaxis, :])
        hstate, pi, value = network.apply(
            train_state.params, hstate, ac_in
        )
        action = pi.sample(seed=sample_rng).squeeze(0).astype(jnp.float32)
        log_prob = pi.log_prob(action).squeeze(0)
        value = value.squeeze(0)
    else:
        pi, value = network.apply(
            train_state.params, obs
        )
        action = pi.sample(seed=sample_rng).astype(jnp.float32)
        log_prob = pi.log_prob(action)

    return rng, action, log_prob, value, hstate


def step_env(rng, env_state, action, env, env_params, config):
    """Take a step in the environment; no explicit history tracking."""
    rng, step_rng = jax.random.split(rng)
    step_rngs = jax.random.split(step_rng, config['NUM_ENVS'])
    obsv, env_state, reward, done, info = env.step(
        step_rngs, env_state, action, env_params
    )
    # No rolling windows; just pass through.
    clean_obsv = obsv
    return rng, env_state, obsv, clean_obsv, done, reward, info


# --- Main step ------------------------------------------------------------

def env_step(
    runner_state, unused,
    is_adv=False,
    config=None
):
    (train_state, adv_train_state,
     env_state, last_obs, last_clean_obs, last_done,
     hstate, adv_hstate,
     rng, update_step) = runner_state

    filter_p = lambda a, b: filter_pro(a, b, config)
    filter_a = lambda a: filter_adv(a, config)

    # 1) Adversary branch (ATLA-only; no OAAT / no latent)
    use_atla = config.get("USE_ATLA", False)
    use_adv = use_atla

    if not use_adv:
        # no adversary mode
        delta = jnp.zeros_like(last_obs)
        adv_value = 0.0
        adv_log_prob = 0.0
        mod_last_obs = last_obs
        first_delta = None
    else:
        # prepare inputs
        # Use single-timestep last_obs/last_done (no history window)
        ac_in = (last_obs[jnp.newaxis], last_done[jnp.newaxis])

        if config.get('USE_ADV_RNN', False):
            adv_hstate, delta_dist, adv_value = adv_network.apply(
                adv_train_state.params,
                adv_hstate,
                ac_in,
            )
        else:
            delta_dist, adv_value = adv_network.apply(
                adv_train_state.params,
                filter_a(last_obs)
            )

        # sample perturbation
        rng, d_rng = jax.random.split(rng)
        delta = delta_dist.sample(seed=d_rng)
        first_delta = delta
        adv_log_prob = jnp.sum(delta_dist.log_prob(delta), axis=1)

        if "fort" not in config["ENV_NAME"]:
            eps = config["ADV_EPS"]
            delta = jnp.clip(delta, -eps, eps)
            mod_last_obs = last_obs + delta
            # no delta history buffer
        else:
            # special handling for 'fort' envs using countdown logic
            filtered = last_obs
            countdown, bordered = filtered[:, :4], filtered[:, 4:]
            # NOTE: env must be provided via closure/bind
            attack_mask = delta <= env.max_countdown
            cd_perturbed = jnp.where(
                attack_mask,
                delta / env.max_countdown,
                countdown
            )
            mod_base = jnp.concatenate([cd_perturbed, bordered], axis=-1)
            # delta_t = filtered - mod_base  # no delta_hist tracking
            mod_last_obs = mod_base

    # 2) Victim action selection (no latent)
    filtered_obs, filtered_done = filter_p(mod_last_obs, last_done)
    rng, action, log_prob, value, hstate = select_action(
        rng, filtered_obs, filtered_done,
        hstate,
        train_state, network,
        config
    )

    # 3) Environment step
    rng, env_state, obsv, clean_obsv, done, reward, info = step_env(
        rng, env_state, action, env, env_params, config
    )

    # 4) Build transition (no action_hist field)
    if is_adv:
        # adversary gets negative of victim reward in ATLA; otherwise zero
        adv_reward = -reward if config["USE_ATLA"] else 0.0
        transition = Transition(
            last_done,
            first_delta,
            adv_value,
            adv_reward,
            adv_log_prob,
            (mod_last_obs if not config["USE_ATLA"] else last_clean_obs),
            info
        )
    else:
        transition = Transition(
            last_done,
            action,
            value,
            reward,
            log_prob,
            mod_last_obs,
            info
        )

    # 5) Pack new state (no delta_hist/action_hist in runner_state)
    runner_state = (
        train_state, adv_train_state,
        env_state, obsv, clean_obsv, done,
        hstate, adv_hstate, rng, update_step
    )
    return runner_state, transition
