"""
Inspired by https://github.com/sotetsuk/pgx/blob/main/examples/minatar-ppo/train.py
"""

from functools import partial
from typing import Callable

import jax
import jax.numpy as jnp
import optax
from haiku import Params
from jax import Array

from medium_rl.config import Config
from medium_rl.data.transition import Transition
from medium_rl.envs.sequence_env import SequenceEnv
from medium_rl.init import RunState


@partial(jax.jit, static_argnames=["forward", "policy_fn", "env"])
def gen_transition(run_state: RunState, forward: Callable, policy_fn: Callable, env: SequenceEnv):
    def _env_step(run_state: RunState, step):
        rng, policy_rng = jax.random.split(run_state.rng)

        # Select action
        env_state = run_state.env_state
        obs, legal_action_mask = env_state.obs, env_state.legal_action_mask
        output = forward.apply(run_state.params, None, obs, is_training=False)
        logits, val = output[0], output[1]
        action = policy_fn(
            logits,
            legal_action_mask,
            step,
            policy_rng,
        )

        # Compute log_prob of action
        logits = logits[:, step, :]  # Extract only logits for current token.
        logits = jnp.where(legal_action_mask == 1, logits * 1, -jnp.inf)
        log_prob = jax.nn.log_softmax(logits, axis=-1)
        log_prob = jnp.take_along_axis(log_prob, action[..., None], axis=-1).squeeze(-1)

        # Step env
        env_state = env.step_fn(env_state, action)

        transition = Transition(
            obs,
            env_state.obs,
            legal_action_mask,
            env_state.legal_action_mask,
            action,
            env_state.terminating.astype(int),  # Whether the action terminated the env
            env_state.terminated.astype(int),
            jnp.zeros(obs.shape[0]),  # Filler
            val[:, step, 0],
            log_prob,
            step * jnp.ones(obs.shape[0], dtype=jnp.int32),
        )

        run_state = run_state.replace(env_state=env_state, rng=rng)
        return run_state, transition

    # This accumulates transitions to create a subtrajectory of shape T x B x ...
    run_state, transition_batch = jax.lax.scan(_env_step, run_state, jnp.arange(env.max_len))

    # Reset environments
    rng, reset_rng = jax.random.split(run_state.rng)
    keys = jax.random.split(reset_rng, run_state.env_state.obs.shape[0])
    run_state = run_state.replace(env_state=env.reset_fn(run_state.env_state, keys), rng=rng)

    return run_state, transition_batch


def make_ppo_loss_fn(cfg: Config):
    def loss_fn(all_outputs: Array, transitions: Array, gae: Array, val_targets: Array):
        minibatch_size = gae.shape[0]
        action = transitions.action
        step = transitions.step
        clip_eps = cfg.alg.clip_eps

        logits = all_outputs[0][jnp.arange(minibatch_size), step]
        logits = jnp.where(transitions.legal_action_mask, logits, -jnp.inf)
        pi, log_pi = jax.nn.softmax(logits, axis=-1), jax.nn.log_softmax(logits, axis=-1)
        pi = jnp.where(transitions.legal_action_mask, pi, 0)
        log_pi = jnp.where(transitions.legal_action_mask, log_pi, 0)
        log_selected_pi = jnp.take_along_axis(log_pi, action[..., None], axis=-1).squeeze(axis=-1)

        # Critic loss
        val = all_outputs[1][jnp.arange(minibatch_size), step].squeeze(-1)
        clipped_val = transitions.val + (val - transitions.val).clip(-clip_eps, clip_eps)
        val_losses = jnp.square(val - val_targets)
        clipped_val_losses = jnp.square(clipped_val - val_targets)
        val_loss = 0.5 * jnp.maximum(val_losses, clipped_val_losses)

        # Actor loss
        prob_ratio = jnp.exp(log_selected_pi - transitions.log_prob)
        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
        actor_loss = prob_ratio * gae
        clipped_actor_loss = jnp.clip(prob_ratio, 1.0 - clip_eps, 1.0 + clip_eps)
        actor_loss = -jnp.minimum(actor_loss, clipped_actor_loss)

        # Entropy
        entropy = -(pi * log_pi)
        entropy = jnp.where(transitions.legal_action_mask, entropy, 0).sum(axis=-1)

        # Combine
        total_loss = actor_loss + cfg.alg.val_coef * val_loss - (1 / cfg.alg.omega) * entropy
        total_loss = (total_loss * (1 - transitions.done)).sum() / (
            1 - transitions.done
        ).sum()  # Ignore done transitions
        return total_loss

    return loss_fn


