"""This PPO implementation is modified from PureJaxRL:

  https://github.com/luchris429/purejaxrl

Please refer to their work if you use this example in your research."""

import pickle
import sys
import time
from typing import Any, Literal, NamedTuple

import distrax
import haiku as hk
import jax
import jax.numpy as jnp
import optax
import pgx
from omegaconf import OmegaConf
from pgx.experimental import auto_reset
from pydantic import BaseModel

import src.lib.util as util
import wandb
from src.util import make_env

# Run with python3 -m src.baselines.ppo env_name=grid-risk


class PPOConfig(BaseModel):
    env_name: Literal[
        "minatar-breakout",
        "minatar-freeway",
        "minatar-space_invaders",
        "minatar-asterix",
        "minatar-seaquest",
        "tree-risk",
        "grid-risk",
        "grid-risk-v2",
        "mountain-car-risk",
        "space-invaders-risk",
        "space-invaders-risk-v2",
        "space-invaders-risk-v2-naive",
        "breakout-risk",
        "breakout-risk-v2",
    ] = "space-invaders-risk-v2-naive"
    is_state_vector: bool = False
    seed: int = 0
    lr: float = 0.0003
    num_envs: int = 4096
    num_eval_envs: int = 100
    num_steps: int = 128
    total_timesteps: int = 20000000
    update_epochs: int = 3
    minibatch_size: int = 4096
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    use_wandb: bool = False
    wandb_project: str = "pgx-minatar-ppo"
    save_model: bool = False

    class Config:
        extra = "forbid"


args = PPOConfig(**OmegaConf.to_object(OmegaConf.from_cli()))  # type: ignore
print(args, file=sys.stderr)

env, is_state_vector, _num_actions = make_env(args.env_name)
args.is_state_vector = is_state_vector


num_updates = args.total_timesteps // (args.num_envs * args.num_steps)
num_minibatches = args.num_envs * args.num_steps // args.minibatch_size
print(
    f"Running {num_updates} updates with {num_minibatches} minibatches each.",
    file=sys.stderr,
)


class ActorCritic(hk.Module):
    def __init__(self, num_actions: int, activation: str = "tanh") -> None:
        super().__init__()
        self.num_actions = num_actions
        self.activation = activation
        assert activation in ["relu", "tanh"]

    def __call__(
        self, x: jnp.ndarray
    ) -> tuple[
        jnp.ndarray, jnp.ndarray
    ]:  # x: (1, *env.observation_shape), (1, 10, 10, 4)
        x = x.astype(jnp.float32)
        if self.activation == "relu":
            activation = jax.nn.relu
        else:
            activation = jax.nn.tanh

        # If the state is an image, we use a CNN
        if not args.is_state_vector:
            # x_in for minatar: (1, 10, 10, 4) [batch_size, height, width, frames]
            x = hk.Conv2D(32, kernel_shape=2)(x)  # (1, 10, 10, 32) [padding=SAME]
            x = jax.nn.relu(x)

            # Average pooling
            window_shape = (1, 2, 2, 1)
            strides = (1, 2, 2, 1)
            padding = "VALID"
            pooled = jax.lax.reduce_window(
                x,
                init_value=0.0,
                computation=jax.lax.add,
                window_dimensions=window_shape,
                window_strides=strides,
                padding=padding,
            )
            # Divide by window size to get average
            x = pooled / (2 * 2)  # (1, 5, 5, 32)

        x = x.reshape((x.shape[0], -1))  # flatten, (1, 5*5*32)
        x = hk.Linear(64)(x)  # (1, 64)
        x = jax.nn.relu(x)
        actor_logits = hk.Linear(64)(x)  # (1, 64)
        actor_logits = activation(actor_logits)
        actor_logits = hk.Linear(64)(actor_logits)  # (1, 64)
        actor_logits = activation(actor_logits)
        actor_logits = hk.Linear(self.num_actions)(actor_logits)  # (1, num_actions)

        critic = hk.Linear(64)(x)  # (1, 64)
        critic = activation(critic)
        critic = hk.Linear(64)(critic)  # (1, 64)
        critic = activation(critic)
        critic = hk.Linear(1)(critic)  # (1, 1)
        return actor_logits, jnp.squeeze(critic, axis=-1)  # (1, num_actions), # (1,)


