import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"  # Disable XLA preallocation to avoid memory issues
import pickle
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
from wrappers import (
    LogWrapper,
    BraxGymnaxWrapper,
    VecEnv,
    NormalizeVecObservation,
    NormalizeVecReward,
    ClipAction,
)
import matplotlib.pyplot as plt
import time
from models import ActorCriticContinuousAction

# Container for a single environment transition.
class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray


# --- A2C Training Function ---
def make_train(config):
    # Derived quantities.
    config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    config["MINIBATCH_SIZE"] = config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]

    # Create environment using BraxGymnaxWrapper and additional wrappers.
    env, env_params = BraxGymnaxWrapper(config["ENV_NAME"], backend=config["ENV_BACKEND"]), None
    env = LogWrapper(env)
    env = ClipAction(env)
    env = VecEnv(env)
    if config.get("NORMALIZE_ENV", False):
        env = NormalizeVecObservation(env)
        env = NormalizeVecReward(env, config["GAMMA"])

    def linear_schedule(count):
        frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * frac

    def train(rng):
        # --- Initialize Network & Optimizer ---
        network = ActorCriticContinuousAction(action_dim=env.action_space(env_params).shape[0],
                              activation=config["ACTIVATION"])
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )
        train_state = TrainState.create(apply_fn=network.apply, params=network_params, tx=tx)

        # --- Initialize Environment ---
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = env.reset(reset_rng, env_params)

        # --- Update Step: Collect Trajectories, Compute Advantages, Update Network ---
        def _update_step(runner_state, unused):
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = env.step(rng_step, env_state, action, env_params)
                transition = Transition(done, action, value, reward, log_prob, last_obs, info)
                new_state = (train_state, env_state, obsv, rng)
                return new_state, transition

            runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["NUM_STEPS"])

            # --- Compute Advantages using GAE ---
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)
            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(carry, transition):
                    gae, next_value = carry
                    done, value, reward = transition.done, transition.value, transition.reward
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    return (gae, value), gae
                _, advantages = jax.lax.scan(_get_advantages,
                                             (jnp.zeros_like(last_val), last_val),
                                             traj_batch,
                                             reverse=True,
                                             unroll=16)
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # --- A2C Update ---
            batch_size = config["NUM_ENVS"] * config["NUM_STEPS"]
            b_obs = traj_batch.obs.reshape((batch_size,) + traj_batch.obs.shape[2:])
            b_actions = traj_batch.action.reshape((batch_size,) + traj_batch.action.shape[2:])
            b_advantages = advantages.reshape((batch_size,))
            b_returns = targets.reshape((batch_size,))

            # Shuffle the batch while preserving correspondence.
            permutation = jax.random.permutation(rng, batch_size)
            b_obs = jnp.take(b_obs, permutation, axis=0)
            b_actions = jnp.take(b_actions, permutation, axis=0)
            b_advantages = jnp.take(b_advantages, permutation, axis=0)
            b_returns = jnp.take(b_returns, permutation, axis=0)

            # Optionally split into mini-batches.
            num_minibatches = config["NUM_MINIBATCHES"]
            minibatch_size = batch_size // num_minibatches
            b_obs = b_obs.reshape((num_minibatches, minibatch_size) + b_obs.shape[1:])
            b_actions = b_actions.reshape((num_minibatches, minibatch_size) + b_actions.shape[1:])
            b_advantages = b_advantages.reshape((num_minibatches, minibatch_size))
            b_returns = b_returns.reshape((num_minibatches, minibatch_size))

            def loss_fn(params, obs, actions, advantages, returns):
                pi, value = network.apply(params, obs)
                log_prob = pi.log_prob(actions)
                policy_loss = -jnp.mean(log_prob * advantages)
                value_loss = 0.5 * jnp.mean((value - returns) ** 2)
                entropy = jnp.mean(pi.entropy())
                total_loss = policy_loss + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy
                return total_loss, (value_loss, policy_loss, entropy)

            grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
            def update_minibatch(train_state, minibatch):
                obs_mb, actions_mb, adv_mb, ret_mb = minibatch
                (loss_value, aux), grads = grad_fn(train_state.params, obs_mb, actions_mb, adv_mb, ret_mb)
                new_train_state = train_state.apply_gradients(grads=grads)
                return new_train_state, (loss_value, aux)

            train_state, loss_data = jax.lax.scan(update_minibatch, train_state,
                                                  (b_obs, b_actions, b_advantages, b_returns))
            # Preserve logging by passing the environment info.
            metric = traj_batch.info
            runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, (metric, train_state.params)

        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, (metric, params) = jax.lax.scan(_update_step, runner_state, None, config["NUM_UPDATES"])
        return {"runner_state": runner_state, "metrics": metric, "params": params}

    return train

