"""
Based on PureJaxRL Implementation of PPO
"""

import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any, Dict
from flax.training.train_state import TrainState
import distrax
from jaxmarl.wrappers.baselines import LogWrapper
import jaxmarl
import wandb
import functools
import matplotlib.pyplot as plt
import hydra
from omegaconf import OmegaConf
from functools import partial
import gymnax
from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper

class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
            
        actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
        critic = activation(critic)
        critic = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(critic)

        return pi, jnp.squeeze(critic, axis=-1)

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray

def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env, env_params = gymnax.make(config["ENV_NAME"], use_minimal_action_set=False)
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)

    @jax.jit
    def train(runner_state, hyperparams):
        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = train_state.apply_fn(train_state.params, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0,None))(
                    rng_step, env_state, action, env_params
                )
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = train_state.apply_fn(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + hyperparams["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                        delta
                        + hyperparams["GAMMA"] * hyperparams["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = train_state.apply_fn(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"],
                                1.0 + config["CLIP_EPS"],
                            )
                            * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - hyperparams["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss

            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            train_state = update_state[0]
            metric = traj_batch.info
            rng = update_state[-1]

            runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, metric["returned_episode_returns"].mean()

        train_state, rng = runner_state
        tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(hyperparams["LR"], eps=1e-5))
        train_state = train_state.replace(tx=tx)
        rng, _rng = jax.random.split(rng)
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(jax.random.split(_rng, config["NUM_ENVS"]), env_params)
        runner_state = (train_state, env_state, obsv, rng)
        runner_state, metrics = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        return (runner_state[0], runner_state[-1]), metrics

    return train

def main(config):
    wandb.init(
        project=config["PROJECT"],
        tags=[config["ENV_NAME"],],
        config=config,
    )
    
    num_devices = jax.device_count()
    popsize=16*num_devices
    num_threads = popsize // num_devices
    assert num_devices * num_threads == popsize
    
    @jax.jit
    def wrap_pmap(tree):
        return jax.tree_map(lambda x: x.reshape(num_devices, num_threads, *x.shape[1:]), tree)
    
    @jax.jit
    def unwrap_pmap(tree):
        return jax.tree_map(lambda x: x.reshape(-1, *x.shape[2:]), tree)

    config_asterix = {**config, "ENV_NAME": "Asterix-MinAtar"}
    config_breakout = {**config, "ENV_NAME": "Breakout-MinAtar"}
    config_freeway = {**config, "ENV_NAME": "Freeway-MinAtar"}
    config_space_invaders = {**config, "ENV_NAME": "SpaceInvaders-MinAtar"}

    games = ["Asterix-MinAtar", "Breakout-MinAtar", "Freeway-MinAtar", "SpaceInvaders-MinAtar"]

    @jax.jit
    def init(rng, hyperparam_ranges):
        train_states = []
        for env_name in games:
            env, env_params = gymnax.make(env_name, use_minimal_action_set=False)
            env = FlattenObservationWrapper(env)
            env = LogWrapper(env)
            network = ActorCritic(env.action_space(env_params).n, activation=config["ACTIVATION"])
            rng, _rng = jax.random.split(rng)
            init_x = jnp.zeros(env.observation_space(env_params).shape)
            network_params = network.init(_rng, init_x)
            tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))
            train_state = TrainState.create(
                apply_fn=network.apply,
                params=network_params,
                tx=tx,
            )
            train_states.append(train_state)
        
        rng, _rng = jax.random.split(rng)
        rng_hyperparams = {key: rng for key, rng in zip(hyperparam_ranges.keys(), jax.random.split(_rng, len(hyperparam_ranges)))}
        hyperparams = jax.tree_map(lambda rng, range: jax.random.uniform(rng, minval=range[0], maxval=range[1]), rng_hyperparams, hyperparam_ranges)
        
        runner_state = (*train_states, rng)
        # runner_state = (train_state, env_state, obsv, rng)

        return (runner_state, hyperparams)

    train_asterix = make_train(config_asterix)
    train_breakout = make_train(config_breakout)
    train_freeway = make_train(config_freeway)
    train_space_invaders = make_train(config_space_invaders)
    train_asterix = jax.pmap(jax.vmap(train_asterix))
    train_breakout = jax.pmap(jax.vmap(train_breakout))
    train_freeway = jax.pmap(jax.vmap(train_freeway))
    train_space_invaders = jax.pmap(jax.vmap(train_space_invaders))
    
    init = jax.vmap(init, (0, None))
    
    train_fns = [
        train_asterix,
        train_breakout,
        train_freeway,
        train_space_invaders,
    ]
    
    def step_and_eval(game_id, runner_states, hyperparams):
        train_states = runner_states[game_id]
        old_runner_states = runner_states
        runner_states = (train_states, runner_states[-1])
        
        runner_states, hyperparams = wrap_pmap((runner_states, hyperparams))
        runner_states, returns = train_fns[game_id](runner_states, hyperparams)
        runner_states, returns = unwrap_pmap((runner_states, returns))
        fitness = returns.mean(axis=-1)
        
        runner_states = (*old_runner_states[:game_id], runner_states[0], *old_runner_states[game_id+1:], runner_states[-1])
        
        return runner_states, fitness
    
    @jax.jit
    def exploit(rng, h, theta, fitness):
        fitness_sorted = jnp.sort(fitness)
        kth_best_fitness = fitness_sorted[-int(0.1 * popsize)]
        rngs = jax.random.split(rng, len(fitness))
        def member_exploit(rng, h_i, theta_i, fitness_i):
            exploit_bool = fitness_i < kth_best_fitness
            copy_id = jax.random.choice(rng, len(fitness), p=(fitness >= kth_best_fitness))
            theta_i = jax.tree_map(lambda x, y: jax.lax.select(exploit_bool, x[copy_id], y), theta, theta_i)
            h_i = jax.lax.select(exploit_bool, h[copy_id], h_i)
            return h_i, theta_i, exploit_bool
        return jax.vmap(member_exploit)(rngs, h, theta, fitness)
    
    @jax.jit
    def explore(rng, h, theta, explore_mask):
        rngs = jax.random.split(rng, len(explore_mask))
        def member_explore(rng, h_i, theta_i, explore_bool):
            # # noise
            std = jnp.roll(h_i, -1, axis=0).at[-1].set(1e-5)
            offset = std * jax.random.normal(rng, h_i.shape)
            h_i = jax.lax.select(explore_bool, h_i + offset, h_i)
            
            # Offset + noise
            # noise = 1e-5 * jax.random.normal(rng, h_i.shape)
            # offset = jnp.roll(h_i, -1, axis=0).at[-1].set(0.) + noise
            # h_i = jax.lax.select(explore_bool, h_i + offset, h_i)
            
            return h_i, theta_i
        return jax.vmap(member_explore)(rngs, h, theta, explore_mask)
        
    rng = jax.random.PRNGKey(config["seed"])
    
    hyperparam_ranges = {
        "LR": jnp.array([config["LR"], config["LR"] + 1e-6]),
        "ENT_COEF": jnp.array([config["ENT_COEF"], config["ENT_COEF"] + 1e-6]),
        "GAMMA": jnp.array([config["GAMMA"], config["GAMMA"] + 1e-6]),
        "GAE_LAMBDA": jnp.array([config["GAE_LAMBDA"], config["GAE_LAMBDA"] + 1e-6]),
    }
    
    rng, rng_pop = jax.random.split(rng)
    runner_states, hyperparams = init(jax.random.split(rng_pop, popsize), hyperparam_ranges)
    
    hyperparam_ranges = {
        "LR": jnp.array([1e-6, 0.1]),
        "ENT_COEF": jnp.array([1e-7, 0.1]),
        "GAMMA": jnp.array([0.9, 0.9999]),
        "GAE_LAMBDA": jnp.array([0.9, 0.9999]),
    }
    
    _, unravel = jax.flatten_util.ravel_pytree(jax.tree_map(lambda x: x[0], hyperparams))
    unravel = jax.vmap(unravel)
    
    @jax.vmap
    def flatten(hyperparams):
        flat, _ = jax.flatten_util.ravel_pytree(hyperparams)
        return flat
    
    @jax.jit
    def clip(hyperparams, ranges):
        hyperparams_tree = unravel(hyperparams[:, 0])
        hyperparams_tree = jax.tree_map(lambda x, r: x.clip(r[0], r[1]), hyperparams_tree, ranges)
        flat = flatten(hyperparams_tree)
        hyperparams = hyperparams.at[:, 0].set(flat)
        return hyperparams
    
    hyperparams = flatten(hyperparams)
    
    # add arbitrary order meta-hyperparams
    hyperparams = hyperparams.reshape(hyperparams.shape[0], 1, hyperparams.shape[1])
    if config["METAHYPER_ORDER"] > 1:
        hyperparams = jnp.repeat(hyperparams, config["METAHYPER_ORDER"], axis=1)#.at[:, 1:].set(1e-5)
        rng, rng_noise = jax.random.split(rng)
        metaparams = 1e-5 * jax.random.normal(rng_noise, shape=(hyperparams.shape[0], config["METAHYPER_ORDER"]-1, hyperparams.shape[2]))
        hyperparams = hyperparams.at[:, 1:].set(metaparams)
    print("(META) HYPERPARAM SHAPE:", hyperparams.shape)

    game_id = 0
    for gen in range(1000):
        rng, rng_exploit, rng_explore = jax.random.split(rng, 3)
        runner_states, fitness = step_and_eval(game_id, runner_states, unravel(hyperparams[:, 0]))
        hyperparams, runner_states, explore_mask = exploit(rng_exploit, hyperparams, runner_states, fitness)
        hyperparams, runner_states = explore(rng_explore, hyperparams, runner_states, explore_mask)
        hyperparams = clip(hyperparams, hyperparam_ranges)
        
        if gen % 50:
            game_id = (game_id + 1) % len(train_fns)
        
        log_dict = {f"fitness_{games[game_id]}": fitness.max()}
        best_id = fitness.argmax()
        for i in range(config["METAHYPER_ORDER"]):
            log_dict.update({f"{i}-th_order/{key}": value[best_id] for key, value in unravel(hyperparams[:, i]).items()})
        wandb.log(log_dict)
        
