# %%
# %%
# %%
# --------------------------------------------------------------------------- #
# 0.  Imports & global JAX flags                                              #
# --------------------------------------------------------------------------- #
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import argparse
import time
import pickle, gc
import jax
import jax.numpy as jnp
import numpy as np
import optax
import flax.linen as nn
from flax.training.train_state import TrainState
from flax.linen.initializers import constant, orthogonal
import orbax.checkpoint as oc

import distrax
import gymnax
import gymnasium as gym

import matplotlib.pyplot as plt
import wandb                               # unchanged – even if unused
from typing import Sequence, NamedTuple, Any

# Local modules
from wrappers import (
    LogWrapper, FlattenObservationWrapper, LogEnvState
)
from a2c_discrete import make_train
from models import ActorCriticDiscreteAction, PredictabilityHead
from utils import Transition
import wandb

# --------------------------------------------------------------------------- #
# 1.  CLI / configuration                                                     #
# --------------------------------------------------------------------------- #
def get_merged_parser() -> argparse.ArgumentParser:
    """
    Unified CLI covering:

      • A2C on-policy training
      • Roll-out collection for the Predictability Transformer
      • Predictability-Transformer training itself
    """
    parser = argparse.ArgumentParser(
        description="A2C → trajectory collection → Predictability Transformer"
    )
    bool_act = (argparse.BooleanOptionalAction
                if hasattr(argparse, "BooleanOptionalAction") else "store_true")

    # -----------------------  A2C hyper-parameters  ----------------------- #
    parser.add_argument("--ENV_NAME", default="Breakout-MinAtar", type=str)
    parser.add_argument("--LR",              default=5e-3,  type=float)
    parser.add_argument("--NUM_ENVS",        default=64,       type=int)
    parser.add_argument("--NUM_STEPS",       default=100,     type=int)
    parser.add_argument("--TOTAL_TIMESTEPS", default=1e7,     type=float)
    parser.add_argument("--UPDATE_EPOCHS",   default=1,       type=int)
    parser.add_argument("--NUM_MINIBATCHES", default=1,       type=int)
    parser.add_argument("--GAMMA",           default=0.99,    type=float)
    parser.add_argument("--GAE_LAMBDA",      default=0.95,    type=float)
    parser.add_argument("--CLIP_EPS",        default=0.2,     type=float)
    parser.add_argument("--ENT_COEF",        default=0.01,    type=float)
    parser.add_argument("--VF_COEF",         default=0.5,     type=float)
    parser.add_argument("--MAX_GRAD_NORM",   default=0.5,     type=float)
    parser.add_argument("--ACTIVATION",      default="tanh",
                        choices=["tanh", "relu"])
    parser.add_argument("--ANNEAL_LR", dest="ANNEAL_LR",
                        default=True, action=bool_act)
    parser.add_argument("--DEBUG", dest="DEBUG",
                        default=True, action=bool_act)
    parser.add_argument("--SAVE_DIR", default="./complete_discrete_long_run/", type=str)
    parser.add_argument("--SEED", default=0, type=int)

    # -----------------------  Data-collection args  ----------------------- #
    parser.add_argument("--num_eps",           default=1,    type=int)
    parser.add_argument("--driver_traj_len",   default=10,   type=int)
    parser.add_argument("--general_traj_len",  default=200,  type=int)
    parser.add_argument("--num_drivertest_states", default=25, type=int)
    
    # -------------  Predictability-Transformer parameters  --------------- #
    parser.add_argument("--env_name", default="Breakout-MinAtar", type=str)
    parser.add_argument("--batch_size", default=16, type=int)
    parser.add_argument("--learning_rate", default=1e-3, type=float)
    parser.add_argument("--num_epochs", default=100, type=int)
    parser.add_argument("--num_heads",  default=8,  type=int)
    parser.add_argument("--num_layers", default=4,  type=int)
    parser.add_argument("--hidden_dim", default=128, type=int)
    parser.add_argument("--grad_clipping", default=1.0, type=float)
    parser.add_argument("--warmup_steps", default=1000, type=int)
    parser.add_argument("--val_log_freq", default=10, type=int)
    
    # Experiment-related
    parser.add_argument("--experiment_name", default="pred_transformer_experiment", type=str,
                        help="Name of the experiment for logging purposes")
    parser.add_argument("--wandb_project", default="complete-discrete-longrun", type=str)
    return parser



