import argparse
import os
import sys
import time
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

# JAX / Flax / Optax
import jax
import jax.numpy as jnp
from flax.training import orbax_utils
from flax.training.train_state import TrainState
from flax.core.frozen_dict import freeze
from orbax.checkpoint import (
    CheckpointManager,
    CheckpointManagerOptions,
    PyTreeCheckpointer,
)
import optax
import distrax

# Other ML libraries
import numpy as np
import ml_collections
import wandb

# Models
from models.simple import (
    ActorCritic,
    ActorCriticRNN,
    ScannedRNN,
    AdversaryActorCritic,
    AdversaryActorCriticRNN,
    RSCritic,
)

# Custom utilities
from utils import (
    load_env,
    save_sarsa,
    load_checkpoint_with_norm,  # for future use / resuming
)
from batch_logging import EpisodeLogger
from attacks import make_rs_attack
from typing import NamedTuple, Optional

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: Optional[jnp.ndarray] = None
    action_hist: Optional[jnp.ndarray] = None
    next_obs: Optional[jnp.ndarray] = None
    hstate: jnp.ndarray


def make_train(config):
    # LOAD VICTIM AND ENV STATS
    payload = load_checkpoint_with_norm(os.path.abspath(config["VICTIM_CHECKPOINT_DIR"]))
    runner_state = payload["runner_state"]
    victim_params = runner_state[0]["params"]
    norm_stats = payload.get("norm_stats", None)

    logger = EpisodeLogger(config)
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )

    env, env_params = load_env(config, norm_stats=norm_stats)
    action_space = env.action_space(env_params)
    config["ACTION_SHAPE"] = 4 if "fort" in config["ENV_NAME"].lower() else action_space.shape[0]
    config["OBS_SHAPE"] = env.observation_space(env_params).shape

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    config = freeze(config)

    action_dim = env.action_space(env_params).shape[0]

    if config["USE_RNN"]:
        victim_network = ActorCriticRNN(action_dim=action_dim, layer_size=config.get("LAYER_SIZE", 256))
    else:
        victim_network = ActorCritic(action_dim=action_dim, layer_size=config.get("LAYER_SIZE", 256))

    rs_attack = make_rs_attack(
        policy_apply=victim_network.apply,
        policy_params=victim_params,
        checkpoint_path="",
        steps=config.get("RS_PGD_STEPS", 50),
        step_eps=config.get("RS_STEP_EPS", None),
        init="uniform",
        takes_params=True,
    )

    init_hstate = ScannedRNN.initialize_carry(
        config["NUM_ENVS"], config["LAYER_SIZE"]
    )

    def train(rng):
        action_space = env.action_space(env_params)
        act_dim = 4 if "fort" in config["ENV_NAME"].lower() else action_space.shape[0]
        obs_dim = env.observation_space(env_params).shape[0]
        rs_critic = RSCritic()
        dummy_s = jnp.zeros((config["NUM_ENVS"], obs_dim))
        dummy_a = jnp.zeros((config["NUM_ENVS"], act_dim))
        rs_params = rs_critic.init(rng, dummy_s, dummy_a)

        tx = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(config["LR"], eps=1e-5),
        )

        train_state = TrainState.create(
            apply_fn=rs_critic.apply,
            params=rs_params,
            tx=tx,
        )

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = env.reset(reset_rng, env_params)

        def _update_step(runner_state, unused):
            (
                train_state,
                env_state,
                last_obs,
                last_done,
                rng,
                update_step,
                hstate
            ) = runner_state

            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, last_done, rng, update_step, hstate = runner_state

                # === SELECT ACTION (RS optional) ===
                rng, _rng = jax.random.split(rng)

                # craft adversarial observation for the policy
                # adv_obs, _rng = rs_attack(last_obs, _rng, config["EPSILON"], train_state.params, train_state.apply_fn)
                obs_for_policy = last_obs

                # If using RNN, include hidden state; otherwise use MLP apply
                if config["USE_RNN"]:
                    ac_in = (obs_for_policy[jnp.newaxis], last_done[jnp.newaxis])
                    hstate, pi, value = victim_network.apply(
                        victim_params, hstate, ac_in
                    )
                    value = value.squeeze(0)
                    action = pi.sample(seed=_rng).squeeze(0).astype(jnp.float32)
                    log_prob = pi.log_prob(action).squeeze(0)
                else:
                    pi, value = victim_network.apply(victim_params, obs_for_policy)
                    action = pi.sample(seed=_rng)
                    log_prob = pi.log_prob(action)

                # === STEP ENV ===
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = env.step(
                    rng_step, env_state, action, env_params
                )

                transition = Transition(
                    done=done,
                    action=action,
                    value=value,
                    reward=reward,
                    log_prob=log_prob,
                    obs=last_obs,
                    next_obs=obsv,
                    info=info,
                    hstate=hstate,  # carried state
                )
                runner_state = (train_state, env_state, obsv, last_done, rng, update_step, hstate)
                return runner_state, transition

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _rs_act_eps_schedule(step_num: jnp.ndarray) -> jnp.ndarray:
                        # Target epsilon and schedule (reset-aware)
                        eps_target = float(config.get("EPSILON", config.get("ADV_EPSILON", 0.075)))
                        warmup_frac = float(config.get("RS_EPS_WARMUP_FRAC", 0.75))

                        # Progress within cycle
                        progress = jnp.clip(update_step / jnp.maximum(config["NUM_UPDATES"] - 1, 1), 0.0, 1.0)

                        # Linear warmup over warmup_frac of the cycle
                        eps = jnp.where(
                            progress < warmup_frac,
                            eps_target * (progress / jnp.maximum(warmup_frac, 1e-8)),
                            eps_target,
                        )
                        return eps

                    def _loss_fn(params, traj_batch, *extra):
                        """RS-SARSA critic loss using global `config`, `victim_network`, and `rs_critic`."""
                        step_num = extra[0] if len(extra) > 0 else 0

                        gamma = config["GAMMA"]
                        use_penalty = config.get("RS_USE_PENALTY", True)
                        rs_lambda = config.get("RS_LAMBDA", 1e3)
                        inner_steps = int(config.get("RS_INNER_STEPS", 10))
                        inner_step_eps = config.get("RS_INNER_STEP_EPS", None)
                        act_low = config.get("ACT_LOW", None)
                        act_high = config.get("ACT_HIGH", None)

                        # scheduled action epsilon
                        act_eps = _rs_act_eps_schedule(step_num)
                        step_eps = inner_step_eps if inner_step_eps is not None else (act_eps / max(inner_steps, 1))

                        # unpack batch
                        s = traj_batch.obs
                        a = traj_batch.action
                        r = traj_batch.reward
                        d = traj_batch.done.astype(jnp.float32)
                        s_n = traj_batch.next_obs
                        hs = traj_batch.hstate  # kept for completeness if needed

                        # SARSA target with deterministic next action (Gaussian -> mean)
                        if config["USE_RNN"]:
                            ac_in = (s_n[jnp.newaxis], d[jnp.newaxis])
                            _, pi, _ = victim_network.apply(
                                victim_params, init_hstate, ac_in
                            )
                            action = pi.sample(seed=_rng).squeeze(0).astype(jnp.float32)
                        pi_next, _ = victim_network.apply(victim_params, s_n)
                        a_n = pi_next.mean()

                        q_sa = rs_critic.apply(params, s, a)
                        q_sn = rs_critic.apply(params, s_n, a_n)

                        target = r + (1.0 - d) * gamma * jax.lax.stop_gradient(q_sn)
                        td_loss = jnp.mean((target - q_sa) ** 2)

                        # --- robust action-space term ------------------------------------------------
                        def _proj_linf_to_ball(a_hat, a_ctr, eps):
                            return jnp.clip(a_hat, a_ctr - eps, a_ctr + eps)

                        def _proj_to_box(a_hat):
                            if (act_low is None) or (act_high is None):
                                return a_hat
                            return jnp.clip(a_hat, act_low, act_high)

                        def _maximize_delta_sq_single(si, ai):
                            # warm start away from ai to avoid zero-gradient saddle
                            g0 = jax.grad(lambda _a: jnp.mean(rs_critic.apply(params, si, _a)))(ai)
                            a_hat = ai + step_eps * jnp.sign(g0)
                            a_hat = _proj_linf_to_ball(a_hat, ai, act_eps)
                            a_hat = _proj_to_box(a_hat)

                            def body(_, ah):
                                def obj(_ah):
                                    return jnp.mean((rs_critic.apply(params, si, _ah) - rs_critic.apply(params, si, ai)) ** 2)
                                grad = jax.grad(obj)(ah)
                                ah = ah + step_eps * jnp.sign(grad)
                                ah = _proj_linf_to_ball(ah, ai, act_eps)
                                ah = _proj_to_box(ah)
                                return ah

                            a_hatK = jax.lax.fori_loop(0, inner_steps, body, a_hat)
                            a_hatK = jax.lax.stop_gradient(a_hatK)

                            return (rs_critic.apply(params, si, a_hatK) - rs_critic.apply(params, si, ai)) ** 2

                        if use_penalty:
                            robust_term = jnp.mean(jax.vmap(_maximize_delta_sq_single)(s, a))
                            total_loss = td_loss + rs_lambda * robust_term
                        else:
                            robust_term = jnp.array(0.0, dtype=td_loss.dtype)
                            total_loss = td_loss

                        return total_loss, (td_loss, robust_term)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    (total_loss, aux), grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                train_state, traj_batch, advantages, targets, rng, update_step = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng, update_step)
                return update_state, total_loss

            def sarsa_update_loop(runner_state, unused):
                # rollout
                env_step_fn = lambda a, b: _env_step(a, b)
                runner_state, traj_batch = jax.lax.scan(
                    env_step_fn, runner_state, None, config["NUM_STEPS"]
                )

                (
                    train_state,
                    env_state,
                    last_obs,
                    last_done,
                    rng,
                    update_step,
                    hstate
                ) = runner_state

                # Placeholders to satisfy existing update API.
                advantages = jnp.zeros_like(traj_batch.reward)
                targets = jnp.zeros_like(traj_batch.reward)

                update_state = (
                    train_state,
                    traj_batch,
                    advantages,
                    targets,
                    rng,
                    update_step,
                )

                update_state, total_loss = jax.lax.scan(
                    _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
                )
                train_state = update_state[0]
                rng = update_state[-2]

                metric = jax.tree.map(
                    lambda x: (x * traj_batch.info["returned_episode"]).sum()
                    / traj_batch.info["returned_episode"].sum(),
                    traj_batch.info,
                )
                metric["average_loss"] = jnp.mean(jnp.array(total_loss))
                if config["USE_WANDB"]:
                    def callback(metric, update_step):
                        logger.log_step(int(update_step), metric)
                    jax.debug.callback(callback, metric, update_step)

                runner_state = (
                    train_state,
                    env_state,
                    last_obs,
                    last_done,
                    rng,
                    update_step + 1,
                    hstate,
                )
                return runner_state, traj_batch

            update_epochs = 1
            adv_update_epochs = 1

            runner_state, _ = jax.lax.scan(sarsa_update_loop, runner_state, None, update_epochs)

            return runner_state, {}

        rng, _rng = jax.random.split(rng)
        done = jnp.zeros((config["NUM_ENVS"]), dtype=bool)

        runner_state = (
            train_state,
            env_state,
            obsv,
            done,
            _rng,
            0,
            init_hstate
        )
        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        return {"runner_state": runner_state, "metric": metric}

    return train


