import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"  # Disable preallocation for JAX

import argparse
import functools
import pickle
import time
from typing import Tuple, Dict, Any

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import wandb
from flax.training.train_state import TrainState
import gymnax
import flashbax as fbx

from wrappers import BraxGymnaxWrapper, LogWrapper, ClipAction, VecEnv
from wrappers import NormalizeVecObservation, NormalizeVecReward
from models import ActorCriticContinuousAction, FeatExtractorContinuousAction, PredictabilityHead
from utils import Transition


# Helper functions
def linear_schedule(lr: float, num_updates: int, num_steps: int, num_minibatches: int):
    """Return a schedule f(step) that linearly decays LR to 0."""
    denom = num_updates * num_minibatches
    return lambda step: lr * (1.0 - step / denom)


def compute_gae(batch: Transition, last_value: jnp.ndarray, gamma: float, lam: float):
    """Compute Generalized Advantage Estimation.
    
    Returns:
        Tuple of (advantages, returns) with the same leading dims as `batch.value`.
    """
    def scan(carry, trans: Transition):
        gae, next_v = carry
        delta = trans.reward + gamma * next_v * (1.0 - trans.done) - trans.value
        gae = delta + gamma * lam * (1.0 - trans.done) * gae
        return (gae, trans.value), gae

    (_, adv) = jax.lax.scan(scan, (jnp.zeros_like(last_value), last_value),
                            batch, reverse=True, unroll=16)
    return adv, adv + batch.value


