import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import argparse
import functools
import time
import pickle
import traceback
from typing import Dict, Any

import jax
import jax.numpy as jnp
import optax
import gymnax
import wandb
from flax.training.train_state import TrainState
import flashbax as fbx

from wrappers import FlattenObservationWrapper, LogWrapper
from models import ActorCriticDiscreteAction, FeatExtractorDiscreteAction, PredictabilityHead
from utils import Transition


def make_train(config, driver_states, driver_env_states, feat_extractor_params, predictor_params):
    """Factory that returns a `train(rng_key)` function for discrete action envs."""
    
    # Derived config values
    config = dict(config)  # shallow copy so we can extend in-place
    config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    config["MINIBATCH_SIZE"] = config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]

    # Environment setup
    base_env, base_env_params = gymnax.make(config["ENV_NAME"])
    base_env = FlattenObservationWrapper(base_env)
    base_env = LogWrapper(base_env)

    # Driver-test state setup
    driver_start_obs, driver_start_internal = driver_states, driver_env_states
    k, obs_dim = driver_start_obs.shape
    
    # Replay buffer setup
    dummy_experience = {
        "driver_returns": jnp.zeros((k,), jnp.float32),
        "query_state": jnp.zeros((obs_dim,), jnp.float32),
        "query_value": jnp.zeros((), jnp.float32),
    }

    config["BUFFER_MAX_LENGTH"] = config["NUM_ENVS"] * config["NUM_STEPS"] * config["MAX_POLICIES_IN_BUFFER"]
    config["BUFFER_MIN_LENGTH"] = config["NUM_ENVS"] * config["NUM_STEPS"] * config["MIN_POLICIES_IN_BUFFER"]
    config["NUM_TO_REPLACE"] = int(config["BUFFER_SAMPLE_BSIZE"] * config["RETAIN_PAST_PORTION"])

    flat_buffer = fbx.make_flat_buffer(
        max_length=config["BUFFER_MAX_LENGTH"],
        min_length=config["BUFFER_MIN_LENGTH"],
        sample_batch_size=config["BUFFER_SAMPLE_BSIZE"],
        add_sequences=False,
        add_batch_size=config["NUM_ENVS"] * config["NUM_STEPS"],
    )
    initial_buffer_state = flat_buffer.init(dummy_experience)
    
    # Learning-rate schedule
    def lr_schedule(step_count):
        progress = 1.0 - (step_count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * progress

    def train(train_key):
        """Return metrics dict after running NUM_UPDATES updates."""
        # Networks & optimizers setup
        train_key, key_policy_init = jax.random.split(train_key)
        actor_critic = ActorCriticDiscreteAction(
            action_dim=base_env.action_space(base_env_params).n,
            activation=config["ACTIVATION"]
        )
        actor_params = actor_critic.init(
            key_policy_init, 
            jnp.zeros(base_env.observation_space(base_env_params).shape)
        )
        actor_optimizer = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(lr_schedule if config["ANNEAL_LR"] else config["LR"], eps=1e-5)
        )
        actor_state = TrainState.create(
            apply_fn=actor_critic.apply,
            params=actor_params,
            tx=actor_optimizer
        )

        # Predictability head setup
        feat_extractor = FeatExtractorDiscreteAction(activation=config["ACTIVATION"])
        
        train_key, key_predictor_init = jax.random.split(train_key)
        predictor = PredictabilityHead(
            config["num_heads"],
            config["hidden_dim"],
            config["num_layers"]
        )
        predictor_state = TrainState.create(
            apply_fn=predictor.apply,
            params=predictor_params,
            tx=optax.adam(config["PRED_LR"])
        )

        # Environment initialization
        train_key, key_env_reset = jax.random.split(train_key)
        env_reset_keys = jax.random.split(key_env_reset, config["NUM_ENVS"])
        observations, env_state = jax.vmap(base_env.reset, in_axes=(0, None))(env_reset_keys, base_env_params)
        driver_observations, driver_internal_state = driver_start_obs, driver_start_internal

        # Runner state (mutable across updates)
        runner_state = (
            actor_state, env_state, observations,
            driver_internal_state, driver_observations,
            predictor_state, initial_buffer_state,
            train_key
        )

        def one_update(runner, _):
            """Run one update step (rollout + optimization)."""
            (
                actor_state,
                env_state,
                observations,
                driver_internal_state,
                driver_observations,
                predictor_state,
                buffer_state,
                rng,
            ) = runner

            # 1. Rollout NUM_STEPS in parallel environments
            def rollout_step(carry, _):
                (
                    actor_state,
                    env_state,
                    observations,
                    rng,
                ) = carry
                rng, rng_policy, rng_driver = jax.random.split(rng, 3)

                # Main environment step
                policy, value = actor_critic.apply(actor_state.params, observations)
                actions = policy.sample(seed=rng_policy)
                log_probs = policy.log_prob(actions)
                step_keys = jax.random.split(rng_driver, observations.shape[0])
                (next_obs, env_state, rewards, done, info) = jax.vmap(
                    base_env.step, in_axes=(0, 0, 0, None)
                )(step_keys, env_state, actions, base_env_params)
                main_transition = Transition(done, actions, value, rewards, log_probs, observations, info)

                carry = (
                    actor_state,
                    env_state,
                    next_obs,
                    rng,
                )
                return carry, main_transition

            # Main trajectory rollout
            (actor_state,
             env_state,
             observations,
             rng), trajectory = jax.lax.scan(
                rollout_step, 
                (actor_state, env_state, observations, rng), 
                None, 
                length=config["NUM_STEPS"]
            )

            # Driver trajectory rollout
            (actor_state, 
             last_driver_env_state, 
             last_driver_observations, 
             rng), driver_trajectory = jax.lax.scan(
                rollout_step, 
                (actor_state, driver_internal_state, driver_observations, rng), 
                None, 
                length=config["driver_traj_len"]
            )

            # 2. GAE advantages + returns
            def compute_gae(tr: Transition, last_v):
                def _scan(carry, timestep):
                    gae, next_value = carry
                    delta = timestep.reward + config["GAMMA"] * next_value * (1 - timestep.done) - timestep.value
                    gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - timestep.done) * gae
                    return (gae, timestep.value), gae
                (_, advantages) = jax.lax.scan(_scan, (jnp.zeros_like(last_v), last_v), tr, reverse=True, unroll=16)
                return advantages, advantages + tr.value

            _, last_value_main = actor_critic.apply(actor_state.params, observations)
            _, last_value_driver = actor_critic.apply(actor_state.params, last_driver_observations)
            advantages, returns = compute_gae(trajectory, last_value_main)
            driver_advantages, driver_returns = compute_gae(driver_trajectory, last_value_driver)

            # 3. Flatten rollout into (B,) mini-batch
            batch_size = config["NUM_ENVS"] * config["NUM_STEPS"]
            flat_obs = trajectory.obs.reshape((batch_size,) + trajectory.obs.shape[2:])
            flat_actions = trajectory.action.reshape((batch_size,) + trajectory.action.shape[2:])
            flat_advantages = advantages.reshape((batch_size,))
            flat_returns = returns.reshape((batch_size,))

            mini_batch_size = batch_size // config["NUM_MINIBATCHES"]
            split_mb = lambda x: x.reshape((config["NUM_MINIBATCHES"], mini_batch_size) + x.shape[1:])
            mb_obs, mb_actions, mb_advs, mb_rets = map(split_mb, (flat_obs, flat_actions, flat_advantages, flat_returns))

            # 4. Add trajectory slices to replay buffer (for predictor)
            driver_returns_batch = jnp.broadcast_to(driver_returns[0], (batch_size, k))  # first time step only
            experience_batch = {
                "driver_returns": driver_returns_batch,
                "query_state": trajectory.obs.reshape((batch_size,) + trajectory.obs.shape[2:]),
                "query_value": flat_returns,
            }
            buffer_state = flat_buffer.add(buffer_state, experience_batch)

            # Episode returns from LogWrapper
            metric = trajectory.info

            # 5. Update predictor from replay buffer (if ready)
            driver_states_processed = feat_extractor.apply(feat_extractor_params, driver_observations)
            driver_states_processed = jnp.broadcast_to(
                driver_states_processed, 
                (config["BUFFER_SAMPLE_BSIZE"], k, driver_states_processed.shape[-1])
            )  # [B,k,obs]
            
            def predictor_update(_):
                rng_sample, rng_next = jax.random.split(rng)
                sample = flat_buffer.sample(buffer_state, rng_sample)
                batch = sample.experience.first

                query_states_processed = feat_extractor.apply(feat_extractor_params, batch["query_state"])
                
                def loss_fn(params):
                    value_hat = predictor.apply(
                        params,
                        driver_states_processed,
                        batch["driver_returns"],
                        query_states_processed
                    )
                    return jnp.mean((value_hat - batch["query_value"]) ** 2)

                def single_step(state, _):
                    loss, grads = jax.value_and_grad(loss_fn)(state.params)
                    new_state = state.apply_gradients(grads=grads)
                    return new_state, loss

                predictor_state_new, losses = jax.lax.scan(
                    single_step, 
                    predictor_state, 
                    None, 
                    length=config["NUM_PRED_UPDATES"]
                )
                return predictor_state_new, losses[-1], rng_next

            def predictor_skip(_):
                return predictor_state, jnp.nan, rng

            predictor_state, predictor_loss, rng = jax.lax.cond(
                flat_buffer.can_sample(buffer_state), 
                predictor_update, 
                predictor_skip, 
                operand=None
            )

            # 6. Compute dR/dθ gradients for each driver start state
            # driver_trajectory.obs shape: [T, k, obs]
            driver_obs_by_env = jnp.transpose(driver_trajectory.obs, (1, 0, 2))  # [k,dt_T,obs]
            driver_act_by_env = jnp.transpose(driver_trajectory.action, (1, 0))  # [k,dt_T]
            driver_adv_by_env = jnp.transpose(driver_advantages, (1, 0))  # [k,dt_T]

            def surrogate_value_function(params, obs_seq, act_seq, adv_seq):
                policy, _ = actor_critic.apply(params, obs_seq)
                return jnp.mean(policy.log_prob(act_seq) * adv_seq)

            per_env_grad = jax.vmap(
                lambda o, a, adv: jax.grad(surrogate_value_function)(actor_state.params, o, a, adv),
                in_axes=(0, 0, 0)
            )(driver_obs_by_env, driver_act_by_env, driver_adv_by_env)  # PyTree with leading dim k

            # 7. Main actor-critic update over mini-batches
            def aggregate_by_coeff(coeffs, grads_tree):
                """Weighted sum Σ_i coeffs[i] * grads_tree[i] across driver envs."""
                return jax.tree_map(lambda g: jnp.tensordot(coeffs, g, axes=1), grads_tree)

            def a2c_loss(params, obs_mb, act_mb, adv_mb, ret_mb):
                policy_mb, value_mb = actor_critic.apply(params, obs_mb)
                loss_policy = -jnp.mean(policy_mb.log_prob(act_mb) * adv_mb)
                loss_value = 0.5 * jnp.mean((value_mb - ret_mb) ** 2)
                entropy_bonus = jnp.mean(policy_mb.entropy())
                return (
                    loss_policy
                    + config["VF_COEF"] * loss_value
                    - config["ENT_COEF"] * entropy_bonus
                ), (loss_policy, loss_value, entropy_bonus)

            def minibatch_update(state: TrainState, minibatch):
                obs_mb, act_mb, adv_mb, ret_mb = minibatch
                (total_loss, (loss_pi, loss_v, ent)), grads = jax.value_and_grad(
                    a2c_loss, has_aux=True
                )(state.params, obs_mb, act_mb, adv_mb, ret_mb)

                # Gradient from predictability head (if trained)
                def pred_grad_for_one_sample(q_state, action, adv, q_return):
                    q_state_processed = feat_extractor.apply(feat_extractor_params, q_state[None, :])
                    predicted_value = predictor.apply(
                        predictor_state.params,
                        driver_states=driver_states_processed[0, :, :],
                        driver_returns=driver_returns[0],
                        query_state=q_state_processed
                    )[0]

                    def value_fn(driver_ret_vec):
                        return predictor.apply(
                            predictor_state.params,
                            driver_states=driver_states_processed[0, :, :],
                            driver_returns=driver_ret_vec,
                            query_state=q_state_processed
                        )[0]

                    # Compute gradient of term1: 2 (V_hat - V) * dV_hat/d_theta
                    dv_hat_dR = jax.grad(value_fn)(driver_returns[0])  # [k]
                    dv_hat_dtheta = aggregate_by_coeff(dv_hat_dR, per_env_grad)  # PyTree same as params
                    scale = 2.0 * (predicted_value - q_return)
                    term1 = jax.tree_map(lambda g: scale * g, dv_hat_dtheta)  # PyTree same as params
                    
                    # Compute gradient of term2: 2 (V_hat - V) * dV/d_theta
                    dv_dtheta = jax.grad(surrogate_value_function)(actor_state.params, q_state, action, adv)
                    term2 = jax.tree_map(lambda g: scale * g, dv_dtheta)  # PyTree same as params
                     
                    return jax.tree_map(lambda t1, t2: t1 - t2, term1, term2)

                def add_predictor_grads(_):
                    pred_grads_mb = jax.vmap(pred_grad_for_one_sample)(obs_mb, act_mb, adv_mb, ret_mb)
                    pred_grads = jax.tree_map(lambda g: jnp.mean(g, axis=0), pred_grads_mb)  # average over minibatch
                    return jax.tree_map(
                        lambda g, pg: g + config["PREDICTABILITY_COEF"] * pg, 
                        grads, 
                        pred_grads
                    )

                query_state_processed = feat_extractor.apply(feat_extractor_params, obs_mb)
                predicted_state_values = predictor.apply(
                    predictor_state.params,
                    driver_states=driver_states_processed[0, :, :],
                    driver_returns=driver_returns[0],
                    query_state=query_state_processed
                )
                pred_loss = jnp.mean((predicted_state_values - ret_mb) ** 2)
                
                grads_total = jax.lax.cond(
                    flat_buffer.can_sample(buffer_state), 
                    add_predictor_grads, 
                    lambda _: grads, 
                    operand=None
                )
                new_state = state.apply_gradients(grads=grads_total)
                return new_state, (loss_pi, pred_loss)

            actor_state, (loss_pi, pred_loss) = jax.lax.scan(
                minibatch_update, 
                actor_state, 
                (mb_obs, mb_actions, mb_advs, mb_rets)
            )

            # 8. Return updated runner state & metrics for logging
            new_runner = (
                actor_state,
                env_state,
                observations,
                driver_internal_state,
                driver_observations,
                predictor_state,
                buffer_state,
                rng,
            )

            loop_data = dict(
                actorcritic_policy_loss=loss_pi,
                actorcritic_pred_loss=pred_loss,
                predictor_pred_loss=predictor_loss,
                metric=metric,
                actor_critic_params=actor_state.params,
                predictor_params=predictor_state.params,
            )
            return new_runner, loop_data

        # Run training loop
        runner_state, loop_data_over_time = jax.lax.scan(
            one_update, 
            runner_state, 
            None, 
            length=config["NUM_UPDATES"]
        )

        return loop_data_over_time
    
    # Return the compiled train() function
    return train