def run_ppo(config):
    # --- Build run name for logs/W&B ---
    net_type = "MLP"
    algo_type = "PPO"

    lr_str = f"lr={config['LR']}"

    prune_str = ""
    if config.get("USE_PRUNING", False):
        prune_str = f"-prune={config['PRUNER_TYPE']}-{int(config['PRUNE_PERCENTAGE'] * 100)}pct"

    adv_str = "-SA" if config.get("USE_SA_PPO", False) else ""

    run_name = (
        f"{algo_type}-{net_type}"
        f"-env={config['ENV_NAME']}"
        f"{prune_str}{adv_str}"
        f"-{lr_str}"
        f"-{int(config['TOTAL_TIMESTEPS'] // 1e6)}M"
        f"-seed={config['SEED']}"
    )

    # --- Optional W&B logging ---
    if config.get("USE_WANDB", False):
        wandb_kwargs = dict(
            project=config["WANDB_PROJECT"],
            config=config,
            name=run_name,
            anonymous="allow",
        )
        if config.get("WANDB_ENTITY"):
            wandb_kwargs["entity"] = config["WANDB_ENTITY"]
        wandb.init(**wandb_kwargs)

    # PRNGs and compiled train fn
    rng = jax.random.PRNGKey(config["SEED"])
    rngs = jax.random.split(rng, config["NUM_REPEATS"])
    train_jit = jax.jit(make_train(config))
    train_vmap = jax.vmap(train_jit)

    # Train
    t0 = time.time()
    out = train_vmap(rngs)
    t1 = time.time()

    elapsed = t1 - t0
    print("Time to run experiment:", elapsed)
    print("SPS:", config["TOTAL_TIMESTEPS"] / elapsed)

    # Save RS/SARSA critic params (not the whole runner state)
    if config.get("SAVE_POLICY", False):
        runner_state0 = jax.tree.map(lambda x: x[0], out["runner_state"])
        train_state0 = runner_state0[0]
        rs_params = train_state0.params

        ckdir = save_sarsa(
            config=config,
            params=rs_params,
            step=int(config["TOTAL_TIMESTEPS"]),
            ckpt_root="./sarsa_checkpoints",
        )
        print(f"✔ RS critic checkpoint saved to {ckdir}")

    return out


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", type=str, default="hopper")
    parser.add_argument("--num_envs", type=int, default=2048)
    parser.add_argument("--total_timesteps", type=lambda x: int(float(x)), default=2.5e7)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--num_steps", type=int, default=10)
    parser.add_argument("--update_epochs", type=int, default=4)
    parser.add_argument("--num_minibatches", type=int, default=32)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--clip_eps", type=float, default=0.2)
    parser.add_argument("--ent_coef", type=float, default=0.0)
    parser.add_argument("--vf_coef", type=float, default=0.5)
    parser.add_argument("--max_grad_norm", type=float, default=0.5)
    parser.add_argument("--activation", type=str, default="tanh")
    parser.add_argument("--anneal_lr", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--seed", type=int, default=np.random.randint(2**31))
    parser.add_argument("--use_wandb", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--save_policy", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--num_repeats", type=int, default=1)
    parser.add_argument("--layer_size", type=int, default=256)
    parser.add_argument("--wandb_project", type=str, default="anonymous-project")
    parser.add_argument("--wandb_entity", type=str, default="")
    parser.add_argument("--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--optimistic_reset_ratio", type=int, default=16)
    parser.add_argument("--normalize_env", action="store_true", default=True)

    # Victim checkpoint path
    parser.add_argument("--victim-checkpoint-dir", type=str, default="")
    parser.add_argument("--epsilon", type=float, default=0.075)
    parser.add_argument("--rs_lambda", type=float, default=100)
    parser.add_argument(
        "--use_rnn",
        action="store_true",
        default=False,
        help="Enable RNN architecture (default: False, uses MLP)",
    )

    args = parser.parse_known_args(sys.argv[1:])[0]
    config = {k.upper(): v for k, v in args.__dict__.items()}

    if config.get("SEED") is None:
        config["SEED"] = 1

    # Neutral placeholder for any project-specific toggles (none applied).
    return config


def main():
    config = parse_args()
    run_ppo(config)


if __name__ == "__main__":
    main()
