import os
# Do not preallocate memory for JAX
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import argparse
import time
import pickle
import gc
from functools import partial
from typing import Sequence, NamedTuple, Any

import numpy as np
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
import orbax.checkpoint as oc
import distrax
import gymnax
import gymnasium as gym
import matplotlib.pyplot as plt
import wandb

# Local modules
from wrappers import (
    LogWrapper, FlattenObservationWrapper, LogEnvState,
    BraxGymnaxWrapper, VecEnv, NormalizeVecObservation,
    NormalizeVecReward, ClipAction
)
from a2c_continuous import make_train
from models import ActorCriticContinuousAction, PredictabilityHead, FeatExtractorDiscreteAction
from utils import Transition, load_feat_extractor_params, extract_submodel

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"


def get_merged_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="A2C → trajectory collection → Predictability Transformer"
    )
    bool_act = (argparse.BooleanOptionalAction
                if hasattr(argparse, "BooleanOptionalAction") else "store_true")

    # A2C hyper-parameters
    parser.add_argument("--ENV_NAME", default="hopper", type=str)
    parser.add_argument("--ENV_BACKEND", type=str, default="positional")
    parser.add_argument("--LR", default=3e-4, type=float)
    parser.add_argument("--NUM_ENVS", default=256, type=int)
    parser.add_argument("--NUM_STEPS", default=10, type=int)
    parser.add_argument("--TOTAL_TIMESTEPS", default=1e6, type=float)
    parser.add_argument("--UPDATE_EPOCHS", default=1, type=int)
    parser.add_argument("--NUM_MINIBATCHES", default=1, type=int)
    parser.add_argument("--GAMMA", default=0.99, type=float)
    parser.add_argument("--GAE_LAMBDA", default=0.95, type=float)
    parser.add_argument("--CLIP_EPS", default=0.2, type=float)
    parser.add_argument("--ENT_COEF", default=0., type=float)
    parser.add_argument("--VF_COEF", default=0.5, type=float)
    parser.add_argument("--MAX_GRAD_NORM", default=0.5, type=float)
    parser.add_argument("--ACTIVATION", default="tanh", choices=["tanh", "relu"])
    parser.add_argument("--ANNEAL_LR", dest="ANNEAL_LR", default=True, action=bool_act)
    parser.add_argument("--NORMALIZE_ENV", dest="NORMALIZE_ENV", default=True, action=bool_act,
                        help="Normalize observations / rewards (default: True)")
    parser.add_argument("--DEBUG", dest="DEBUG", default=True, action=bool_act)
    parser.add_argument("--SAVE_DIR", default="./complete_continuous_longrun/", type=str)
    parser.add_argument("--SEED", default=0, type=int)

    # Data-collection args
    parser.add_argument("--num_eps", default=1, type=int)
    parser.add_argument("--driver_traj_len", default=25, type=int)
    parser.add_argument("--general_traj_len", default=1000, type=int)
    parser.add_argument("--num_drivertest_states", default=20, type=int)

    # Predictability-Transformer parameters
    parser.add_argument("--env_name", default="hopper", type=str)
    parser.add_argument("--batch_size", default=16, type=int)
    parser.add_argument("--learning_rate", default=1e-3, type=float)
    parser.add_argument("--num_epochs", default=10, type=int)
    parser.add_argument("--num_heads", default=8, type=int)
    parser.add_argument("--num_layers", default=4, type=int)
    parser.add_argument("--hidden_dim", default=128, type=int)
    parser.add_argument("--grad_clipping", default=1.0, type=float)
    parser.add_argument("--warmup_steps", default=1000, type=int)
    parser.add_argument("--val_log_freq", default=10, type=int)

    # Experiment-related
    parser.add_argument("--experiment_name", default="pred_transformer_experiment", type=str,
                        help="Name of the experiment for logging purposes")
    parser.add_argument("--wandb_project", default="complete-continuous-longrun", type=str)
    return parser


parser = get_merged_parser()
config = vars(parser.parse_args())
print("CONFIG:", config)

wandb.init(
    project=config["wandb_project"],
    entity="",
    config=config,
    mode="online",
    dir="./wandb",
)

rng = jax.random.PRNGKey(config["SEED"])