import argparse

def build_arg_parser() -> argparse.ArgumentParser:
    """
    Return an ArgumentParser whose options correspond 1-for-1 to the keys
    in the default `config` dict below.
    """
    parser = argparse.ArgumentParser(
        description="Generalised A2C / PPO configuration (halfcheetah-positional)"
    )

    # ── scalar hyper-parameters ────────────────────────────────────────────
    parser.add_argument("--LR",              type=float, default=3e-4)
    parser.add_argument("--NUM_ENVS",        type=int,   default=256)
    parser.add_argument("--NUM_STEPS",       type=int,   default=10)
    parser.add_argument("--TOTAL_TIMESTEPS", type=float, default=1e6)
    parser.add_argument("--UPDATE_EPOCHS",   type=int,   default=1)
    parser.add_argument("--NUM_MINIBATCHES", type=int,   default=1)
    parser.add_argument("--GAMMA",           type=float, default=0.99)
    parser.add_argument("--GAE_LAMBDA",      type=float, default=0.95)
    parser.add_argument("--CLIP_EPS",        type=float, default=0.2)
    parser.add_argument("--ENT_COEF",        type=float, default=0.0)
    parser.add_argument("--VF_COEF",         type=float, default=0.5)
    parser.add_argument("--MAX_GRAD_NORM",   type=float, default=0.5)

    parser.add_argument(
        "--ACTIVATION",
        type=str,
        choices=["tanh", "relu", "gelu", "swish"],
        default="tanh",
    )
    parser.add_argument("--ENV_NAME",        type=str,   default="halfcheetah")
    parser.add_argument("--ENV_BACKEND",     type=str,   default="positional")

    # ── boolean flags (Python ≥ 3.9 provides BooleanOptionalAction) ───────
    bool_act = getattr(argparse, "BooleanOptionalAction", "store_true")

    parser.add_argument(
        "--ANNEAL_LR", dest="ANNEAL_LR",
        default=False, action=bool_act,
        help="Linearly anneal the learning rate (default: False)",
    )
    parser.add_argument(
        "--NORMALIZE_ENV", dest="NORMALIZE_ENV",
        default=True, action=bool_act,
        help="Normalize observations / rewards (default: True)",
    )
    parser.add_argument(
        "--DEBUG", dest="DEBUG",
        default=True, action=bool_act,
        help="Enable extra logging / assertions (default: True)",
    )
    parser.add_argument("--SAVE_DIR", type=str, default="./a2c_brax")
    parser.add_argument("--SEED", type=int, default=0)
    
    return parser

def make_short_tag(cfg: dict) -> str:
    """
    Create a compact experiment tag that contains only
    ENV_NAME, SEED, TOTAL_TIMESTEPS, and GAMMA.
    """
    env   = cfg["ENV_NAME"]
    steps = int(cfg["TOTAL_TIMESTEPS"])
    gamma = cfg["GAMMA"]
    seed = cfg["SEED"]
    return f"{env}_steps{steps}_gamma{gamma}_seed{seed}"


if __name__ == "__main__":
    config = vars(build_arg_parser().parse_args())
    experiment_name = make_short_tag(config)

    print("Training env:", config["ENV_NAME"])
    rng = jax.random.PRNGKey(config["SEED"])
    t0 = time.time()
    train_jit = jax.jit(make_train(config))
    out = jax.block_until_ready(train_jit(rng))
    print(f"time: {time.time() - t0:.2f} s")

    # Save the output
    with open(f"{config['SAVE_DIR']}/{experiment_name}_metrics.pkl", "wb") as f:
        pickle.dump(
            {
                "config": config,
                "metrics": out["metrics"],
            }, f
        )

    with open(f"{config['SAVE_DIR']}/{experiment_name}_params.pkl", "wb") as f:
        pickle.dump(
            {
                "config": config,
                "params": out["params"],
            }, f
        )
        
    print("Successfully trained env:", config["ENV_NAME"])