def main():
    parser = argparse.ArgumentParser(description="EvA-RL discrete action space")
    parser.add_argument("--save_dir", type=str, default="./complete_discrete_long_run/")
    parser.add_argument("--experiment_name", type=str, default="Pong-misc_4envs_100steps_500000.0ts_0seed_10dtr_200gtr_20dt_Pong-misc_128bs_0.001lr_100ep_4h_2l_16hd")
    parser.add_argument("--wandb_project", type=str, default="evarl_discrete")
    parser.add_argument("--PREDICTABILITY_COEF", type=float, default=0.1)
    parser.add_argument("--PRED_LR", type=float, default=1e-5)
    parser.add_argument("--use_pretrained_transformer", type=int, default=1)
    parser.add_argument("--NUM_PRED_UPDATES", type=int, default=10)
    args = parser.parse_args()

    save_dir = args.save_dir
    experiment_name = args.experiment_name

    # Load pre-trained A2C parameters
    with open(f'{save_dir}/{experiment_name}_a2c_params.pkl', "rb") as f:
        a2c_params = pickle.load(f)
    config = a2c_params["config"]
    del a2c_params

    # Configure learning rate
    if args.PRED_LR != 0:
        config["PRED_LR"] = 1e-4 if args.use_pretrained_transformer == 1 else 1e-3
    else: 
        config["PRED_LR"] = 0
        
    config.update({
        "MIN_POLICIES_IN_BUFFER": 1,
        "MAX_POLICIES_IN_BUFFER": 10,
        "BUFFER_SAMPLE_BSIZE": 256,
        "RETAIN_PAST_PORTION": 0.05,
        "PREDICTABILITY_COEF": args.PREDICTABILITY_COEF,
        "use_pretrained_transformer": args.use_pretrained_transformer,
        "NUM_PRED_UPDATES": args.NUM_PRED_UPDATES,
    })

    # Load predictability transformer weights
    feat_extractor = FeatExtractorDiscreteAction(activation=config["ACTIVATION"])
    with open(f'{config["SAVE_DIR"]}/{experiment_name}_pred_transformer.pkl', "rb") as f:
        pred_transformer_params = pickle.load(f)
    feat_extractor_params = pred_transformer_params["feat_extractor_params"]
    predictor_params = pred_transformer_params["predictor_params"]

    # Load driver's test data
    with open(f'{config["SAVE_DIR"]}/{experiment_name}_predtran_train_data.pkl', "rb") as f:
        predtran_train_data = pickle.load(f)

    driver_states = predtran_train_data["driver_states"]
    driver_env_states = predtran_train_data["driver_env_states"]

    # Initialize wandb
    wandb.init(
        project=args.wandb_project,
        entity="",
        config=config,
        dir="./wandb"
    )

    # Training loop
    try:
        print("Training env:", config["ENV_NAME"])
        rng = jax.random.PRNGKey(config["SEED"])
        train_jit = jax.jit(make_train(config, driver_states, driver_env_states, feat_extractor_params, predictor_params))
        print("Compiled training function")
        t0 = time.time()
        out = jax.block_until_ready(train_jit(rng))
        print(f"Training completed in {time.time() - t0:.2f} seconds")

        # Log metrics to wandb
        for i in range(out['predictor_pred_loss'].shape[0]):
            wandb.log({
                "actorcritic_policy_loss": out["actorcritic_policy_loss"][i][0],
                "actorcritic_pred_loss": out["actorcritic_pred_loss"][i][0],
                "predictor_pred_loss": out["predictor_pred_loss"][i],
                "update": i,
            })
        
        returns_to_log = out["metric"]["returned_episode_returns"].mean(axis=-1).reshape(-1)
        for i in range(returns_to_log.shape[0]):
            wandb.log({
                "episodic_returns": returns_to_log[i],
            })

        # Save models
        save_experiment_name = f"{experiment_name}_{config['PREDICTABILITY_COEF']}pc_{config['use_pretrained_transformer']}usepretrained_{config['PRED_LR']}predlr"

        with open(f"{config['SAVE_DIR']}/{save_experiment_name}_pred_transformer_evarl.pkl", "wb") as f:
            pickle.dump({
                "predictor_params": out["predictor_params"],
                "config": config,
            }, f)

        with open(f"{config['SAVE_DIR']}/{save_experiment_name}_a2c_evarl.pkl", "wb") as f:
            pickle.dump({
                "actor_params": out["actor_critic_params"],
                "config": config,
            }, f)

        print("Successfully saved the models")

    except Exception as e:
        traceback.print_exc()


if __name__ == "__main__":
    main()