# Train A2C
print("\n=== Training A2C ===")
t0 = time.time()
train_jit = jax.jit(make_train(config))
rng, rng_a2c = jax.random.split(rng)
a2c_out = train_jit(rng_a2c)
print(f"A2C training completed in {time.time() - t0:.1f}s | "
      f"num policies = {a2c_out['params']['params']['Dense_0']['kernel'].shape[0]}")

# Rollout collection
print("\n=== Collecting rollouts ===")
env, env_params = BraxGymnaxWrapper(config["ENV_NAME"], backend=config["ENV_BACKEND"]), None
env = LogWrapper(env)
env = ClipAction(env)
env = VecEnv(env)
if config.get("NORMALIZE_ENV", False):
    env = NormalizeVecObservation(env)
    env = NormalizeVecReward(env, config["GAMMA"])

network = ActorCriticContinuousAction(
    action_dim=env.action_space(env_params).shape[0],
    activation=config["ACTIVATION"]
)

num_eps = config["num_eps"]
driver_traj_len = config["driver_traj_len"]
general_traj_len = config["general_traj_len"]

rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, num_eps)
obsv, env_state = env.reset(reset_rng, env_params)
print("Observation shape:", obsv.shape)


@partial(jax.jit, static_argnames=("traj_len",))
def collect_trajectory(params, env_state, obsv, rng, traj_len):
    def _env_step(runner_state, _):
        params, env_state, last_obs, rng = runner_state
        B = last_obs.shape[0]
        rng, sub = jax.random.split(rng)
        pi, value = network.apply(params, last_obs)
        action = pi.sample(seed=sub)
        log_prob = pi.log_prob(action)
        rng, sub = jax.random.split(rng)
        rng_step = jax.random.split(sub, B)
        obsv, env_state, reward, done, info = env.step(rng_step, env_state, action, env_params)
        transition = Transition(done, action, value, reward, log_prob, last_obs, info)
        return (params, env_state, obsv, rng), (transition, env_state)

    runner_state, (traj_batch, env_states) = jax.lax.scan(
        _env_step, (params, env_state, obsv, rng), xs=None, length=traj_len)
    return traj_batch, env_states


trajs, env_states = jax.vmap(
    collect_trajectory, in_axes=(0, None, None, None, None)
)(a2c_out["params"], env_state, obsv, rng, general_traj_len)

n_transitions = trajs.reward.size
print(f"Collected {n_transitions} transitions in {trajs.obs.shape[0]} trajectories.")


@jax.jit
def compute_discounted_episodic_return(rewards, gamma=0.99):
    def _body(carry, r):
        carry = carry * gamma + r
        return carry, carry
    last, all_ret = jax.lax.scan(_body, 0.0, rewards, reverse=True)
    return last, all_ret


traj_disc_ret, traj_returns = jax.vmap(
    compute_discounted_episodic_return, in_axes=(0, None)
)(trajs.reward.transpose((0, 2, 1)).reshape(-1, general_traj_len), config["GAMMA"])

print(f"num trajectories={traj_returns.shape[0]}  |  traj len={traj_returns.shape[1]}")

driver_states_idx = jax.random.choice(
    rng, trajs.obs.shape[0] * trajs.obs.shape[1],
    shape=(config["num_drivertest_states"],),
    replace=False,
)
driver_states = trajs.obs.reshape(-1, trajs.obs.shape[-1])[driver_states_idx]
print("Driver states shape:", driver_states.shape)

driver_reset_rng = jax.random.split(_rng, config["num_drivertest_states"])
_, driver_env_states = env.reset(driver_reset_rng, env_params)



def replace_inner_env_state(driver_env_states: Any, env_states: Any, idx: jnp.ndarray):
    k = idx.size

    def tails_compatible(env_shape, rest_shape):
        if not rest_shape:
            return True
        return env_shape[-len(rest_shape):] == rest_shape

    def _replace(driver_leaf, env_leaf):
        d = jnp.asarray(driver_leaf)
        e = jnp.asarray(env_leaf)
        if d.ndim == 0 or d.shape[0] != k:
            return driver_leaf
        rest_shape = d.shape[1:]
        assert tails_compatible(e.shape, rest_shape), (
            f"Trailing shape mismatch: driver {d.shape}, env {e.shape}"
        )
        e_flat = e.reshape((-1, *e.shape[-len(rest_shape):])) if rest_shape else e.reshape((-1,))
        selected = e_flat[idx]
        return selected.reshape(d.shape)

    return jax.tree_util.tree_map(_replace, driver_env_states, env_states)


