from functools import partial
from typing import Mapping, NamedTuple

import flashbax as fbx
import jax
import jax.numpy as jnp
import optax

import src.lib.util as util


class ExItTransition(NamedTuple):
    done: jnp.ndarray  # (b,)
    action: jnp.ndarray  # (b,)
    reward: jnp.ndarray  # (b,)
    search_policy: jnp.ndarray  # (b, num_actions)
    obs: jnp.ndarray  # (b, *obs_shape)
    info: Mapping[
        str, float | int | bool | jnp.ndarray
    ]  # Additional info, e.g., episode return, length


def scale_gradient(g: jnp.ndarray, scale: float = 1) -> jnp.ndarray:
    """Scales the gradient of `g` by `scale` but keeps the original value unchanged."""
    return g * scale + jax.lax.stop_gradient(g) * (1.0 - scale)


def init_model_and_optim(env, model_init_fn, args):
    dummy_obs = jnp.zeros(
        (1, *env.observation_shape)
    )  # (b, *obs_shape), b=1 for initialization
    dummy_action = jax.random.randint(
        jax.random.PRNGKey(args.seed), (1,), 0, env.num_actions
    )
    params = model_init_fn.init(jax.random.PRNGKey(args.seed), dummy_obs, dummy_action)

    lr = (
        partial(
            util.linear_schedule,
            num_updates=args.lr_anneal_iterations * args.train_epochs_per_iter,
            lr=args.lr,
            min_lr=args.min_lr,
        )
        if args.lr_linear_decay
        else args.lr
    )
    optimizer = optax.chain(
        optax.clip_by_global_norm(args.max_grad_norm),
        optax.adam(learning_rate=lr, eps=args.optim_eps),  # type: ignore
    )
    opt_state = optimizer.init(params)
    return params, optimizer, opt_state


def make_buffer(
    env, args
) -> tuple[
    fbx.trajectory_buffer.TrajectoryBuffer, fbx.trajectory_buffer.TrajectoryBufferState
]:
    # Make trajectory buffer
    dummy_state = jax.vmap(env.init)(
        jax.random.split(jax.random.PRNGKey(0), 1)
    )  # (1, *env.observation_shape)
    dummy_transition = ExItTransition(
        done=jnp.array(False),
        action=jnp.array(0),
        reward=jnp.array(0.0),
        search_policy=jnp.zeros((args.num_actions,)),
        obs=dummy_state.observation[0],
        info={
            "episode_return": jnp.array(0.0),
            "episode_length": jnp.array(0),
            "is_terminal_step": jnp.array(False),
        },
    )
    buffer_fn = fbx.make_trajectory_buffer(
        max_size=args.total_buffer_size,
        min_length_time_axis=args.sample_sequence_length + args.n_step,
        sample_batch_size=args.train_batch_size,
        sample_sequence_length=args.sample_sequence_length + args.n_step,
        period=1,
        add_batch_size=args.selfplay_batch_size,
    )
    buffer_fn = buffer_fn.replace(  # type: ignore
        init=jax.jit(buffer_fn.init),
        add=jax.jit(buffer_fn.add, donate_argnums=0),
        sample=jax.jit(buffer_fn.sample),
        can_sample=jax.jit(buffer_fn.can_sample),
    )
    buffer_state = buffer_fn.init(dummy_transition)

    return buffer_fn, buffer_state