def make_train(config, feature_extractor_params, predictor_params, driver_states, driver_env_states):
    """Factory that returns a `train(rng_key)` function for continuous action envs."""

    # Derived config values
    config = dict(config)
    config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    config["MINIBATCH_SIZE"] = config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]

    # Environment setup
    base_env, base_params = BraxGymnaxWrapper(config["ENV_NAME"]), None
    base_env = LogWrapper(base_env)
    base_env = ClipAction(base_env)
    base_env = VecEnv(base_env)
    if config["NORMALIZE_ENV"]:
        base_env = NormalizeVecObservation(base_env)
        base_env = NormalizeVecReward(base_env, config["GAMMA"])

    # Driver-test state setup
    driver_states, driver_int = driver_states, driver_env_states
    k, obs_dim = driver_states.shape[0], driver_states.shape[1]
    
    # Replay buffer setup
    dummy_experience = {
        "driver_returns": jnp.zeros((k,)),              # [k]
        "query_state": jnp.zeros((obs_dim,)),           # [obs]
        "query_value": jnp.zeros(()),                   # scalar
    }
    config["BUFFER_MAX_LENGTH"] = config["NUM_ENVS"] * config["NUM_STEPS"] * config["MAX_POLICIES_IN_BUFFER"]
    config["BUFFER_MIN_LENGTH"] = config["NUM_ENVS"] * config["NUM_STEPS"] * config["MIN_POLICIES_IN_BUFFER"]
    config["NUM_TO_REPLACE"] = int(config["BUFFER_SAMPLE_BSIZE"] * config["RETAIN_PAST_PORTION"])
    flat_buffer = fbx.make_flat_buffer(
        max_length=config["BUFFER_MAX_LENGTH"],
        min_length=config["BUFFER_MIN_LENGTH"],
        sample_batch_size=config["BUFFER_SAMPLE_BSIZE"],
        add_sequences=False,
        add_batch_size=config["NUM_ENVS"] * config["NUM_STEPS"],
    )
    buffer_state0 = flat_buffer.init(dummy_experience)

    # Get the params of pretrained feature extractor and predictor
    feature_extractor_params, predictor_params = feature_extractor_params, predictor_params

    def train(master_key, feature_extractor_params, predictor_params):
        # Networks & optimizers setup
        key_net, key_pred = jax.random.split(master_key)
        actor_critic = ActorCriticContinuousAction(
            action_dim=base_env.action_space(base_params).shape[0],
            activation=config["ACTIVATION"])
        actor_params = actor_critic.init(key_net, jnp.zeros(base_env.observation_space(base_params).shape))
        
        lr_fn = linear_schedule(config["LR"], config["NUM_UPDATES"], 
                                config["NUM_STEPS"], config["NUM_MINIBATCHES"]) if config["ANNEAL_LR"] else config["LR"]
        actor_opt = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(lr_fn, eps=1e-5)
        )
        actor_state = TrainState.create(apply_fn=actor_critic.apply, params=actor_params, tx=actor_opt)

        # Initial environment reset
        key_env, key_driver = jax.random.split(key_pred)
        env_reset_keys = jax.random.split(key_env, config["NUM_ENVS"])
        obs, env_state = base_env.reset(env_reset_keys, base_params)

        driver_obs, driver_int = driver_states, driver_env_states
        
        # Feature extractor and predictor setup
        feat_extractor = FeatExtractorContinuousAction(activation=config["ACTIVATION"])
        driver_states_processed = feat_extractor.apply(feat_extractor_params, driver_obs)  # [k,emb_dim]
        
        predictor = PredictabilityHead(
            config["num_heads"],
            config["hidden_dim"],
            config["num_layers"],
        )
        
        predictor_params = jax.lax.cond(
            config["use_pretrained_transformer"] == 1,
            lambda _: predictor_params,
            lambda _: predictor.init(key_pred, driver_states_processed, jnp.zeros((k,)), 
                                    jnp.zeros((config["BUFFER_SAMPLE_BSIZE"], driver_states_processed.shape[-1]))),
            None
        )
        predictor_state = TrainState.create(
            apply_fn=predictor.apply,
            params=predictor_params,
            tx=optax.adam(config["PRED_LR"])
        )
        
        # Runner tuple keeps mutable state across updates
        runner = (actor_state, env_state, obs,
                  driver_int, driver_obs,
                  predictor_state, buffer_state0,
                  master_key)

        update_idx = jnp.arange(config["NUM_UPDATES"])  # For logging purposes

        def one_update(runner_state, idx):
            (actor_state, env_state, obs,
             driver_int, driver_obs,
             predictor_state, buffer_state,
             rng) = runner_state
            
            def maybe_print(idx):    
                return jax.debug.print("Update {}/{}", idx + 1, config["NUM_UPDATES"])
            
            jax.lax.cond((idx % 100 == 0), maybe_print, lambda _: None, idx)

            rng, rng_roll, rng_perm = jax.random.split(rng, 3)

            # 1. Rollout in parallel environments
            def rollout_step(carry, _):
                (actor_state, env_state, obs, rng) = carry
                rng, key_pi, key_drv = jax.random.split(rng, 3)

                # Main environment step
                pi, v = actor_critic.apply(actor_state.params, obs)
                act = pi.sample(seed=key_pi)
                logp = pi.log_prob(act)
                step_keys = jax.random.split(key_drv, obs.shape[0])
                nxt_obs, env_state, rew, done, info = base_env.step(step_keys, env_state, act, base_params)
                main_tr = Transition(done, act, v, rew, logp, obs, info)

                carry = (actor_state, env_state, nxt_obs, rng)
                return carry, main_tr

            # Main trajectory rollout
            (actor_state, env_state, obs, rng), traj = jax.lax.scan(
                rollout_step, (actor_state, env_state, obs, rng), None, length=config["NUM_STEPS"])
            
            # Driver trajectory rollout
            (actor_state, 
            last_driver_int, last_driver_obs, 
            rng), drv_traj = jax.lax.scan(
                rollout_step, (actor_state, driver_int, driver_obs, rng), None, length=config["driver_traj_len"])
            
            # 2. GAE advantages computation
            _, last_v = actor_critic.apply(actor_state.params, obs)
            _, last_drv_v = actor_critic.apply(actor_state.params, last_driver_obs)
            adv, rets = compute_gae(traj, last_v, config["GAMMA"], config["GAE_LAMBDA"])
            drv_adv, drv_r = compute_gae(drv_traj, last_drv_v, config["GAMMA"], config["GAE_LAMBDA"])

            # 3. Flatten trajectories for mini-batch updates
            B = config["NUM_ENVS"] * config["NUM_STEPS"]
            flat_obs = traj.obs.reshape((B,) + traj.obs.shape[2:])
            flat_act = traj.action.reshape((B,) + traj.action.shape[2:])
            flat_adv = adv.reshape((B,))
            flat_ret = rets.reshape((B,))

            # Mini-batch split
            MB = config["NUM_MINIBATCHES"]
            mb = B // MB
            split = lambda x: x.reshape((MB, mb) + x.shape[1:])
            mb_obs, mb_act, mb_adv, mb_ret = map(split, (flat_obs, flat_act, flat_adv, flat_ret))

            # 4. Add to replay buffer
            experience = {
                "driver_returns": jnp.broadcast_to(drv_r[0], (B, k)),
                "query_state": flat_obs,
                "query_value": flat_ret,
            }
            rng, rng_sample, rng_swap = jax.random.split(rng, 3)

            # Indices of the rows we will overwrite inside the fresh experience
            choose_idx = jax.random.choice(
                rng_swap, B, shape=(config["NUM_TO_REPLACE"],), replace=False
            )

            def _swap_rows(arr_exp, arr_samp):
                """Overwrite selected rows (axis-0) of arr_exp with arr_samp."""
                return arr_exp.at[choose_idx].set(arr_samp[choose_idx])

            def _add_only(_):
                # Buffer not ready – simply push the fresh experience
                return flat_buffer.add(buffer_state, experience)

            def _sample_splice_and_add(_):
                # Buffer ready – sample, splice, push
                sample = flat_buffer.sample(buffer_state, rng_sample)
                batch_samp = sample.experience.first  # Flashbax idiom
                mixed_exp = {k: _swap_rows(experience[k], batch_samp[k]) for k in experience}
                return flat_buffer.add(buffer_state, mixed_exp)

            buffer_state = jax.lax.cond(
                flat_buffer.can_sample(buffer_state),  # predicate
                _sample_splice_and_add,               # true_fn (buffer ready)
                _add_only,                            # false_fn (buffer not ready)
                operand=None
            )
            buffer_state = flat_buffer.add(buffer_state, experience)

            # 5. Predictor head update (if buffer ready)
            def predictor_step(predictor_state, _):
                rng_sample, new_rng = jax.random.split(rng)  
                sample = flat_buffer.sample(buffer_state, rng_sample)
                batch = sample.experience.first
                
                query_states = batch["query_state"]  # [B, obs_dim]
                query_states_processed = feat_extractor.apply(feat_extractor_params, query_states)
                driver_states_processed_broadcasted = jnp.broadcast_to(
                    driver_states_processed, 
                    (batch["driver_returns"].shape[0], k, driver_states_processed.shape[1])
                )
                
                def mse(params):
                    v_hat = predictor.apply(
                        params,
                        driver_states=driver_states_processed_broadcasted,
                        driver_returns=batch["driver_returns"],
                        query_state=query_states_processed
                    )
                    return jnp.mean((v_hat - batch["query_value"]) ** 2)

                loss, grads = jax.value_and_grad(mse)(predictor_state.params)
                new_state = predictor_state.apply_gradients(grads=grads)
                return new_state, loss

            # Helper to scan NUM_PRED_UPDATES steps and grab last loss
            def do_predictor_updates(state):
                final_state, losses = jax.lax.scan(
                    predictor_step,
                    state,
                    None,  # xs placeholder
                    length=config["NUM_PRED_UPDATES"]
                )
                return final_state, losses[-1]

            predictor_state, predictor_pred_loss = jax.lax.cond(
                flat_buffer.can_sample(buffer_state),
                do_predictor_updates,
                lambda state: (state, jnp.nan),
                predictor_state
            )
            
            # 6. Main A2C update across mini-batches
            def loss_pi_v(params, o, a, ad, r):
                pi, v = actor_critic.apply(params, o)
                pol = -jnp.mean(pi.log_prob(a) * ad)
                val = 0.5 * jnp.mean((v - r) ** 2)
                ent = jnp.mean(pi.entropy())
                return pol + config["VF_COEF"]*val - config["ENT_COEF"]*ent, (pol, val, ent)
            
            def contract_leading(coeff, grad_tree):
                return jax.tree_map(lambda g: jnp.tensordot(coeff, g, axes=1), grad_tree)

            def mb_update(state: TrainState, minibatch):
                o_mb, a_mb, adv_mb, ret_mb = minibatch

                # A2C gradient
                (_, (pol_batch, _, _)), ac_grads = jax.value_and_grad(
                    loss_pi_v, has_aux=True)(state.params, o_mb, a_mb, adv_mb, ret_mb)

                # Predictability gradient
                def add_pred(ac_grads):
                    # Get the policy gradient for the driver states rollouts
                    drv_obs_env = jnp.transpose(drv_traj.obs, (1, 0, 2))    # [k,dt_T,obs]
                    drv_act_env = jnp.transpose(drv_traj.action, (1, 0, 2)) # [k,dt_T,act]
                    drv_adv_env = jnp.transpose(drv_adv, (1, 0))            # [k,dt_T]

                    def surrogate_policy_performance(p, o, a, ad):  # The gradient of this would be policy gradient
                        pi, _ = actor_critic.apply(p, o)
                        return jnp.mean(pi.log_prob(a) * ad)  # NOTE: No - sign here!

                    drv_grad = jax.vmap(
                        lambda o, a, ad: jax.grad(surrogate_policy_performance)(actor_state.params, o, a, ad),
                        in_axes=(0, 0, 0)
                    )(drv_obs_env, drv_act_env, drv_adv_env)  # PyTree[k]

                    # 2.a V̂, Vπ, residual
                    def value_hat(R_vec):
                        return predictor.apply(
                            predictor_state.params,
                            driver_states=driver_states_processed,  # [k,emb_dim]
                            driver_returns=R_vec,                   # [k,]
                            query_state=feat_extractor.apply(
                                feat_extractor_params, o_mb)        # [mb,emb_dim]    
                        )         

                    vhat_mb = value_hat(drv_r[0])            # (mb,)
                    resid_mb = vhat_mb - ret_mb
                    actorcritic_pred_loss = jnp.mean(resid_mb ** 2)  # scalar used for logging
                    
                    # 2.b ∂_R V̂ and coeff_k
                    jac_mb_k = jax.jacrev(value_hat)(drv_r[0])         # (mb,k)
                    coeff_k = (2.0 * resid_mb[:, None] * jac_mb_k).mean(axis=0)  # (k,)
                    term1 = contract_leading(coeff_k, drv_grad)     # PyTree

                    # 2.c ∂_θ Vπ via policy-gradient factor
                    # Gradient of the value of policy w.r.t. θ
                    dVpi_dθ_mb = jax.vmap(
                        lambda s, a, adv: jax.grad(surrogate_policy_performance)(state.params, s, a, adv)
                    )(o_mb, a_mb, adv_mb)                               # PyTree with leading mb
                    term2 = contract_leading(2.0 * resid_mb, dVpi_dθ_mb)

                    # 2.d Add to A2C grads
                    pred_grad = jax.tree_map(lambda g1, g2: g1 - g2, term1, term2)
                    ac_grads = jax.tree_map(
                        lambda g, pg: g + config["PREDICTABILITY_COEF"] * pg,
                        ac_grads, pred_grad)
                    
                    return ac_grads, actorcritic_pred_loss
            
                ac_grads, actorcritic_pred_loss = jax.lax.cond(
                    flat_buffer.can_sample(buffer_state),
                    add_pred,
                    lambda ac_grads: (ac_grads, jnp.nan),  # no predictability loss
                    ac_grads
                )
                # Optimizer step
                new_state = state.apply_gradients(grads=ac_grads)
                return new_state, (pol_batch, actorcritic_pred_loss)

            # Scan over mini-batches
            actor_state, (pol_batch, actorcritic_pred_loss) = jax.lax.scan(
                mb_update, actor_state, (mb_obs, mb_act, mb_adv, mb_ret))
            
            # 7. Package next runner state & metrics
            new_runner = (actor_state, env_state, obs,
                          driver_int, driver_obs,
                          predictor_state, buffer_state, rng)
            
            metric = traj.info
            
            loop_data = dict(
                predictor_pred_loss=predictor_pred_loss,
                actorcritic_pred_loss=actorcritic_pred_loss,
                actorcritic_policy_loss=pol_batch,
                metric=metric,
                actor_critic_params=actor_state.params,
                predictor_params=predictor_state.params,
                env_state=env_state,
                obs_batch=obs,
                driver_env_state=driver_int,
                driver_obs_batch=driver_obs,
                rng_step=rng,
            )
            return new_runner, loop_data

        # Run training loop
        runner_final, loop_data_over_time = jax.lax.scan(
            one_update,     # body
            runner,         # initial carry
            update_idx      # xs ← supplies idx each step
        )
        return loop_data_over_time
    
    return train