new_driver_env_states = replace_inner_env_state(
    driver_env_states,
    env_states,
    driver_states_idx
)

driver_trajs, driver_env_states_rollout = jax.vmap(
    collect_trajectory, in_axes=(0, None, None, None, None)
)(a2c_out["params"], new_driver_env_states, driver_states, rng, driver_traj_len)

driver_returns, _ = jax.vmap(
    compute_discounted_episodic_return, in_axes=(0, None)
)(driver_trajs.reward.transpose((0, 2, 1)).reshape(-1, driver_traj_len), config["GAMMA"])

driver_returns = driver_returns.reshape(-1, config["num_drivertest_states"])

num_policies, _, _, obs_dim = trajs.obs.shape
query_states = trajs.obs.reshape(-1, obs_dim)
query_returns = traj_returns.reshape(-1)
repeat_factor = query_returns.shape[0] // driver_returns.shape[0]

os.makedirs(config["SAVE_DIR"], exist_ok=True)

experiment_name = config['experiment_name']

offline_data = {
    "config": config,
    "trajs": trajs,
    "driver_trajs": driver_trajs,
    "env_states": env_states,
    "driver_env_states_rollout": driver_env_states_rollout,
}
ckpt_dir = f'{config["SAVE_DIR"]}/{experiment_name}_offline_ckpt'
ckptr = oc.Checkpointer(oc.PyTreeCheckpointHandler())
ckptr.save(ckpt_dir, offline_data, force=True)


del env_states, traj_disc_ret, traj_returns
del trajs, driver_trajs

with open(f'{config["SAVE_DIR"]}/{experiment_name}_predtran_train_data.pkl', "wb") as f:
    pickle.dump({
        "config": config,
        "driver_states": driver_states,
        "driver_env_states": new_driver_env_states,
        "driver_returns": driver_returns,
        "query_states": query_states,
        "query_returns": query_returns,
    }, f)

with open(f'{config["SAVE_DIR"]}/{experiment_name}_a2c_params.pkl', "wb") as f:
    pickle.dump({
        "config": config,
        "a2c_out": a2c_out["params"],
    }, f)
print("Saved a2c params")

driver_returns = jnp.repeat(driver_returns, repeat_factor, axis=0)
print("transformer training data.")

# Predictability Transformer
args = parser.parse_args()

actor_critic_params = extract_submodel(a2c_out["params"]["params"], -1)
del a2c_out
feat_extractor_params = load_feat_extractor_params(actor_critic_params)
feat_extractor = FeatExtractorDiscreteAction(activation=config["ACTIVATION"])

driver_states = feat_extractor.apply(feat_extractor_params, driver_states)
query_states = feat_extractor.apply(feat_extractor_params, query_states)

predictor = PredictabilityHead(
    num_heads=args.num_heads,
    num_layers=args.num_layers,
    hidden_dim=args.hidden_dim,
)


batch_size = args.batch_size
predictor_params = predictor.init(
    rng,
    jnp.zeros((batch_size, args.num_drivertest_states, driver_states.shape[-1])),
    jnp.zeros((batch_size, args.num_drivertest_states)),
    jnp.zeros((batch_size, driver_states.shape[-1])),
)

lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=args.learning_rate,
    warmup_steps=args.warmup_steps,
    decay_steps=args.num_epochs * (query_states.shape[0] // batch_size) - args.warmup_steps,
    end_value=0.03 * args.learning_rate,
)

tx = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(
        learning_rate=lr_schedule,
        weight_decay=1e-2,
        b1=0.9, b2=0.999, eps=1e-8,
    ),
)

predictor_state = TrainState.create(
    apply_fn=predictor.apply,
    params=predictor_params,
    tx=tx,
)


