import os 
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import pickle
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
import gymnax
from wrappers import LogWrapper, FlattenObservationWrapper
import gymnasium as gym
import wandb
import os
import time
import matplotlib.pyplot as plt
from models import ActorCriticDiscreteAction

import datetime
import re

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray

def make_train(config):
    # Derived quantities.
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env, env_params = gymnax.make(config["ENV_NAME"])
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)

    def linear_schedule(count):
        frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * frac

    def train(rng):
        # ------------------------
        # INIT NETWORK & OPTIMIZER
        # ------------------------
        network = ActorCriticDiscreteAction(env.action_space(env_params).n, activation=config["ACTIVATION"])
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5)
            )
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        # ------------------------
        # INIT ENVIRONMENT
        # ------------------------
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

        # ------------------------
        # TRAIN LOOP: UPDATE STEP
        # ------------------------
        def _update_step(runner_state, unused):
            # --- COLLECT TRAJECTORIES ---
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION & VALUE
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP THE ENVIRONMENT
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(
                    env.step, in_axes=(0, 0, 0, None)
                )(rng_step, env_state, action, env_params)
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                new_state = (train_state, env_state, obsv, rng)
                return new_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # --- COMPUTE ADVANTAGES & RETURNS (using GAE) ---
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(carry, transition):
                    gae, next_value = carry
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # --- UPDATE NETWORK (A2C UPDATE) ---
            # In A2C we use a single update pass (optionally with mini-batches)
            batch_size = config["NUM_ENVS"] * config["NUM_STEPS"]
            # Flatten rollout from [NUM_STEPS, NUM_ENVS, ...] to [batch_size, ...]
            b_obs = traj_batch.obs.reshape((batch_size,) + traj_batch.obs.shape[2:])
            b_actions = traj_batch.action.reshape((batch_size,) + traj_batch.action.shape[2:])
            b_advantages = advantages.reshape((batch_size,))
            b_returns = targets.reshape((batch_size,))

            # Shuffle the batch while keeping correspondence.
            permutation = jax.random.permutation(rng, batch_size)
            b_obs = jnp.take(b_obs, permutation, axis=0)
            b_actions = jnp.take(b_actions, permutation, axis=0)
            b_advantages = jnp.take(b_advantages, permutation, axis=0)
            b_returns = jnp.take(b_returns, permutation, axis=0)

            # Optionally, if you want to use mini-batches (here we split into NUM_MINIBATCHES mini-batches)
            num_minibatches = config["NUM_MINIBATCHES"]
            minibatch_size = batch_size // num_minibatches
            b_obs = b_obs.reshape((num_minibatches, minibatch_size) + b_obs.shape[1:])
            b_actions = b_actions.reshape((num_minibatches, minibatch_size) + b_actions.shape[1:])
            b_advantages = b_advantages.reshape((num_minibatches, minibatch_size))
            b_returns = b_returns.reshape((num_minibatches, minibatch_size))

            # Define A2C loss (without PPO-specific clipping).
            def loss_fn(params, obs, actions, advantages, returns):
                pi, value = network.apply(params, obs)
                log_prob = pi.log_prob(actions)
                policy_loss = -jnp.mean(log_prob * advantages)
                value_loss = 0.5 * jnp.mean((value - returns) ** 2)
                entropy = jnp.mean(pi.entropy())
                total_loss = policy_loss + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy
                return total_loss, (value_loss, policy_loss, entropy)

            # For A2C, we perform one update pass over the collected batch.
            grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
            # Here we aggregate gradients over all mini-batches:
            def update_minibatch(train_state, minibatch):
                obs_mb, actions_mb, adv_mb, ret_mb = minibatch
                (loss_value, aux), grads = grad_fn(train_state.params, obs_mb, actions_mb, adv_mb, ret_mb)
                new_train_state = train_state.apply_gradients(grads=grads)
                return new_train_state, (loss_value, aux)

            train_state, loss_data = jax.lax.scan(
                update_minibatch, train_state, (b_obs, b_actions, b_advantages, b_returns)
            )
            # (Optional) You could average loss_data over mini-batches here.
            # -------------------------
            # LOGGING: The PPO version logs traj_batch.info.
            # Here we preserve the same logging by passing along the info from the rollout.
            metric = traj_batch.info

            # Update the runner state with the new train_state.
            runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, (metric, train_state.params)

        # Initialize runner state.
        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, (metric, params) = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        return {"runner_state": runner_state, "metrics": metric, "params": params}

    return train