# %%

parser  = get_merged_parser()
config  = vars(parser.parse_args())          # pass [] for notebook / demo
print("CONFIG:", config)

wandb.init(
    project=config["wandb_project"],
    entity="",
    config=config,
    mode="online",
    dir="./wandb",
)    
    
rng = jax.random.PRNGKey(config["SEED"])

# --------------------------------------------------------------------------- #
# 2.  Train A2C                                                               #
# --------------------------------------------------------------------------- #
print("\n=== Training A2C ===")
t0 = time.time()
train_jit = jax.jit(make_train(config))
rng, rng_a2c = jax.random.split(rng)
a2c_out = train_jit(rng_a2c)
print(f"A2C training completed in {time.time() - t0:.1f}s | "
      f"num policies = "
      f"{a2c_out['params']['params']['Dense_0']['kernel'].shape[0]}")


# %%

# --------------------------------------------------------------------------- #
# 3.  Collect rollouts                                                        #
# --------------------------------------------------------------------------- #
print("\n=== Collecting rollouts ===")
env, env_params = gymnax.make(config["ENV_NAME"])
env = FlattenObservationWrapper(env)
env = LogWrapper(env)

network = ActorCriticDiscreteAction(
    action_dim=env.action_space().n,
    activation=config["ACTIVATION"],
)

num_eps          = config["num_eps"]            # here = 1
driver_traj_len  = config["driver_traj_len"]
general_traj_len = config["general_traj_len"]

# Seed reset
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, a2c_out['params']['params']['Dense_0']['kernel'].shape[0])
obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

# ----------------  rollout function (JIT, static traj_len)  ---------------- #
from functools import partial



# %%

@partial(jax.jit, static_argnames=("traj_len",))
def collect_trajectory(params, env_state, obsv, rng, traj_len):
    """Roll out `traj_len` steps of the env under a given policy."""
    def _env_step(runner_state, _):
        params, env_state, last_obs, rng = runner_state
        rng, action_rng = jax.random.split(rng)
        pi, value = network.apply(params, last_obs)
        action    = pi.sample(seed=action_rng)
        log_prob  = pi.log_prob(action)
        rng, rng_step = jax.random.split(rng)
        
        obsv, env_state, reward, done, info = env.step(rng_step, env_state, action, env_params)

        transition = Transition(
            done, action, value, reward, log_prob, last_obs, info
        )
        return (params, env_state, obsv, rng), (transition, env_state)

    runner_state, (traj_batch, env_states) = jax.lax.scan(
        _env_step,
        (params, env_state, obsv, rng),
        xs=None,
        length=traj_len,
    )
    return traj_batch, env_states


trajs, env_states = jax.vmap(
    collect_trajectory,
    in_axes=(0, 0, 0, None, None)
)(a2c_out["params"], env_state, obsv, rng, general_traj_len)

n_transitions = trajs.reward.size
print(f"Collected {n_transitions} transitions "
      f"in {trajs.obs.shape[0]} trajectories.")


# %%


# --------------------------------------------------------------------------- #
# 4.  Returns & discounted returns                                            #
# --------------------------------------------------------------------------- #
@jax.jit
def compute_discounted_episodic_return(rewards, gamma=0.99):
    """Return both cumulative and per-step discounted returns."""
    def _body(carry, r):
        carry = carry * gamma + r
        return carry, carry
    last, all_ret = jax.lax.scan(_body, 0.0, rewards, reverse=True)
    return last, all_ret

traj_disc_ret, traj_returns = jax.vmap(
    compute_discounted_episodic_return, in_axes=(0, None)
)(trajs.reward, config["GAMMA"])

print(f"num trajectories={traj_returns.shape[0]}  |  traj len={traj_returns.shape[1]}")


# %%

