from typing import NamedTuple

import haiku as hk
import jax
import jax.numpy as jnp

import src.lib.util as util
from src.baselines.qrdqn.util import distort_value


class QRNetworkOutputs(NamedTuple):
    q_values: jnp.ndarray
    q_dist: jnp.ndarray


class SharedBackbone(hk.Module):
    def __init__(
        self,
        is_state_vector: bool,
    ) -> None:
        super().__init__()
        self.is_state_vector = is_state_vector

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = x.astype(jnp.float32)

        # If the state is an image, we use a CNN
        if not self.is_state_vector:
            # x_in for minatar: (1, 10, 10, 4) [batch_size, height, width, frames]
            x = hk.Conv2D(16, kernel_shape=3)(x)  # (1, 10, 10, 16) [padding=SAME]
            x = jax.nn.relu(x)

            # 3 residual blocks
            for _ in range(3):
                x_res = hk.Conv2D(16, kernel_shape=3, padding="SAME")(x)
                x_res = jax.nn.relu(x_res)
                x = x + x_res  # residual connection

        x = x.reshape((x.shape[0], -1))  # flatten
        x = hk.Linear(512)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(256)(x)

        return x  # (1, hidden_size)


class QRNetwork(hk.Module):
    def __init__(
        self,
        num_actions: int,
        num_quantiles: int,
        is_state_vector: bool,
        alpha_cvar: float,
    ):
        super().__init__()
        self.num_actions = num_actions
        self.num_quantiles = num_quantiles
        self.is_state_vector = is_state_vector
        self.alpha_cvar = alpha_cvar

        self.tau_hats = util.make_tau_hats(num_quantiles)

        self.feature_extractor = SharedBackbone(
            is_state_vector=is_state_vector,
        )

    def __call__(self, x: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
        x = x.astype(jnp.float32)

        x = self.feature_extractor(x)  # (b, hidden_size)
        x = jax.nn.relu(x)
        x = hk.Linear(self.num_quantiles * self.num_actions)(
            x
        )  # (b, args.num_quantiles * num_actions)
        q_dist = jnp.reshape(
            x, (-1, self.num_quantiles, self.num_actions)
        )  # (b, args.num_quantiles, num_actions)
        q_values = jax.vmap(distort_value, in_axes=(0, None, None))(
            q_dist,  # (b, num_quantiles, num_actions)
            self.tau_hats,  # (num_quantiles,)
            self.alpha_cvar,
        )
        q_values = jax.lax.stop_gradient(q_values)  # (b, num_actions)
        return QRNetworkOutputs(q_dist=q_dist, q_values=q_values)


class TQLNetwork(hk.Module):
    def __init__(
        self,
        num_actions: int,
        num_quantiles: int,
        is_state_vector: bool,
        alpha_cvar: float,
    ):
        super().__init__()
        self.num_actions = num_actions
        self.num_quantiles = num_quantiles
        self.alpha_cvar = alpha_cvar

        self.tau_hats = util.make_tau_hats(num_quantiles)

        self.feature_extractor = SharedBackbone(is_state_vector=is_state_vector)

    def __call__(
        self,
        x: jnp.ndarray,
    ) -> QRNetworkOutputs:
        # Join the sequence and batch dimensions to get latents
        batch, seq_len = x.shape[0], x.shape[1]
        x = x.reshape((-1, *x.shape[2:]))  # (batch * seq_len, *obs_shape)
        x = self.feature_extractor(x)  # (batch * seq_len, hidden_size)
        x = jax.nn.relu(x)  # (batch * seq_len, hidden_size)

        # Return the sequence dimension
        x = x.reshape((batch, seq_len, -1))  # (batch, seq_len, hidden_size)

        core = hk.GRU(256)
        rnn_initial_states = core.initial_state(batch)

        # Dynamic unroll the RNN core
        outs, rnn_state = hk.dynamic_unroll(
            core,
            x,
            rnn_initial_states,
            time_major=False,
        )
        outs_ = outs[:, -1, :]  # type: ignore # Take the last time step (the current state) (batch, hidden_size)

        outs = hk.Linear(256)(outs_)  # (b, hidden_size)
        outs = outs_ + jax.nn.relu(outs)  # (b, hidden_size)
        q_dist = hk.Linear(self.num_quantiles * self.num_actions)(
            outs
        )  # (b, num_quantiles * num_actions)
        q_dist = jnp.reshape(
            q_dist, (-1, self.num_quantiles, self.num_actions)
        )  # (b, args.num_quantiles, num_actions)

        q_values = jax.vmap(distort_value, in_axes=(0, None, None))(
            q_dist,  # (b, num_quantiles, num_actions)
            self.tau_hats,  # (num_quantiles,)
            self.alpha_cvar,
        )
        q_values = jax.lax.stop_gradient(q_values)  # (b, num_actions)

        return QRNetworkOutputs(q_dist=q_dist, q_values=q_values)


def create_qr_networks(
    num_actions: int,
    num_quantiles: int,
    is_state_vector: bool,
    alpha_cvar: float,
):
    def qr_model_fn(x: jnp.ndarray) -> QRNetworkOutputs:
        # x: (b, *obs_shape)
        network = QRNetwork(
            num_actions=num_actions,
            num_quantiles=num_quantiles,
            is_state_vector=is_state_vector,
            alpha_cvar=alpha_cvar,
        )
        return network(x)

    def tql_model_fn(x: jnp.ndarray) -> QRNetworkOutputs:
        # x: (b, history_length, *obs_shape)
        network = TQLNetwork(
            num_actions=num_actions,
            num_quantiles=num_quantiles,
            is_state_vector=is_state_vector,
            alpha_cvar=alpha_cvar,
        )
        return network(x)

    def init_model_fn(x: jnp.ndarray) -> tuple[QRNetworkOutputs, QRNetworkOutputs]:
        # x: (b, history_length, *obs_shape)

        # This is used for initializing the model
        qr = QRNetwork(
            num_actions=num_actions,
            num_quantiles=num_quantiles,
            is_state_vector=is_state_vector,
            alpha_cvar=alpha_cvar,
        )

        tql = TQLNetwork(
            num_actions=num_actions,
            num_quantiles=num_quantiles,
            is_state_vector=is_state_vector,
            alpha_cvar=alpha_cvar,
        )

        tql_out = tql(x)
        qr_out = qr(x[:, 0, :])  # Use only the first frame for QR

        return tql_out, qr_out

    qr_model = hk.without_apply_rng(hk.transform(qr_model_fn))
    tql_model = hk.without_apply_rng(hk.transform(tql_model_fn))
    init_model = hk.without_apply_rng(hk.transform(init_model_fn))
    return qr_model, tql_model, init_model