def main():
    parser = argparse.ArgumentParser(description="evarl continuous")
    parser.add_argument("--save_dir", type=str, default="./complete_continuous_longrun/")
    parser.add_argument("--experiment_name", type=str, default="inverted_double_pendulum_256envs_10steps_1000000.0ts_0seed_25dtr_1000gtr_10dt_inverted_double_pendulum_32bs_0.001lr_100ep_4h_2l_16hd")
    parser.add_argument("--wandb_project", type=str, default="evarl_continuous")
    parser.add_argument("--PREDICTABILITY_COEF", type=float, default=.1)
    parser.add_argument("--PRED_LR", type=float, default=1e-5)
    parser.add_argument("--use_pretrained_transformer", type=int, default=1)
    parser.add_argument("--NUM_PRED_UPDATES", type=int, default=10)
    args = parser.parse_args()

    # Experiment configuration
    save_dir = args.save_dir
    experiment_name = args.experiment_name

    # Load pre-trained A2C parameters
    with open(f'{save_dir}/{experiment_name}_a2c_params.pkl', "rb") as f:
        a2c_params = pickle.load(f)
        
    config = a2c_params["config"]
    del a2c_params

    # Configure learning rate
    if args.PRED_LR != 0:
        if args.use_pretrained_transformer == 1:
            config["PRED_LR"] = 1e-4
        else:
            config["PRED_LR"] = 1e-3
    else:
        config["PRED_LR"] = 0
            
    config.update({
        "MIN_POLICIES_IN_BUFFER": 1,  # Minimum number of policies in the buffer before training the predictor
        "MAX_POLICIES_IN_BUFFER": 5,  # Maximum number of policies in the buffer
        "BUFFER_SAMPLE_BSIZE": 256,
        "RETAIN_PAST_PORTION": 0.10,
        "NUM_ENVS": 64,
        "NUM_STEPS": 15,
        "TOTAL_TIMESTEPS": 1e7,
        "PREDICTABILITY_COEF": args.PREDICTABILITY_COEF,
        "use_pretrained_transformer": args.use_pretrained_transformer,
        "NUM_PRED_UPDATES": args.NUM_PRED_UPDATES
    })

    # Loading predictability transformer weights
    feat_extractor = FeatExtractorContinuousAction(activation=config["ACTIVATION"])
    with open(f'{config["SAVE_DIR"]}/{experiment_name}_pred_transformer.pkl', "rb") as f:
        pred_transformer_params = pickle.load(f)
    feat_extractor_params = pred_transformer_params["feat_extractor_params"]
    predictor_params = pred_transformer_params["predictor_params"]

    # Load the driver's test data from earlier runs
    with open(f'{config["SAVE_DIR"]}/{experiment_name}_predtran_train_data.pkl', "rb") as f:
        predtran_train_data = pickle.load(f)

    driver_states = predtran_train_data["driver_states"]
    driver_env_states = predtran_train_data["driver_env_states"]

    # Initialize wandb
    wandb.init(
        project=args.wandb_project,
        entity="",
        config=config,
        dir="./wandb"
    )

    # Training loop
    try:
        print("Training env:", config["ENV_NAME"])
        rng = jax.random.PRNGKey(config["SEED"])
        t0 = time.time()
        train_jit = jax.jit(make_train(config, feat_extractor_params, predictor_params, driver_states, driver_env_states))
        print("Compiled training function")
        out = jax.block_until_ready(train_jit(rng, feat_extractor_params, predictor_params))
        print(f"Training completed in {time.time() - t0:.2f} seconds")

        # Log metrics to wandb
        for i in range(out['predictor_pred_loss'].shape[0]):
            wandb.log({
                "actorcritic_policy_loss": out["actorcritic_policy_loss"][i][0],
                "actorcritic_pred_loss": out["actorcritic_pred_loss"][i][0],
                "predictor_pred_loss": out["predictor_pred_loss"][i],
                "update": i,
            })
            
        returns_to_log = out["metric"]["returned_episode_returns"].mean(axis=-1).reshape(-1)
        for i in range(returns_to_log.shape[0]):
            wandb.log({
                "episodic_returns": returns_to_log[i],
            })
        
        # Save models
        save_experiment_name = f"{experiment_name}_{config['PREDICTABILITY_COEF']}pc_{config['use_pretrained_transformer']}usepretrained_{config['PRED_LR']}predlr"
        
        with open(f"{config['SAVE_DIR']}/{save_experiment_name}_pred_transformer_evarl.pkl", "wb") as f:
            pickle.dump({
                "predictor_params": out["predictor_params"],
                "config": config,
            }, f)
        
        with open(f"{config['SAVE_DIR']}/{save_experiment_name}_a2c_evarl.pkl", "wb") as f:
            pickle.dump({
                "actor_params": out["actor_critic_params"],
                "config": config,
            }, f)

        print(f"Successfully saved the models")

    except Exception as e:
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()