# --------------------------------------------------------------------------- #
# 5.  Driver-test states & env-state slicing                                  #
# --------------------------------------------------------------------------- #
driver_states_idx = jax.random.choice(
    rng, trajs.obs.shape[0] * trajs.obs.shape[1],
    shape=(config["num_drivertest_states"],),
    replace=False,
)
driver_states = trajs.obs.reshape(-1, trajs.obs.shape[-1])[driver_states_idx]


# %%


# ---------- helper functions for env-state replacement --------------------- #
def _flatten_first_two(x):
    if x.ndim <= 1:
        return x
    t, b, *rest = x.shape
    return x.reshape(t * b, *rest)


def _select_rows(x, idx):
    return _flatten_first_two(x)[idx]


def _match_shape(template, x):
    return x if x.shape == template.shape else jnp.reshape(x, template.shape)


def replace_inner_env_state(driver_env_states, env_states, idx):
    """Replace inner EnvState of driver_env_states with selected rows."""
    raw = jax.tree_util.tree_map(
        lambda leaf: _select_rows(leaf, idx), env_states.env_state
    )
    tmpl = driver_env_states.env_state
    new_inner = jax.tree_util.tree_map(_match_shape, tmpl, raw)

    import dataclasses
    if dataclasses.is_dataclass(driver_env_states):
        return dataclasses.replace(driver_env_states, env_state=new_inner)
    return driver_env_states._replace(env_state=new_inner)


# ---------- build driver_env_states --------------------------------------- #
driver_reset_rng = jax.random.split(_rng, config["num_drivertest_states"])
_, driver_env_states = jax.vmap(env.reset, in_axes=(0, None))(
    driver_reset_rng, env_params
)
driver_env_states = replace_inner_env_state(
    driver_env_states, env_states, driver_states_idx
)

# --------------------------------------------------------------------------- #
# 6.  Roll out driver trajectories                                            #
# --------------------------------------------------------------------------- #
@partial(jax.jit, static_argnames=("traj_len",))
def collect_driver_trajectory(params, env_state, obsv, rng, traj_len):
    """Roll out `traj_len` steps of the env under a given policy."""
    def _env_step(runner_state, _):
        params, env_state, last_obs, rng = runner_state
        rng, action_rng = jax.random.split(rng)
        pi, value = network.apply(params, last_obs)
        action    = pi.sample(seed=action_rng)
        log_prob  = pi.log_prob(action)
        rng_step = jax.random.split(rng, last_obs.shape[0])
        
        # obsv, env_state, reward, done, info = env.step(rng_step, env_state, action, env_params)
        obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
            rng_step, env_state, action, env_params
        )

        transition = Transition(
            done, action, value, reward, log_prob, last_obs, info
        )
        return (params, env_state, obsv, rng), (transition, env_state)

    runner_state, (traj_batch, env_states) = jax.lax.scan(
        _env_step,
        (params, env_state, obsv, rng),
        xs=None,
        length=traj_len,
    )
    return traj_batch, env_states


driver_trajs, driver_env_states_rollout = jax.vmap(
    collect_driver_trajectory,
    in_axes=(0, None, None, None, None)
)(a2c_out["params"], driver_env_states, driver_states, rng, driver_traj_len)


# %%

driver_returns, _ = jax.vmap(
    compute_discounted_episodic_return, in_axes=(0, None)
)(driver_trajs.reward.transpose((0, 2, 1)).reshape(-1, driver_traj_len),
  config["GAMMA"])
driver_returns = driver_returns.reshape(-1, config["num_drivertest_states"])

# --------------------------------------------------------------------------- #
# 7.  Build PT training dataset                                               #
# --------------------------------------------------------------------------- #
num_policies, _, obs_dim = trajs.obs.shape
query_states  = trajs.obs.reshape(-1, obs_dim)
query_returns = traj_returns.reshape(-1)
repeat_factor = query_returns.shape[0] // driver_returns.shape[0]


# %%

os.makedirs(config["SAVE_DIR"], exist_ok=True)

experiment_name = config['experiment_name']

