import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import jax
import jax.numpy as jnp
import optax
from typing import NamedTuple
from flax.training.train_state import TrainState
from ac import AC
from wrappers import LogWrapper, NormalizeVecReward, ACVecEnv
from network import FFNActorCritic, ResNetActorCritic

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray


def make_train(config):
    # The config parameters used to define the environment, i.e. ENV_NAME (and GAMMA if we use NormalizeVecReward) 
    # cannot be passed as inputs to train() and hence cannot be swept over using vmap. 
    # Any other hyperparameters may be swept over by removing them from config and 
    # passing them explicitly as inputs to train(). This includes the two hyperparameters defined here:
    # NUM_UPDATES and MINIBATCH_SIZE, as their values could possibly just be defined inside `train` using
    # other hyperparameters.
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env = AC(n_gen=2, max_length=512, max_steps_in_episode=200, primed_actions=True, is_reward_sparse=True)
    assert len(env.init_states) == config["NUM_ENVS"], "expect num_envs to be equal to num initial states"
    env_params = env.default_params
    env = LogWrapper(env)
    env = NormalizeVecReward(env, config["GAMMA"])
    env = ACVecEnv(env)

    # If we are not planning to sweep over ACTIVATION, we may as well define network here.
    if config["RES_NET"]:
        network = ResNetActorCritic(
            env.action_space(env_params).n, activation=config["ACTIVATION"], num_residual_blocks=config["NUM_RES_BLOCKS"]
        )
        print(f"""
              Using ResNetActorCritic with {config['NUM_RES_BLOCKS']} residual blocks..
              Consider setting config["RES_NET"]=False on a CPU for less computational power
              """)
    else:
        network = FFNActorCritic(
            env.action_space(env_params).n, activation=config["ACTIVATION"]
        )

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    def train(rng):
        # INIT NETWORK
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)

        # It is okay to use if-else here instead of jax.lax.cond as we set config["ANNEAL_LR"]
        # once through all experiments. So while compiling the function, JAX will just compile
        # the branch of if-else associated to the ANNEAL_LR value. 
        # I think if were to pass ANNEAL_LR as an input to make_train, then maybe jax will throw an error.
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        ) 

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = env.reset(reset_rng, env_params, jnp.arange(config["NUM_ENVS"]))

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                # TODO: here we can consider applying temperature to pi.
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = env.step(rng_step, env_state, action, env_params)
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"],
                                1.0 + config["CLIP_EPS"],
                            )
                            * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                # Batching and Shuffling
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                # Mini-batch Updates
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss
            # Updating Training State and Metrics:
            # traj_batch has shape (NUM_STEPS, NUM_ENVS, *obs.shape)
            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            ) # To see information about loss_info, see the end of the file
            train_state = update_state[0]
            metric = traj_batch.info # metric.keys() = ["discount", "returned_episode", "returned_episode_lengths", 
                                                        #"returned_episode_returns", "timestep"], all are of shape (NUM_STEPS, NUM_ENVS)
            rng = update_state[-1]
            adv = update_state[2]

            values_flat = traj_batch.value.flatten()
            targets_flat = targets.flatten()
            explained_var = 1 - jnp.var(values_flat - targets_flat) / jnp.var(targets_flat)

            # TODO: callbacks likely happen on cpu? Can we call them asynchronously?
            # Debugging mode
            if config.get("DEBUG"):
                def print_callback(info, env_state):
                    return_values = info["returned_episode_returns"][info["returned_episode"]]
                    timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"] # TODO: this should be different    
                    for t in range(len(timesteps)):
                        print(f"global step={timesteps[t]}, episodic return={return_values[t]}, num solved={int(jnp.count_nonzero(env_state.solved_idx))}")

                jax.debug.callback(print_callback, metric, env_state)

            if config.get("WANDB_MODE", "disabled") == "online":
                def wandb_callback(info, loss_info, adv, traj_batch, env_state, explained_var):
                    return_values = info["returned_episode_returns"][info["returned_episode"]]
                    timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
                    num_solved = int(jnp.count_nonzero(env_state.solved_idx))
                    largest_solved_idx = -1 if num_solved == 0 else int(env_state.solved_idx.nonzero()[0].max())
                    
                    for t in range(len(timesteps)):
                        wandb.log({"global_step": int(timesteps[t]), # TODO: improve the logging of time.
                                    "episodic_return": float(return_values[t]),
                                    "entropy_loss": float(loss_info[1][2].mean()),
                                    "value_loss": float(loss_info[1][0].mean()),
                                    "policy_loss": float(loss_info[1][1].mean()),
                                    "adv_mean": float(adv.mean()), # TODO: should I consider gae instead? 
                                    "adv_std": float(adv.std()),
                                    "mean_length": float(traj_batch.info["length"].mean()),
                                    "min_length": float(traj_batch.info["length"].min()),
                                    "max_length": float(traj_batch.info["length"].max()),
                                    "num_solved": num_solved,
                                    "largest_solved_idx": largest_solved_idx, # the larger the better
                                    "recently_solved_no_duplicates": (info["terminated"].sum(axis=0) != 0).sum(),
                                    "recently_solved_with_duplicates": info["terminated"].sum(axis=0).sum(),
                                    "explained_var": float(explained_var)
                                    })

                jax.debug.callback(wandb_callback, metric, loss_info, adv, traj_batch, env_state, explained_var)

            runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, metric

        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, _ = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        return {"runner_state": runner_state}

    return train


if __name__ == "__main__":
    import argparse
    def parse_args():
        parser = argparse.ArgumentParser()
        parser.add_argument("--w", type=int, default=0)
        parser.add_argument("--d", type=int, default=0)
        parser.add_argument("--lr", type=float, default=5e-4)
        parser.add_argument("--ent_coef", type=float, default=0.01)
        parser.add_argument("--num_res_blocks", type=int, default=4)
        return parser.parse_args()
    
    # TODO: Log target_kl
    # repeat_solved_prob
    # norm_rewards
    # norm_adv
    
    args = parse_args()

    config = {
        "LR": args.lr,
        "NUM_ENVS": 1190,
        "NUM_STEPS": 200,
        "TOTAL_TIMESTEPS": 1e10,
        "UPDATE_EPOCHS": 3,
        "NUM_MINIBATCHES": 4,
        "RES_NET": True,
        "NUM_RES_BLOCKS": args.num_res_blocks,
        "GAMMA": 0.999,
        "GAE_LAMBDA": 0.95,
        "CLIP_EPS": 0.2,
        "ENT_COEF": args.ent_coef,
        "VF_COEF": 0.5,
        "MAX_GRAD_NORM": 0.5,
        "ACTIVATION": "tanh",
        "ENV_NAME": "AC-v0",
        "ANNEAL_LR": False,
        "DEBUG": args.d, 
        "WANDB_MODE": "online" * args.w,  # set to online to activate wandb
        "ENTITY": "",
        "PROJECT": "purejaxrl_ppo_ac",
    }

    if config.get("WANDB_MODE", "disabled") == "online":
        import wandb 

        wandb.init(
            entity=config["ENTITY"],
            project=config["PROJECT"],
            tags=["PPO", config["ENV_NAME"].upper(), f"jax_{jax.__version__}"],
            name=f'purejaxrl_ppo_{config["ENV_NAME"]}',
            config=config,
            mode=config["WANDB_MODE"],
        )

    train_jit = jax.jit(make_train(config))
    rng = jax.random.PRNGKey(42)

    out = jax.block_until_ready(train_jit(rng))