def take_batch(arr: jnp.ndarray, batch_idx: int, B: int):
    start = batch_idx * B
    sizes = (B,) + arr.shape[1:]
    return jax.lax.dynamic_slice(arr, (start,) + (0,) * (arr.ndim - 1), sizes)


@jax.jit
def _update(state, d_states, d_returns, q_states, q_returns):
    def loss_fn(params):
        preds = state.apply_fn(params, d_states, d_returns, q_states)
        return jnp.mean((preds - q_returns) ** 2)
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads), loss


@jax.jit
def _epoch_update(state, d_states_all, d_returns_all, q_states_all, q_returns_all, rng):
    n = q_states_all.shape[0]
    n_batches = n // batch_size
    rng, key = jax.random.split(rng)
    perm = jax.random.permutation(key, n)
    qs, qr = q_states_all[perm], q_returns_all[perm]
    d_states_fixed = jnp.broadcast_to(
        d_states_all,
        (batch_size, args.num_drivertest_states, driver_states.shape[-1]),
    )

    def step(s, i):
        d_r = take_batch(d_returns_all, i, batch_size)
        q_s = take_batch(qs, i, batch_size)
        q_r = take_batch(qr, i, batch_size)
        return _update(s, d_states_fixed, d_r, q_s, q_r)

    state, losses = jax.lax.scan(step, state, jnp.arange(n_batches))
    return state, jnp.mean(losses)


@jax.jit
def _val_epoch(state, d_states_all, d_returns_all, q_states_all, q_returns_all):
    n = q_states_all.shape[0]
    n_batches = n // batch_size
    d_states_fixed = jnp.broadcast_to(
        d_states_all,
        (batch_size, args.num_drivertest_states, driver_states.shape[-1]),
    )

    def step(loss_acc, i):
        d_r = take_batch(d_returns_all, i, batch_size)
        q_s = take_batch(q_states_all, i, batch_size)
        q_r = take_batch(q_returns_all, i, batch_size)
        preds = state.apply_fn(state.params, d_states_fixed, d_r, q_s)
        return loss_acc + jnp.mean((preds - q_r) ** 2), loss_acc

    total, _ = jax.lax.scan(step, 0.0, jnp.arange(n_batches))
    return total / n_batches


rng, split_key = jax.random.split(rng)
n_samples = query_states.shape[0]
num_epochs = args.num_epochs
perm = jax.random.permutation(split_key, n_samples)
split = int(0.8 * n_samples)
train_idx, val_idx = perm[:split], perm[split:]

print("\n=== Training Predictability Transformer ===")
val_log_freq = args.val_log_freq
t0 = time.time()


def _epoch_scan(carry, epoch_i):
    state, rng = carry
    rng, sub = jax.random.split(rng)
    val_loss = _val_epoch(state, driver_states, driver_returns[val_idx],
                            query_states[val_idx], query_returns[val_idx])
    state, tr_loss = _epoch_update(state, driver_states, driver_returns[train_idx],
                                   query_states[train_idx], query_returns[train_idx],
                                   sub)
    jax.debug.print("Epoch: {epoch_i}", epoch_i=epoch_i)
    return (state, rng), (tr_loss, val_loss)


(predictor_state, _), (train_losses, val_losses) = jax.lax.scan(
    _epoch_scan, (predictor_state, rng), xs=jnp.arange(num_epochs)
)

num_batches_per_epoch = n_samples // batch_size
train_losses = train_losses.reshape(-1)
val_losses = val_losses.reshape(-1)

for i in range(len(train_losses)):
    wandb.log({
        "train_loss": float(train_losses[i]),
        "val_loss": float(val_losses[i]),
        "epoch": i + 1,
    })

final_val_loss = _val_epoch(predictor_state, driver_states, driver_returns[val_idx],
                            query_states[val_idx], query_returns[val_idx])
wandb.log({
    "val_loss": float(final_val_loss),
    "epoch": num_epochs + 1
})

print(f"Transformer training time: {time.time() - t0:.1f}s")

with open(f'{config["SAVE_DIR"]}/{experiment_name}_pred_transformer.pkl', "wb") as f:
    pickle.dump({
        "feat_extractor_params": feat_extractor_params,
        "predictor_params": predictor_state.params,
    }, f)
