from functools import partial
from typing import Any, NamedTuple

import chex
import flashbax as fbx
import jax
import jax.numpy as jnp
import optax

import src.lib.util as util

# Loss Function


@partial(jax.jit, static_argnames=("delta",))
def huber_loss(x: jnp.ndarray, delta: float = 1.0) -> jnp.ndarray:
    chex.assert_type(x, float)
    # 0.5 * x^2                  if |x| <= d
    # 0.5 * d^2 + d * (|x| - d)  if |x| > d
    abs_x = jnp.abs(x)
    quadratic = jnp.minimum(abs_x, delta)
    # Same as max(abs_x - delta, 0) but avoids potentially doubling gradient.
    linear = abs_x - quadratic
    return 0.5 * quadratic**2 + delta * linear


@partial(jax.jit, static_argnames=("huber_param", "stop_target_gradients"))
def quantile_huber_loss(
    dist_src: jnp.ndarray,  # (num_quantiles, )
    tau_src: jnp.ndarray,  # (num_quantiles, )
    dist_target: jnp.ndarray,  # (num_quantiles, )
    huber_param: float = 0,
    stop_target_gradients: bool = True,
):
    # Calculate quantile error.
    delta = dist_target[None, :] - dist_src[:, None]
    delta_neg = (delta < 0.0).astype(jnp.float32)
    delta_neg = jax.lax.select(
        stop_target_gradients, jax.lax.stop_gradient(delta_neg), delta_neg
    )
    weight = jnp.abs(tau_src[:, None] - delta_neg)

    # Calculate Huber loss.
    if huber_param > 0.0:
        loss = huber_loss(delta, huber_param)
    else:
        loss = jnp.abs(delta)
    loss *= weight

    # Average over target-samples dimension, sum over src-samples dimension.
    return jnp.sum(jnp.mean(loss, axis=-1))


@partial(
    jax.jit,
    static_argnames=(
        "epsilon_start",
        "epsilon_finish",
        "epsilon_anneal_time",
    ),
)
def calc_eps(
    t: int,
    epsilon_start: float,
    epsilon_finish: float,
    epsilon_anneal_time: float,
) -> jnp.ndarray:
    return jnp.clip(
        ((epsilon_finish - epsilon_start) / epsilon_anneal_time)
        * (jnp.maximum(0, t))  # Only anneal after learning starts
        + epsilon_start,
        epsilon_finish,
    )


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    reward: jnp.ndarray
    obs: jnp.ndarray
    info: dict[str, Any] | None = None  # Additional info, e.g., episode stats


@partial(jax.jit, static_argnames=("alpha_cvar",))
def distort_value(
    q_dist: jnp.ndarray, tau: jnp.ndarray, alpha_cvar: int
) -> jnp.ndarray:
    """Assumes fixed cvar level for now. Masks all quantile with tau > alpha"""
    # q_dist: (num_quantiles, num_actions)
    # tau: (num_quantiles, )

    # Mask all quantiles with tau > alpha
    mask = tau <= alpha_cvar  # (num_quantiles, num_actions)
    # Multiply q_dist with mask

    # Calculate the distorted value by taking the mean of the quantiles that are below the cvar level
    q_dist_value = jnp.sum(mask[:, None] * q_dist, axis=0) / jnp.sum(mask, axis=0)
    return q_dist_value  # (num_actions, )


def init_model_and_optim(env, model_init_fn, args):
    key1, key2 = jax.random.split(jax.random.PRNGKey(args.seed))
    obs = env.init(key1).observation

    params = model_init_fn.init(key2, obs[jnp.newaxis, ...])

    lr = (
        partial(
            util.linear_schedule,
            num_updates=args.lr_anneal_iterations * args.train_epochs_per_iter,
            lr=args.lr,
            min_lr=args.min_lr,
        )
        if args.lr_linear_decay
        else args.lr
    )
    optimizer = optax.chain(
        optax.clip_by_global_norm(args.max_grad_norm),
        optax.adam(learning_rate=lr, eps=args.optim_eps),  # type: ignore
    )
    opt_state = optimizer.init(params)
    return params, optimizer, opt_state


def make_buffer(env, args):
    buffer_fn = fbx.make_trajectory_buffer(
        max_size=args.buffer_size,
        min_length_time_axis=args.n_step + 1,
        sample_batch_size=args.buffer_batch_size,
        sample_sequence_length=args.n_step + 1,
        period=1,
        add_batch_size=args.selfplay_batch_size,
    )
    buffer_fn = buffer_fn.replace(  # type: ignore
        init=jax.jit(buffer_fn.init),
        add=jax.jit(buffer_fn.add, donate_argnums=0),
        sample=jax.jit(buffer_fn.sample),
        can_sample=jax.jit(buffer_fn.can_sample),
    )

    # Make trajectory buffer
    dummy_state = jax.vmap(env.init)(jax.random.split(jax.random.PRNGKey(0), 1))
    dummy_transition = Transition(
        done=jnp.array(False),
        action=jnp.array(0),
        reward=jnp.array(0.0),
        obs=dummy_state.observation[0],
        info={
            "episode_return": jnp.array(0.0),
            "episode_length": jnp.array(0),
            "is_terminal_step": jnp.array(False),
        },
    )
    buffer_state = buffer_fn.init(dummy_transition)

    return buffer_fn, buffer_state


def get_value_target(learn_batch, _model, target_params, args):
    # Compute the target
    batch = jax.tree_util.tree_map(
        lambda x: jnp.swapaxes(x, 0, 1), learn_batch
    )  # (b, t, *) -> (t, b, *)
    # Last step doesn't reset since we bootstrap from it (the preceding done is the reset)
    dones = batch.done.at[-1].set(False)
    q_next_out = _model.apply(
        target_params, batch.obs[-1]
    )  # (b, num_quantiles, num_actions)
    greedy_actions = jnp.argmax(q_next_out.q_values, axis=-1)  # (b, )
    q_next_target = jnp.take_along_axis(
        q_next_out.q_dist,
        jnp.expand_dims(greedy_actions, axis=(-1, -2)),  # (b, num_quantiles, 1)
        axis=-1,
    ).squeeze(-1)  # (b, num_quantiles)

    def body_fn(carry, i):
        ix = args.n_step - i - 1  # Reverse index for backward pass
        v = (
            batch.reward[ix, :, None] + (1 - dones[ix, :, None]) * args.gamma * carry
        )  # (batch_size,)
        return v, v

    _, q_next_target = jax.lax.scan(
        body_fn,
        q_next_target,
        jnp.arange(args.n_step),
    )
    q_next_target = q_next_target[-1, :]
    return q_next_target  # (b, num_quantiles)