if __name__=="__main__":
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--PROJECT", type=str, default="MINATAR")
    parser.add_argument("--seed", type=int, default=0)
    # === PPO ===
    parser.add_argument("--LR", type=float, default=5e-3)
    parser.add_argument("--NUM_ENVS", type=int, default=64)
    parser.add_argument("--NUM_STEPS", type=int, default=128)
    parser.add_argument("--TOTAL_TIMESTEPS", type=int, default=3e5)
    parser.add_argument("--UPDATE_EPOCHS", type=int, default=4)
    parser.add_argument("--NUM_MINIBATCHES", type=int, default=8)
    parser.add_argument("--GAMMA", type=float, default=0.99)
    parser.add_argument("--GAE_LAMBDA", type=float, default=0.95)
    parser.add_argument("--CLIP_EPS", type=float, default=0.2)
    parser.add_argument("--ENT_COEF", type=float, default=0.01)
    parser.add_argument("--VF_COEF", type=float, default=0.5)
    parser.add_argument("--MAX_GRAD_NORM", type=float, default=0.5)
    parser.add_argument("--ACTIVATION", type=str, default="relu")
    parser.add_argument("--ENV_NAME", type=str, default="SpaceInvaders-MinAtar")
    parser.add_argument("--METAHYPER_ORDER", type=int, default=1)
    config = vars(parser.parse_args())
    
    wandb.login()
    
    main(config)