import argparse

def build_arg_parser() -> argparse.ArgumentParser:
    """
    Return an ArgumentParser whose options correspond one-for-one to the
    keys in the default A2C `config` dictionary below.
    """
    parser = argparse.ArgumentParser(
        description="A2C training configuration for Pong-misc"
    )

    # ─── scalar hyper-parameters ───────────────────────────────────────────
    parser.add_argument("--ENV_NAME",        type=str,   default="Pong-misc")
    parser.add_argument("--LR",              type=float, default=2.5e-4)
    parser.add_argument("--NUM_ENVS",        type=int,   default=32)
    parser.add_argument("--NUM_STEPS",       type=int,   default=128)
    parser.add_argument("--TOTAL_TIMESTEPS", type=float, default=5e5)
    parser.add_argument("--UPDATE_EPOCHS",   type=int,   default=1)      # A2C: single epoch
    parser.add_argument("--NUM_MINIBATCHES", type=int,   default=1)      # A2C: no minibatching
    parser.add_argument("--GAMMA",           type=float, default=0.99)
    parser.add_argument("--GAE_LAMBDA",      type=float, default=0.95)
    parser.add_argument("--CLIP_EPS",        type=float, default=0.2)    # kept for completeness
    parser.add_argument("--ENT_COEF",        type=float, default=0.01)
    parser.add_argument("--VF_COEF",         type=float, default=0.5)
    parser.add_argument("--MAX_GRAD_NORM",   type=float, default=0.5)
    parser.add_argument(
        "--ACTIVATION",
        type=str,
        choices=["tanh", "relu", "gelu", "swish"],
        default="tanh",
    )
    # ─── boolean flags (Python ≥ 3.9) ──────────────────────────────────────
    bool_act = getattr(argparse, "BooleanOptionalAction", "store_true")
    parser.add_argument(
        "--ANNEAL_LR", dest="ANNEAL_LR",
        default=True, action=bool_act,
        help="Linearly anneal the learning rate (default: True)",
    )
    parser.add_argument(
        "--DEBUG", dest="DEBUG",
        default=True, action=bool_act,
        help="Enable extra logging / assertions (default: True)",
    )
    parser.add_argument("--SAVE_DIR", type=str, default="./a2c_minatar")
    parser.add_argument("--SEED", type=int, default=0)
    
    return parser

def make_short_tag(cfg: dict) -> str:
    """
    Create a compact experiment tag that contains only
    ENV_NAME, SEED, TOTAL_TIMESTEPS, and GAMMA.
    """
    env   = cfg["ENV_NAME"]
    steps = int(cfg["TOTAL_TIMESTEPS"])
    gamma = cfg["GAMMA"]
    seed = cfg["SEED"]
    return f"{env}_steps{steps}_gamma{gamma}_seed{seed}"

if __name__ == "__main__":
    config = vars(build_arg_parser().parse_args())
    experiment_name = make_short_tag(config)
    print("Training env:", config["ENV_NAME"])
    rng = jax.random.PRNGKey(config["SEED"])
    t0 = time.time()
    train_jit = jax.jit(make_train(config))
    out = jax.block_until_ready(train_jit(rng))
    print(f"time: {time.time() - t0:.2f} s")

    # Save the output
    with open(f"{config['SAVE_DIR']}/{experiment_name}_metrics.pkl", "wb") as f:
        pickle.dump(
            {
                "config": config,
                "metrics": out["metrics"],
            }, f
        )

    with open(f"{config['SAVE_DIR']}/{experiment_name}_params.pkl", "wb") as f:
        pickle.dump(
            {
                "config": config,
                "params": out["params"],
            }, f
        )
        
    print("Successfully trained env:", config["ENV_NAME"])