def make_ppo_train_step_fn(
    env: SequenceEnv,
    buffer,
    forward: Callable,
    policy_fn: Callable,
    loss_fn: Callable,
    optimizer,
    cfg: Config,
):
    @jax.jit
    def update_network(run_state: RunState, transitions: Transition, advantages, targets):
        def _grad_loss_fn(params: Params, transitions: Transition, rng: jax.random.PRNGKey):
            all_outputs = forward.apply(params, rng, transitions.obs, is_training=True)
            loss = loss_fn(all_outputs, transitions, advantages, targets)

            return loss

        params, opt_state, rng = run_state.params, run_state.opt_state, run_state.rng
        rng, network_rng = jax.random.split(rng)

        # Update params
        grad_fn = jax.value_and_grad(_grad_loss_fn, argnums=0, has_aux=False)
        total_loss, grads = grad_fn(params, transitions, network_rng)
        updates, opt_state = optimizer.update(grads, opt_state, params=params)
        params = optax.apply_updates(params, updates)

        run_state = run_state.replace(params=params, opt_state=opt_state, rng=rng)
        return run_state, total_loss

    @jax.jit
    def get_advantages(transitions: Transition, rewards: jax.Array):
        def _get_advantage(gae_and_next_value, transition):
            gae, next_value = gae_and_next_value
            done, value = (
                transition.done,
                transition.val,
            )

            delta = next_value - value
            gae = delta + cfg.alg.gae_lambda * gae

            # For transitions after termination, just bring back the terminal reward
            curr_value = jnp.where(done, next_value, value)
            gae = jnp.where(done, 0, gae)
            return (gae, curr_value), gae

        rewards = rewards.squeeze(-1)
        _, advantages = jax.lax.scan(
            _get_advantage,
            (jnp.zeros_like(rewards), rewards),
            transitions,
            reverse=True,
        )
        return advantages, advantages + transitions.val

    def train_step(run_state: RunState, buffer_state):
        """MAIN LOOP"""
        # Collect data
        run_state, transitions = gen_transition(run_state, forward, policy_fn, env)
        gen_samples = transitions.obs[-1, :]
        rewards, extra_oracle_info = env.get_rewards(gen_samples)
        advantages, val_targets = get_advantages(transitions, rewards)

        # Make everything [B, T, ...]
        def switch_and_flatten(x):
            x = x.swapaxes(0, 1)
            return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])

        advantages, val_targets = switch_and_flatten(advantages), switch_and_flatten(val_targets)
        transitions = jax.tree_util.tree_map(switch_and_flatten, transitions)  # [B, T, ...]

        # Split generated transitions into batches
        rng, permutation_rng = jax.random.split(run_state.rng)
        run_state = run_state.replace(rng=rng)
        idxs = jax.random.permutation(permutation_rng, advantages.shape[0])

        minibatch_size = int(advantages.shape[0] / cfg.alg.num_minibatches)
        for batch in range(cfg.alg.num_minibatches):
            batch_idxs = idxs[batch * minibatch_size : (batch + 1) * minibatch_size]
            run_state, total_loss = update_network(
                run_state,
                jax.tree_util.tree_map(lambda x: x[batch_idxs], transitions),
                advantages[batch_idxs],
                val_targets[batch_idxs],
            )

        return run_state, buffer_state, total_loss, gen_samples, rewards, extra_oracle_info

    return train_step