offline_data = {
    "config": config,
    "trajs":  trajs,        # DeviceArrays stay on GPU/TPU
    "driver_trajs": driver_trajs,
    "env_states": env_states,
    "driver_env_states_rollout": driver_env_states_rollout,
}
ckpt_dir = f'{config["SAVE_DIR"]}/{experiment_name}_offline_ckpt'
ckptr    = oc.Checkpointer(oc.PyTreeCheckpointHandler())   # TensorStore backend
ckptr.save(ckpt_dir, offline_data, force=True)
del env_states, traj_disc_ret, traj_returns
del trajs, driver_trajs

with open(f'{config["SAVE_DIR"]}/{experiment_name}_predtran_train_data.pkl', "wb") as f:
    pickle.dump({
        "config": config,
        "driver_states": driver_states,
        "driver_env_states": driver_env_states,
        "driver_returns": driver_returns,
        "query_states": query_states,
        "query_returns": query_returns,
    }, f)
    
with open(f'{config["SAVE_DIR"]}/{experiment_name}_a2c_params.pkl', "wb") as f:
    pickle.dump({
        "config": config,
        "a2c_out": a2c_out["params"],
    }, f)
print("Saved a2c params")

driver_returns = jnp.repeat(driver_returns, repeat_factor, axis=0)
print("transformer training data.")




# %%

# --------------------------------------------------------------------------- #
# 8.  Predictability Transformer                                              #
# --------------------------------------------------------------------------- #
# Delete variables to free memory
# print("Clearing jax cache")
# jax.clear_caches()

args = parser.parse_args()                         # reused names

# Process data:
from utils import load_feat_extractor_params, extract_submodel
actor_critic_params = extract_submodel(a2c_out["params"]["params"], -1)
del a2c_out
feat_extractor_params = load_feat_extractor_params(actor_critic_params)
from models import FeatExtractorDiscreteAction

feat_extractor = FeatExtractorDiscreteAction(activation=config["ACTIVATION"])

# Extract features and use them
driver_states = feat_extractor.apply(feat_extractor_params, driver_states)
query_states = feat_extractor.apply(feat_extractor_params, query_states)

predictor = PredictabilityHead(
    num_heads=args.num_heads,
    num_layers=args.num_layers,
    hidden_dim=args.hidden_dim,
)

batch_size = args.batch_size
predictor_params = predictor.init(
    rng,
    jnp.zeros((batch_size, args.num_drivertest_states, driver_states.shape[-1])),
    jnp.zeros((batch_size, args.num_drivertest_states)),
    jnp.zeros((batch_size, driver_states.shape[-1])),
)

lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value = 0.0,                          # start from 0
    peak_value = args.learning_rate,           # e.g. 1e-3
    warmup_steps = args.warmup_steps,          # e.g. 1 000
    decay_steps  = args.num_epochs * (query_states.shape[0] // batch_size) - args.warmup_steps,
    end_value    = 0.03 * args.learning_rate,  # ~3 % of peak
)

tx = optax.chain(
    optax.clip_by_global_norm(1.0),            # grad-norm clipping
    optax.adamw(                               # Adam + decoupled weight-decay
        learning_rate = lr_schedule,
        weight_decay  = 1e-2,
        b1 = 0.9, b2 = 0.999, eps = 1e-8,
    ),
)

predictor_state = TrainState.create(
    apply_fn = predictor.apply,
    params   = predictor_params,
    tx       = tx,
)

# --------------------- utility for dynamic slice -------------------------- #
def take_batch(arr: jnp.ndarray, batch_idx: int, B: int):
    start = batch_idx * B
    sizes = (B,) + arr.shape[1:]
    return jax.lax.dynamic_slice(arr, (start,) + (0,) * (arr.ndim - 1), sizes)


# --------------------- training step functions ---------------------------- #
@jax.jit
def _update(state, d_states, d_returns, q_states, q_returns):
    def loss_fn(params):
        preds = state.apply_fn(params, d_states, d_returns, q_states)
        return jnp.mean((preds - q_returns) ** 2)
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads), loss