def batch_n_step_bootstrapped_returns(
    r_t: jax.Array,
    discount_t: jax.Array,
    v_t: jax.Array,
    n: int,
    lambda_t: float = 1.0,
    stop_target_gradients: bool = True,
) -> jax.Array:
    """Stolen from rlax.
    Computes strided n-step bootstrapped return targets over a batch of sequences.

    The returns are computed according to the below equation iterated `n` times:

        Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].

    When lambda_t == 1. (default), this reduces to

        Gₜ = rₜ₊₁ + γₜ₊₁ * (rₜ₊₂ + γₜ₊₂ * (... * (rₜ₊ₙ + γₜ₊ₙ * vₜ₊ₙ ))).

    Args:
        r_t: rewards at times B x [1, ..., T].
        discount_t: discounts at times B x [1, ..., T].
        v_t: state or state-action values to bootstrap from at time B x [1, ...., T].
        n: number of steps over which to accumulate reward before bootstrapping.
        lambda_t: lambdas at times B x [1, ..., T]. Shape is [], or B x [T-1].
        stop_target_gradients: bool indicating whether or not to apply stop gradient
        to targets.

    Returns:
        estimated bootstrapped returns at times B x [0, ...., T-1]
    """
    # swap axes to make time axis the first dimension
    r_t, discount_t, v_t = jax.tree_util.tree_map(
        lambda x: jnp.swapaxes(x, 0, 1), (r_t, discount_t, v_t)
    )
    seq_len = r_t.shape[0]
    batch_size = r_t.shape[1]

    # Maybe change scalar lambda to an array.
    lambda_t = jnp.ones_like(discount_t) * lambda_t  # type: ignore

    # Shift bootstrap values by n and pad end of sequence with last value v_t[-1].
    pad_size = min(n, seq_len)
    targets = jnp.concatenate([v_t[n:], jnp.array([v_t[-1]] * pad_size)], axis=0)

    # Pad sequences. Shape is now (T + n - 1,).
    r_t = jnp.concatenate([r_t, jnp.zeros((n - 1, batch_size))], axis=0)
    discount_t = jnp.concatenate([discount_t, jnp.ones((n - 1, batch_size))], axis=0)
    lambda_t = jnp.concatenate([lambda_t, jnp.ones((n - 1, batch_size))], axis=0)  # type: ignore
    v_t = jnp.concatenate([v_t, jnp.array([v_t[-1]] * (n - 1))], axis=0)

    # Work backwards to compute n-step returns.
    for i in reversed(range(n)):
        r_ = r_t[i : i + seq_len]
        discount_ = discount_t[i : i + seq_len]
        lambda_ = lambda_t[i : i + seq_len]  # type: ignore
        v_ = v_t[i : i + seq_len]
        targets = r_ + discount_ * ((1.0 - lambda_) * v_ + lambda_ * targets)

    targets = jnp.swapaxes(targets, 0, 1)
    return jax.lax.select(
        stop_target_gradients, jax.lax.stop_gradient(targets), targets
    )


def get_train_targets(batch, target_params, representation_apply, critic_apply, args):
    # Compute value for each obs
    B, S = batch.obs.shape[:2]  # (b, t)
    obs = jnp.reshape(batch.obs, (-1, *batch.obs.shape[2:]))  # (b*t, *obs_shape)
    obs_embedding = representation_apply.apply(target_params, obs)  # (b*t, num_hidden)
    values = critic_apply.apply(target_params, obs_embedding)  # (b*t,)
    values = jnp.reshape(values, batch.obs.shape[:2])  # (b, t)

    # Return obs sequence dim
    s_targets = jnp.reshape(obs_embedding, (B, S, -1))

    # Mask after dones: cumulative sum along the row will be >1 after the first True
    discounts = jnp.ones_like(batch.reward) * args.discount  # (b, t)
    discounts = jnp.cumsum(batch.done, axis=1)
    discounts = (discounts < 1).astype(jnp.int32)

    value_targets = batch_n_step_bootstrapped_returns(
        batch.reward,
        discounts,
        values,
        n=args.n_step,
    )

    # We want to mask everything _after_ the first done.
    # This means shifting the discounts right by one and padding the first value as a 1
    # Note that this works because the first value is always a 1 since our mask starts after
    # the first TRUE, which can at worst be the first entry.
    discounts = jnp.concatenate(
        [
            jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32),
            discounts[:, :-1],
        ],
        axis=1,
    )  # (b, t)
    value_targets = value_targets * discounts

    # We want to mask all rewards after the first done.
    reward_targets = batch.reward * discounts
    policy_targets = (
        batch.search_policy * discounts[:, :, jnp.newaxis]
    )  # (b, t, num_actions)

    # Get unroll length
    value_targets = value_targets[:, : args.sample_sequence_length]
    reward_targets = reward_targets[:, : args.sample_sequence_length]
    policy_targets = policy_targets[:, : args.sample_sequence_length]

    return policy_targets, reward_targets, value_targets, s_targets