def forward_fn(x: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Forward state, get policy and value"""
    net = ActorCritic(env.num_actions, activation="tanh")
    logits, value = net(x)
    return logits, value  # (1, num_actions), (1,)


forward = hk.without_apply_rng(hk.transform(forward_fn))
optimizer = optax.chain(
    optax.clip_by_global_norm(args.max_grad_norm), optax.adam(args.lr, eps=1e-5)
)


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray


RunnerState = tuple[
    optax.Params,  # params
    optax.OptState,  # opt_state
    pgx.State,  # env_state
    jnp.ndarray,  # last_obs
    jax.Array,  # rng
]

UpdateState = tuple[
    optax.Params,  # params
    optax.OptState,  # opt_state
    Transition,  # traj_batch
    jnp.ndarray,  # advantages
    jnp.ndarray,  # targets
    jax.Array,  # rng
]


def make_update_fn():
    # TRAIN LOOP
    def _update_step(runner_state: RunnerState) -> tuple[RunnerState, jnp.ndarray]:
        """Update the network and environment state"""
        step_fn = jax.vmap(auto_reset(env.step, env.init))

        def _env_step(
            runner_state: RunnerState, unused: Any
        ) -> tuple[RunnerState, Transition]:
            """Step the environment and collect a transition.

            Given a state with the current (last) observation. We sample an action,
            step the environment, and return the transition (s, a, r, done). The state
            s' is in the next environment state.
            """
            params, opt_state, env_state, last_obs, rng = runner_state
            # env_state: Batched environment state
            # last_obs: Batched last observations (b, *env.observation_shape)
            # rng: (2, )

            # SELECT ACTION
            rng, _rng = jax.random.split(rng)  # (2, ), (2, )
            logits, value = forward.apply(params, last_obs)  # (b, num_actions), (b, )
            pi = distrax.Categorical(logits=logits)
            action = pi.sample(seed=_rng)  # (b,)
            log_prob = pi.log_prob(action)  # (b,)

            # STEP ENV
            rng, _rng = jax.random.split(rng)  # (2, ), (2, )
            keys = jax.random.split(_rng, env_state.observation.shape[0])  # (b, 2)
            env_state = step_fn(env_state, action, keys)  # Returns batched env state

            # Store the s, a, r, done in a Transition
            transition = Transition(
                env_state.terminated,  # (b,)
                action,  # (b,)
                value,  # (b,)
                jnp.squeeze(env_state.rewards),  # (b,)
                jnp.asarray(log_prob),  # (b,)
                last_obs,  # (b, *env.observation_shape)
            )
            runner_state = (params, opt_state, env_state, env_state.observation, rng)
            return runner_state, transition

        # Collect a trajectory
        runner_state, traj_batch = jax.lax.scan(
            _env_step, runner_state, None, args.num_steps
        )

        # CALCULATE ADVANTAGE
        params, opt_state, env_state, last_obs, rng = runner_state
        # env_state: Batched environment state
        # last_obs: Batched last observations (b, *env.observation_shape)
        # rng: (2, )
        _, last_val = forward.apply(params, last_obs)  # (b, )

        def _calculate_gae(
            traj_batch: Transition, last_val: jnp.ndarray
        ) -> tuple[jnp.ndarray, jnp.ndarray]:
            def _get_advantages(
                gae_and_next_value: tuple[jnp.ndarray, jnp.ndarray],
                transition: Transition,
            ) -> tuple[tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
                gae, next_value = gae_and_next_value  # (b, ), (b, )
                done, value, reward = (
                    transition.done,  # (b, ),
                    transition.value,  # (b, )
                    transition.reward,  # (b, )
                )
                delta = reward + args.gamma * next_value * (1 - done) - value
                gae = delta + args.gamma * args.gae_lambda * (1 - done) * gae
                return (gae, value), gae

            _, advantages = jax.lax.scan(
                _get_advantages,
                (jnp.zeros_like(last_val), last_val),
                traj_batch,
                reverse=True,
                unroll=16,
            )
            return advantages, advantages + traj_batch.value

        advantages, targets = _calculate_gae(traj_batch, last_val)

        # UPDATE NETWORK
        def _update_epoch(
            update_state: UpdateState, unused: Any
        ) -> tuple[UpdateState, jnp.ndarray]:
            def _update_minbatch(
                tup_params_opt: tuple[optax.Params, optax.OptState],
                batch_info: tuple[Transition, jnp.ndarray, jnp.ndarray],
            ) -> tuple[tuple[optax.Params, optax.OptState], jnp.ndarray]:
                params, opt_state = tup_params_opt
                traj_batch, advantages, targets = batch_info
                # traj_batch: (num_minibatches, *transition)
                # advantages: (num_minibatches,)
                # targets: (num_minibatches,)

                def _loss_fn(
                    params: jnp.ndarray,
                    traj_batch: Transition,
                    gae: jnp.ndarray,
                    targets: jnp.ndarray,
                ) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
                    # RERUN NETWORK
                    logits, value = forward.apply(params, traj_batch.obs)
                    pi = distrax.Categorical(logits=logits)
                    log_prob = pi.log_prob(traj_batch.action)  # (minibatch_size, )

                    # CALCULATE VALUE LOSS
                    value_pred_clipped = traj_batch.value + (
                        value - traj_batch.value
                    ).clip(-args.clip_eps, args.clip_eps)  # (minibatch_size, )
                    value_losses = jnp.square(value - targets)  # (minibatch_size, )
                    value_losses_clipped = jnp.square(
                        value_pred_clipped - targets
                    )  # (minibatch_size, )
                    value_loss = (
                        0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                    )  # ()

                    # CALCULATE ACTOR LOSS
                    ratio = jnp.exp(
                        log_prob - traj_batch.log_prob
                    )  # (minibatch_size, )
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)  # (minibatch_size, )
                    loss_actor1 = ratio * gae  # (minibatch_size, )
                    loss_actor2 = (
                        jnp.clip(
                            ratio,
                            1.0 - args.clip_eps,
                            1.0 + args.clip_eps,
                        )
                        * gae
                    )  # (minibatch_size, )
                    loss_actor = -jnp.minimum(
                        loss_actor1, loss_actor2
                    )  # (minibatch_size, )
                    loss_actor = loss_actor.mean()  # ()
                    entropy = pi.entropy().mean()  # ()

                    total_loss = (
                        loss_actor + args.vf_coef * value_loss - args.ent_coef * entropy
                    )  # ()
                    return total_loss, (value_loss, loss_actor, entropy)

                grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                total_loss, grads = grad_fn(params, traj_batch, advantages, targets)
                updates, opt_state = optimizer.update(grads, opt_state)
                params = optax.apply_updates(params, updates)
                return (params, opt_state), total_loss

            params, opt_state, traj_batch, advantages, targets, rng = update_state
            # traj_batch: batched 128 transitions
            # advantages: (128, b)
            # targets: (128, b)
            # rng: (2, )

            rng, _rng = jax.random.split(rng)
            batch_size = args.minibatch_size * num_minibatches
            assert batch_size == args.num_steps * args.num_envs, (
                "batch size must be equal to number of steps * number of envs"
            )
            permutation = jax.random.permutation(_rng, batch_size)  # (batch_size, )

            # Permute the minbatches to avoid overfitting, then reshape back to minibatches
            batch = (traj_batch, advantages, targets)
            # Flatten batch dim, everything is (batch_size, *shape)
            batch = jax.tree_util.tree_map(
                lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
            )
            # Shuffle
            shuffled_batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=0), batch
            )
            # Reshape to (num_minibatches, minibatch_size, *shape)
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(x, [num_minibatches, -1] + list(x.shape[1:])),
                shuffled_batch,
            )
            # Update
            (params, opt_state), total_loss = jax.lax.scan(
                _update_minbatch, (params, opt_state), minibatches
            )
            update_state = (params, opt_state, traj_batch, advantages, targets, rng)
            return update_state, total_loss

        update_state = (params, opt_state, traj_batch, advantages, targets, rng)
        update_state, loss_info = jax.lax.scan(
            _update_epoch, update_state, None, args.update_epochs
        )
        params, opt_state, _, _, _, rng = update_state

        runner_state = (params, opt_state, env_state, last_obs, rng)
        return runner_state, loss_info

    return _update_step


@jax.jit
def evaluate(
    params: optax.Params, rng_key: jax.Array
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Runs evaluation on num_eval_envs"""
    step_fn = jax.vmap(env.step)
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, args.num_eval_envs)  # (num_eval_envs, 2)
    state = jax.vmap(env.init)(subkeys)  # (num_eval_envs, *env.observation_shape)
    ep_return = jnp.zeros_like(state.rewards)  # (num_eval_envs, )

    def cond_fn(tup: tuple[pgx.State, Any, Any]) -> bool:
        """Condition for the while loop to continue"""
        state, _, _ = tup
        return ~state.terminated.all()

    def loop_fn(
        tup: tuple[pgx.State, jax.Array, jax.Array],
    ) -> tuple[pgx.State, jnp.ndarray, jax.Array]:
        state, R, rng_key = (
            tup  # (num_eval_envs, *env.observation_shape), (num_eval_envs, ), (2, )
        )
        # Sample action
        logits, _ = forward.apply(
            params, state.observation
        )  # (num_eval_envs, num_actions) # action = logits.argmax(axis=-1)
        pi = distrax.Categorical(logits=logits)
        rng_key, _rng = jax.random.split(rng_key)
        action = pi.sample(seed=_rng)  # (num_eval_envs, )
        # Act
        rng_key, _rng = jax.random.split(rng_key)
        keys = jax.random.split(_rng, state.observation.shape[0])
        state = step_fn(state, action, keys)
        return state, R + state.rewards, rng_key

    # Loop until all environments are done
    state, ep_return, _ = jax.lax.while_loop(
        cond_fn, loop_fn, (state, ep_return, rng_key)
    )

    # Flatten the episode returns
    ep_return = ep_return.reshape((-1,))
    cvar_ep_return = util.cvar(ep_return, 0.25)  # Calculate CVaR
    return ep_return.mean(), cvar_ep_return