@jax.jit
def _epoch_update(state, d_states_all, d_returns_all,
                  q_states_all, q_returns_all, rng):
    n = q_states_all.shape[0]
    n_batches = n // batch_size

    rng, key = jax.random.split(rng)
    perm = jax.random.permutation(key, n)
    qs, qr = q_states_all[perm], q_returns_all[perm]

    d_states_fixed = jnp.broadcast_to(
        d_states_all,
        (batch_size, args.num_drivertest_states, driver_states.shape[-1]),
    )

    def step(s, i):
        d_r = take_batch(d_returns_all, i, batch_size)
        q_s = take_batch(qs, i, batch_size)
        q_r = take_batch(qr, i, batch_size)
        return _update(s, d_states_fixed, d_r, q_s, q_r)

    state, losses = jax.lax.scan(step, state, jnp.arange(n_batches))
    return state, jnp.mean(losses)


@jax.jit
def _val_epoch(state, d_states_all, d_returns_all, q_states_all, q_returns_all):
    n = q_states_all.shape[0]
    n_batches = n // batch_size

    d_states_fixed = jnp.broadcast_to(
        d_states_all,
        (batch_size, args.num_drivertest_states, driver_states.shape[-1]),
    )

    def step(loss_acc, i):
        d_r = take_batch(d_returns_all, i, batch_size)
        q_s = take_batch(q_states_all, i, batch_size)
        q_r = take_batch(q_returns_all, i, batch_size)
        preds = state.apply_fn(state.params, d_states_fixed, d_r, q_s)
        return loss_acc + jnp.mean((preds - q_r) ** 2), loss_acc

    total, _ = jax.lax.scan(step, 0.0, jnp.arange(n_batches))
    return total / (n_batches) # Represents per sample loss, batch-wise

# %%
# -------------------  train/val split ------------------------------------ #
rng, split_key = jax.random.split(rng)
n_samples = query_states.shape[0]
num_epochs = args.num_epochs
perm = jax.random.permutation(split_key, n_samples)
split = int(0.8 * n_samples)

train_idx, val_idx = perm[:split], perm[split:]


# %%

# -------------------  Epoch loop ----------------------------------------- #
print("\n=== Training Predictability Transformer ===")
val_log_freq = args.val_log_freq
t0 = time.time()

def _epoch_scan(carry, epoch_i):
    state, rng = carry
    rng, sub = jax.random.split(rng)
    
    val_loss = _val_epoch(state, driver_states, driver_returns[val_idx], query_states[val_idx], query_returns[val_idx])
    
    state, tr_loss = _epoch_update(state, driver_states, driver_returns[train_idx], query_states[train_idx], query_returns[train_idx], sub)
    # Print the epoch number and training loss
    jax.debug.print(
        "Epoch: {epoch_i}",
        epoch_i=epoch_i
    )
    return (state, rng), (tr_loss, val_loss)


(predictor_state, _), (train_losses, val_losses) = jax.lax.scan(
    _epoch_scan,
    (predictor_state, rng),
    xs=jnp.arange(num_epochs),
)

num_batches_per_epoch = n_samples // batch_size
train_losses = train_losses.reshape(-1)
val_losses = val_losses.reshape(-1)

for i in range(len(train_losses)):
    wandb.log({
        "train_loss": float(train_losses[i]),
        "val_loss": float(val_losses[i]),
        "epoch": i + 1,
    })

final_val_loss = _val_epoch(predictor_state, driver_states, driver_returns[val_idx], query_states[val_idx], query_returns[val_idx])
wandb.log({"val_loss": float(final_val_loss), "epoch": num_epochs + 1})

print(f"Transformer training time: {time.time() - t0:.1f}s")


# Pickle the transformer along with the feature extractor
with open(f'{config["SAVE_DIR"]}/{experiment_name}_pred_transformer.pkl', "wb") as f:
    pickle.dump(
        {
            "feat_extractor_params": feat_extractor_params,
            "predictor_params": predictor_state.params,
        }, f
    )
    

# --------------------------------------------------------------------------- #
# 10.  (Optional) dump artefacts                                              #
# --------------------------------------------------------------------------- #
# with open("a2c_out.pkl", "wb") as f:
#     pickle.dump(a2c_out, f)
# --------------------------------------------------------------------------- #




# %%



