import argparse
import os
import sys
import time
# Example: set CUDA device via environment variable
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# JAX / Flax / Optax
import jax
import jax.numpy as jnp
from flax.training import orbax_utils
from flax.training.train_state import TrainState
from flax.core.frozen_dict import freeze
from orbax.checkpoint import (
    CheckpointManager,
    CheckpointManagerOptions,
    PyTreeCheckpointer,
)
import optax
import distrax

# Other ML libraries
import numpy as np
import ml_collections
import wandb

# Logging utilities
from batch_logging import EpisodeLogger

# Models
from models.simple import (
    ActorCritic,
    ActorCriticRNN,
    ScannedRNN,
    AdversaryActorCritic,
    AdversaryActorCriticRNN,
)

from models.losses import (
    _adv_ppo_loss_fn,
    _ppo_loss_fn,
)

# Custom utilities
from rl.env_step import env_step
from rl.update import update_epoch
from rl.setup import setup_protagonist, setup_antagonist
from utils import (
    Transition, 
    load_env, 
    filter_adv, 
    filter_pro, 
    calculate_gae, 
    calculate_rnn_gae,
    bind,
    _arch_str,
    _prune_str,
    _format_ckpt_dir,
    _extract_norm_stats_from_env_state,
    save_checkpoint_with_norm,
    load_checkpoint_with_norm,  # for future use / resuming
    lth_rewind,
    clamp_to_mask,
    param_sparsity
)
from batch_logging import EpisodeLogger


