from functools import partial
from typing import NamedTuple

import flashbax as fbx
import jax
import jax.numpy as jnp
import optax

import src.lib.util as util


class Transition(NamedTuple):
    obs_history: jnp.ndarray  # History of observations
    obs: jnp.ndarray  # The final observation in the sequence
    act_history: jnp.ndarray  # History of actions
    done: jnp.ndarray  # If the final step in the sequence is done
    reward: jnp.ndarray  # The reward at the final transition
    discounted_accumulated_return: (
        jnp.ndarray
    )  # The discounted return up to second to last transition
    info: dict[str, jnp.ndarray]  # Additional info, e.g. episode return, length, etc.
    step_count: jnp.ndarray  # Step count (distance from the root)


def init_model_and_optim(env, model_init_fn, args):
    key1, key2 = jax.random.split(jax.random.PRNGKey(args.seed))
    obs = env.init(key1).observation
    obs = obs[jnp.newaxis, ...]
    # Repeat the obs over history length
    obs = jnp.repeat(obs, args.history_length, axis=0)

    params = model_init_fn.init(key2, obs[jnp.newaxis, ...])

    lr = (
        partial(
            util.linear_schedule,
            num_updates=args.lr_anneal_iterations * args.train_epochs_per_iter,
            lr=args.lr,
            min_lr=args.min_lr,
        )
        if args.lr_linear_decay
        else args.lr
    )
    optimizer = optax.chain(
        optax.clip_by_global_norm(args.max_grad_norm),
        optax.adam(learning_rate=lr, eps=args.optim_eps),  # type: ignore
    )
    opt_state = optimizer.init(params)
    return params, optimizer, opt_state


def make_buffer(env, args):
    buffer_fn = fbx.make_trajectory_buffer(
        max_size=args.buffer_size,
        min_length_time_axis=args.n_step + 1,
        sample_batch_size=args.buffer_batch_size,
        sample_sequence_length=args.n_step + 1,
        period=1,
        add_batch_size=args.selfplay_batch_size,
    )
    buffer_fn = buffer_fn.replace(  # type: ignore
        init=jax.jit(buffer_fn.init),
        add=jax.jit(buffer_fn.add, donate_argnums=0),
        sample=jax.jit(buffer_fn.sample),
        can_sample=jax.jit(buffer_fn.can_sample),
    )

    # Make trajectory buffer
    dummy_state = jax.vmap(env.init)(jax.random.split(jax.random.PRNGKey(0), 1))
    dummy_transition = Transition(
        obs_history=jnp.repeat(
            dummy_state.observation,  # (1, *obs_shape)
            args.history_length + 1,
            axis=0,
        ),
        obs=dummy_state.observation[0],
        act_history=jnp.zeros((args.history_length + 1,), dtype=jnp.int32),
        done=jnp.array(False),
        reward=jnp.array(0.0),
        discounted_accumulated_return=jnp.array(0.0),
        info={
            "episode_return": jnp.array(0.0),
            "episode_length": jnp.array(0),
            "is_terminal_step": jnp.array(False),
        },
        step_count=jnp.array(0),
    )
    buffer_state = buffer_fn.init(dummy_transition)

    return buffer_fn, buffer_state


def get_value_targets(
    learn_batch, _qr_model, _reward_history_model, target_params, args, rng_key
):
    # Compute the target
    batch = jax.tree_util.tree_map(
        lambda x: jnp.swapaxes(x, 0, 1), learn_batch
    )  # (b, t, *) -> (t, b, *)

    # Last step doesn't reset since we bootstrap from it (the preceding done is the reset)
    dones = batch.done.at[-1].set(False)
    qr_out = _qr_model.apply(
        target_params,
        batch.obs[-1],  # Second to last observation in the history is s
    )
    rh_out = (
        _reward_history_model.apply(
            target_params,
            batch.obs_history[-1, :, 1:],  # Use history up to s
        )
        * (1 - (batch.step_count[-1] == 0))[:, None]
    )  # (b, num_quantiles)

    rng_key, sample_key = jax.random.split(rng_key)

    # Join batch and action dims
    qr_dist = jnp.permute_dims(
        qr_out.q_dist, (0, 2, 1)
    )  # (batch_size, num_actions, num_quantiles)
    qr_dist = qr_dist.reshape(
        (-1, args.num_quantiles)
    )  # (batch_size * num_actions, num_quantiles)

    sample_keys = jax.random.split(sample_key, qr_dist.shape[0])
    qr_samples = jax.vmap(util.sample_quantile_distribution, in_axes=(0, 0, None))(
        sample_keys, qr_dist, args.num_quantile_samples
    )  # (batch_size * num_actions, num_samples)
    qr_samples = qr_samples.reshape(
        (-1, args.num_actions, args.num_quantile_samples)
    )  # (batch_size, num_actions, num_quantile_samples)

    rng_key, sample_key = jax.random.split(rng_key)
    sample_keys = jax.random.split(sample_key, rh_out.shape[0])
    rh_samples = jax.vmap(util.sample_quantile_distribution, in_axes=(0, 0, None))(
        sample_keys, rh_out, args.num_quantile_samples
    )  # (batch_size, num_samples)

    samples = (
        qr_samples + rh_samples[:, None, :]
    )  # (batch_size, num_actions, num_samples)

    samples = samples.reshape(
        (-1, args.num_quantile_samples)
    )  # (batch_size * num_actions, num_samples)
    q_values = jax.vmap(util.cvar, in_axes=(0, None))(samples, args.alpha_cvar)
    q_values = q_values.reshape((-1, args.num_actions))  # (batch_size, num_actions)
    # q_values = jnp.mean(samples, axis=-1)  # (batch_size, num_actions)

    # Mask out illegal actions if needed
    greedy_action = jnp.argmax(q_values, axis=-1)

    qr_next_target = jnp.take_along_axis(
        qr_out.q_dist,
        greedy_action[:, None, None],  # (b, num_quantiles, 1)
        axis=-1,
    ).squeeze(-1)  # (b, num_quantiles)

    def body_fn(carry, i):
        ix = args.n_step - i - 1  # Reverse index for backward pass
        v = (
            batch.reward[ix, :, None] + (1 - dones[ix, :, None]) * args.gamma * carry
        )  # (batch_size,)
        return v, v

    _, qr_next_target = jax.lax.scan(
        body_fn,
        qr_next_target,
        jnp.arange(args.n_step),
    )
    target_qr = qr_next_target[-1, :]

    return target_qr  # (b, num_quantiles)