def train(rng: jax.Array) -> RunnerState:
    tt = 0
    st = time.time()
    # INIT NETWORK
    rng, _rng = jax.random.split(rng)  # (2,), (2,)
    init_x = jnp.zeros((1,) + env.observation_shape)  # (1, *env.observation_shape)
    print("Initializing (compiling)...")
    params = forward.init(_rng, init_x)
    opt_state = optimizer.init(params=params)

    # INIT UPDATE FUNCTION
    _update_step = make_update_fn()
    jitted_update_step = jax.jit(_update_step)

    # INIT ENV
    rng, _rng = jax.random.split(rng)  # (2,), (2,)
    reset_rng = jax.random.split(_rng, args.num_envs)  # Tuple: (2,), (2,)
    env_state = jax.jit(jax.vmap(env.init))(reset_rng)  # Batched env state

    rng, _rng = jax.random.split(rng)  # (2,), (2,)
    runner_state = (params, opt_state, env_state, env_state.observation, _rng)

    # warm up
    _, _ = jitted_update_step(runner_state)
    print("Done initializing network", file=sys.stderr)

    steps = 0

    # initial evaluation
    et = time.time()  # exclude evaluation time
    tt += et - st
    rng, _rng = jax.random.split(rng)
    eval_R, cvar_R = evaluate(runner_state[0], _rng)
    log = {
        "sec": tt,
        f"{args.env_name}/eval_R": float(eval_R),
        f"{args.env_name}/cvar_R": float(cvar_R),
        "steps": steps,
    }
    print(log)
    if args.use_wandb:
        wandb.log(log)
    st = time.time()

    for _ in range(num_updates):
        runner_state, _ = jitted_update_step(runner_state)
        steps += args.num_envs * args.num_steps

        # evaluation
        et = time.time()  # exclude evaluation time
        tt += et - st
        rng, _rng = jax.random.split(rng)
        eval_R, cvar_R = evaluate(runner_state[0], _rng)
        log = {
            "sec": tt,
            f"{args.env_name}/eval_R": float(eval_R),
            f"{args.env_name}/cvar_R": float(cvar_R),
            "steps": steps,
        }
        print(log)
        if args.use_wandb:
            wandb.log(log)
        st = time.time()

    return runner_state


if __name__ == "__main__":
    if args.use_wandb:
        wandb.init(project=args.wandb_project, config=args.model_dump())
    rng = jax.random.PRNGKey(args.seed)
    out = train(rng)
    if args.save_model:
        with open(f"{args.env_name}-seed={args.seed}.ckpt", "wb") as f:
            pickle.dump(out[0], f)