def make_train(config):

    logger = EpisodeLogger(config)
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    config["NUM_LTH_UPDATES"] = 0
    if config["USE_LTH"]:
        config["NUM_LTH_UPDATES"] = (
        config["LTH_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )

    env, env_params = load_env(config)
    action_space = env.action_space(env_params)
    config['ACTION_SHAPE'] = 4 if 'fort' in config['ENV_NAME'].lower() else action_space.shape[0]
    config["OBS_SHAPE"] = env.observation_space(env_params).shape

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    config = freeze(config)
    filter_p = lambda a, b: filter_pro(a, b, config)
    filter_a = lambda a: filter_adv(a, config)

    def train(rng):
        obs_shape = config["OBS_SHAPE"]
        action_dim = config["ACTION_SHAPE"]
        rng, network, train_state, init_hstate, pruner = setup_protagonist(rng, config)
        save_victim_params = train_state.params
        save_opt_state = train_state.opt_state
        rng, adv_network, adv_train_state, init_adv_hstate = setup_antagonist(rng, config)

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = env.reset(reset_rng, env_params)
        networks = (network, adv_network)  # no encoder

        _env_step = bind(env_step, config, env, env_params, networks, pruner)
        if config['USE_RNN']:
            _calculate_gae = bind(calculate_rnn_gae, config, env, env_params, networks, pruner)
        else: 
            _calculate_gae = bind(calculate_gae, config, env, env_params, networks, pruner)
        _update_epoch = bind(update_epoch, config, env, env_params, networks, pruner)

        def _update_step(runner_state, unused):
            (
                train_state,
                adv_train_state,
                env_state,
                last_obs,
                last_clean_obs,
                last_done,
                hstate,
                adv_hstate,
                rng,
                update_step
            ) = runner_state

            # (Removed project-specific apply_mask usage.)

            def protagonist_update_loop(runner_state, unused):
                initial_hstate = runner_state[6]
                env_step_fn = lambda a, b: _env_step(a, b, False)
                runner_state, traj_batch = jax.lax.scan(
                    env_step_fn, runner_state, None, config["NUM_STEPS"]
                )
                (
                    train_state,
                    adv_train_state,
                    env_state,
                    last_obs,
                    last_clean_obs,
                    last_done,
                    hstate,
                    adv_hstate,
                    rng,
                    update_step
                ) = runner_state

                if config["USE_LTH"]:
                    def revert_params(p1, p2, os):
                        return pruner.post_gradient_update(p2, os)
                    def revert_opt_state(os1, os2):
                        return os2._replace(masks=os1.masks,
                                 count=os1.count, target_sparsities=os1.target_sparsities)

                    lth_params = jax.lax.cond(
                        update_step == config["NUM_UPDATES"],
                        revert_params,
                        lambda p1, p2, os: p1,
                        train_state.params,
                        save_victim_params,
                        train_state.opt_state,
                    )

                    lth_opt_state = jax.lax.cond(
                        update_step == config["NUM_UPDATES"],
                        revert_opt_state,
                        lambda os1, os2: os1,
                        train_state.opt_state,
                        save_opt_state,
                    )

                    train_state = train_state.replace(params=lth_params, opt_state=lth_opt_state)
                filtered_last_obs, filtered_last_done = filter_p(last_obs, last_done)

                if config["USE_RNN"]:
                    init_hstate = initial_hstate[None, :]
                    ac_in = (filtered_last_obs[np.newaxis, :], filtered_last_done[np.newaxis, :])
                    _, _, last_val = network.apply(train_state.params, hstate, ac_in)
                    last_val = last_val.squeeze(0)
                else:
                    init_hstate = initial_hstate
                    _, last_val = network.apply(train_state.params, filtered_last_obs)

                advantages, targets = _calculate_gae(traj_batch, last_val, filtered_last_done)

                update_state = (
                    train_state,
                    init_hstate,
                    traj_batch,
                    advantages,
                    targets,
                    rng,
                    update_step,
                )
                update_fn = lambda a, b: _update_epoch(a, b, uses_hstate=config['USE_RNN'], is_adv=False)
                update_state, _ = jax.lax.scan(
                    update_fn, update_state, None, config["UPDATE_EPOCHS"]
                )
                train_state = update_state[0]
                rng = update_state[-2]

                metric = jax.tree.map(
                    lambda x: (x * traj_batch.info["returned_episode"]).sum()
                    / traj_batch.info["returned_episode"].sum(),
                    traj_batch.info,
                )
                metric["sparsity"] = param_sparsity(train_state.params)
                if config["USE_WANDB"]:

                    def callback(metric, update_step):
                        logger.log_step(int(update_step), metric)

                    jax.debug.callback(callback, metric, update_step)

                runner_state = (
                    train_state,
                    adv_train_state,
                    env_state,
                    last_obs,
                    last_clean_obs,
                    last_done,
                    hstate,
                    adv_hstate,
                    rng,
                    update_step + 1
                )
                return runner_state, traj_batch

            def antagonist_update_loop(runner_state, unused):
                initial_adv_hstate = runner_state[7]
                env_step_fn = lambda st, _: _env_step(st, None, is_adv=True)
                runner_state, adv_traj_batch = jax.lax.scan(
                    env_step_fn, runner_state, None, length=config["NUM_STEPS"]
                )
                (
                    train_state,
                    adv_train_state,
                    env_state,
                    last_obs,
                    last_clean_obs,
                    last_done,
                    hstate,
                    adv_hstate,
                    rng,
                    steps
                ) = runner_state

                filtered_last_obs, filtered_last_done = filter_a(last_clean_obs), filter_a(last_done)

                if config['USE_ADV_RNN']:
                    init_adv_hstate = initial_adv_hstate[None, :]
                    ac_in = (filtered_last_obs[np.newaxis, :], filtered_last_done[np.newaxis, :])
                    _, _, adv_last_val = adv_network.apply(
                        adv_train_state.params, adv_hstate, ac_in
                    )
                else:
                    init_adv_hstate = initial_adv_hstate
                    _, adv_last_val = adv_network.apply(
                        adv_train_state.params, filtered_last_obs
                    )

                adv_advantages, adv_targets = _calculate_gae(
                    adv_traj_batch, adv_last_val, filtered_last_done.astype(jnp.bool_)
                )

                update_state = (
                    adv_train_state,
                    init_adv_hstate,
                    adv_traj_batch,
                    adv_advantages,
                    adv_targets,
                    rng,
                    update_step,
                )
                update_fn = lambda a, b: _update_epoch(a, b, uses_hstate=config['USE_ADV_RNN'], is_adv=True)
                update_state, _ = jax.lax.scan(
                    update_fn, update_state, None, config["UPDATE_EPOCHS"]
                )
                adv_train_state = update_state[0]
                rng = update_state[-2]

                runner_state = (
                    train_state,
                    adv_train_state,
                    env_state,
                    last_obs,
                    last_clean_obs,
                    last_done,
                    hstate,
                    adv_hstate,
                    rng,
                    steps
                )
                return runner_state, adv_traj_batch

            update_epochs = 1
            adv_update_epochs = 1

            runner_state, _ = jax.lax.scan(protagonist_update_loop, runner_state, None, update_epochs)
            if config['USE_ATLA']:
                runner_state, _ = jax.lax.scan(antagonist_update_loop, runner_state, None, adv_update_epochs)

            return runner_state, {}

        rng, _rng = jax.random.split(rng)
        done = jnp.zeros((config["NUM_ENVS"]), dtype=bool)

        # No explicit history buffers or state/done padding — hstate carries memory.
        runner_state = (
            train_state,
            adv_train_state,
            env_state,
            obsv,
            obsv,
            done,
            init_hstate,
            init_adv_hstate,
            _rng,
            0
        )
        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"] + config["NUM_LTH_UPDATES"]
        )
        return {"runner_state": runner_state, "metric": metric}

    return train


def run_ppo(config):
    # --- Build run name for logs/W&B ---
    # Model type
    net_type = "RNN" if config["USE_RNN"] else "MLP"

    # Algorithm type
    algo_type = "ATLA" if config["USE_ATLA"] else "PPO"

    # Learning rates
    lr_str = (
        f"lr={config['LR']}-adv_lr={config['ADV_LR']}"
        if config["USE_ATLA"]
        else f"lr={config['LR']}"
    )

    # Pruning info (only if pruning is enabled)
    prune_str = ""
    if config.get("USE_PRUNING", False):
        prune_str = f"-prune={config['PRUNER_TYPE']}-{int(config['PRUNE_PERCENTAGE']*100)}pct"

    # Adversarial training info
    adv_str = "-SA" if config.get("USE_SA_PPO", False) else ""

    # Final run name
    run_name = (
        f"{algo_type}-{net_type}"
        f"-env={config['ENV_NAME']}"
        f"{prune_str}{adv_str}"
        f"-{lr_str}"
        f"-{int(config['TOTAL_TIMESTEPS'] // 1e6)}M"
        f"-seed={config['SEED']}"
    )

    # --- Optional W&B logging ---
    if config.get("USE_WANDB", False):
        wandb_kwargs = dict(
            project=config["WANDB_PROJECT"],
            config=config,
            name=run_name,
            anonymous="allow",
        )
        if config.get("WANDB_ENTITY"):
            wandb_kwargs["entity"] = config["WANDB_ENTITY"]
        wandb.init(**wandb_kwargs)

    # PRNGs and compiled train fn
    rng = jax.random.PRNGKey(config["SEED"])
    rngs = jax.random.split(rng, config["NUM_REPEATS"])
    train_jit = jax.jit(make_train(config))
    train_vmap = jax.vmap(train_jit)

    # Train
    t0 = time.time()
    out = train_vmap(rngs)
    t1 = time.time()

    elapsed = t1 - t0
    print("Time to run experiment:", elapsed)
    print("SPS:", config["TOTAL_TIMESTEPS"] / elapsed)

    if config['USE_RNN']:
        ckpt_root = "./rnn_checkpoints" if config.get("USE_SA_PPO", False) else "./rnn_no_sa_checkpoints"
    else:
        ckpt_root = "./checkpoints" if config.get("USE_SA_PPO", False) else "./no_sa_checkpoints"

    # Save checkpoint if requested
    if config.get("SAVE_POLICY", False):
        state0 = jax.tree.map(lambda x: x[0], out["runner_state"])
        ckdir = save_checkpoint_with_norm(
            config=config,
            runner_state=state0,
            step=int(config["TOTAL_TIMESTEPS"]),
            ckpt_root= ckpt_root,
        )
        print(f"✔ Checkpoint (with normalization stats) saved to {ckdir}")

    return out


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", type=str, default="hopper")
    parser.add_argument("--num_envs", type=int, default=2048)
    parser.add_argument("--total_timesteps", type=lambda x: int(float(x)), default=5e7)
    parser.add_argument("--lr", type=float, default=2e-4)
    parser.add_argument("--num_steps", type=int, default=10)
    parser.add_argument("--update_epochs", type=int, default=4)
    parser.add_argument("--num_minibatches", type=int, default=32)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--gae_lambda", type=float, default=0.95)
    parser.add_argument("--clip_eps", type=float, default=0.2)
    parser.add_argument("--ent_coef", type=float, default=0.0)
    parser.add_argument("--vf_coef", type=float, default=0.5)
    parser.add_argument("--max_grad_norm", type=float, default=0.5)
    parser.add_argument("--activation", type=str, default="tanh")
    parser.add_argument("--anneal_lr", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--seed", type=int, default=np.random.randint(2**31))
    parser.add_argument("--use_wandb", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--save_policy", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--num_repeats", type=int, default=1)
    parser.add_argument("--layer_size", type=int, default=256)
    parser.add_argument("--wandb_project", type=str, default="anonymous-project")
    parser.add_argument("--wandb_entity", type=str, default="")
    parser.add_argument("--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--optimistic_reset_ratio", type=int, default=16)
    parser.add_argument("--normalize_env", action="store_true", default=True)
    parser.add_argument("--use_rnn", action=argparse.BooleanOptionalAction, default=False)

    # Adversary
    parser.add_argument("--adv_epsilon", type=float, default=0.075)
    parser.add_argument("--adv_lr", type=float, default=5e-4)
    parser.add_argument("--adv_clip_eps", type=float, default=0.2)
    parser.add_argument("--adv_ent_coef", type=float, default=0.01)
    parser.add_argument("--adv_vf_coef", type=float, default=0.5)
    parser.add_argument("--adv_layer_size", type=int, default=128)
    parser.add_argument("--adv_eps", type=float, default=0.8)
    parser.add_argument("--adv_update_epochs", type=int, default=1)
    parser.add_argument("--use_adv_rnn", action=argparse.BooleanOptionalAction, default=False)

    # ATLA flag
    parser.add_argument("--use_ATLA", action=argparse.BooleanOptionalAction, default=False)

    parser.add_argument("--use_state_history", action=argparse.BooleanOptionalAction, default=False)

    # Pruning (optional)
    parser.add_argument("--use_pruning", action=argparse.BooleanOptionalAction, default=False,
                        help="Enable parameter pruning during training.")
    parser.add_argument("--pruner_type", type=str, default="magnitude",
                        help="Pruning algorithm (e.g., 'rigl', 'magnitude').")
    parser.add_argument("--prune_percentage", type=float, default=0.9,
                        help="Target global sparsity in [0, 1].")
    parser.add_argument("--prune_burnin", type=float, default=0.25,
                        help="Outer updates to wait before pruning starts (multiplied by UPDATE_EPOCHS*NUM_MINIBATCHES).")
    parser.add_argument("--prune_dist_type", type=str, default="erk",
                        help="Layer-wise sparsity distribution (e.g., 'erk', 'uniform').")
    parser.add_argument("--prune_schedule_power", type=float, default=2,
                        help="Optional schedule shaping power; leave None to omit.")
    
    # --- SA-PPO (state adversarial regularization) ---
    parser.add_argument("--use_sa_ppo", action=argparse.BooleanOptionalAction, default=True,
                        help="Enable SA-PPO KL regularization against adversarial observation perturbations.")
    parser.add_argument("--sa_steps", type=int, default=3,
                        help="PGD ascent steps for the inner max over observations.")
    parser.add_argument("--sa_step_size", type=float, default=0.0,
                        help="PGD step size; if 0, defaults to sa_eps/sa_steps.")
    parser.add_argument("--sa_lambda", type=float, default=0.07,
                        help="Weight for the SA-PPO KL regularizer.")
    parser.add_argument("--sa_subsample", type=int, default=0,
                        help="If >0, compute SA regularizer on a random subsample of this many batch items.")
    parser.add_argument("--policy_sigma", type=float, default=0.1,
                        help="Fixed diagonal Gaussian std for the policy (per-dimension).")

    # Optional: observation clamping (set if your obs are bounded post-normalization)
    parser.add_argument("--obs_clip_low", type=float, default=None,
                        help="Lower clamp for observations during PGD (None = no clamp).")
    parser.add_argument("--obs_clip_high", type=float, default=None,
                        help="Upper clamp for observations during PGD (None = no clamp).")
                        

    parser.add_argument("--use_lth", type=bool, default=False)
    parser.add_argument("--lth_timesteps", type=int, default=5e7)

    args = parser.parse_known_args(sys.argv[1:])[0]
    config = {k.upper(): v for k, v in args.__dict__.items()}

    if config.get("SEED") is None:
        config["SEED"] = 1

    if config["USE_RNN"]:
        config["NUM_STEPS"] = 128
        config["NUM_MINIBATCHES"] = 4
    return config


def main():
    config = parse_args()
    run_ppo(config)


if __name__ == "__main__":
    main()
