from typing import Final, NamedTuple

from flax import nnx
from jax import Array, lax, numpy as jnp
from jax.nn import relu, softplus
from jax.scipy.linalg import block_diag
from optax import squared_error

from offline.lbp.tc.modules.vector_quantizer_ema import VectorQuantizerEMA
from offline.lbp.tc.modules.ssm_init import make_DPLR_HiPPO
from offline.lbp.tc.modules.ssm import BatchS5SSM
from offline.modules.actor.base import EPS
from offline.modules.actor.utils import gaussian_log_likelihood
from offline.modules.mlp import MLP


class TCAEResults(NamedTuple):
    loss_action: Array
    loss_latent: Array
    loss_reward: Array
    loss_transition: Array
    perplexity: Array


class Encoder(nnx.Module):
    def __init__(
        self,
        clip_eigenvalues: bool,
        hidden_features: int,
        max_timescale: float,
        min_timescale: float,
        num_blocks: int,
        observation_dim: int,
        observation_embedding_dim: int,
        out_features: int,
        reward_embedding_dim: int,
        rngs: nnx.Rngs,
        ssm_base_size: int,
    ):
        block_size = int(ssm_base_size / num_blocks)
        lambda_matrix, _, _, v_matrix, _ = make_DPLR_HiPPO(block_size)
        block_size = block_size // 2
        ssm_size = ssm_base_size // 2

        lambda_matrix = lambda_matrix[:block_size]
        v_matrix = v_matrix[:, :block_size]
        v_inverse = v_matrix.conj().T

        # If initializing state matrix A as block-diagonal, put HiPPO approximation
        # on each block
        lambda_matrix = (
            lambda_matrix * jnp.ones((num_blocks, block_size))
        ).ravel()
        v_matrix = block_diag(*([v_matrix] * num_blocks))
        v_inverse = block_diag(*([v_inverse] * num_blocks))
        self.observation_encoder = nnx.Linear(
            observation_dim, observation_embedding_dim, rngs=rngs
        )
        self.reward_encoder: nnx.Linear | None
        if reward_embedding_dim > 0:
            self.reward_encoder = nnx.Linear(1, reward_embedding_dim, rngs=rngs)
        else:
            self.reward_encoder = None
        self.linear = nnx.Linear(
            observation_embedding_dim + reward_embedding_dim,
            hidden_features,
            rngs=rngs,
        )
        self.s5 = BatchS5SSM(
            C_init="lecun_normal",
            clip_eigs=clip_eigenvalues,
            discretization="zoh",
            dt_max=max_timescale,
            dt_min=min_timescale,
            H=hidden_features,
            Lambda_im_init=lambda_matrix.imag,
            Lambda_re_init=lambda_matrix.real,
            P=ssm_size,
            rngs=rngs,
            V=v_matrix,
            Vinv=v_inverse,
        )
        self.final = nnx.Linear(hidden_features, out_features, rngs=rngs)

    def __call__(self, observations: Array, rewards: Array) -> Array:
        inputs = self.observation_encoder(observations)  # type: ignore
        if self.reward_encoder is not None:
            reward_features = self.reward_encoder(rewards)  # type: ignore
            inputs = jnp.concat((inputs, reward_features), axis=-1)
        # [batch_size, max_episode_length, obs_dim + rew_dim]
        inputs = relu(inputs)
        # [batch_size, max_episode_length, hidden_features]
        inputs = self.linear(inputs)
        # [batch_size, max_episode_length, hidden_features]
        latents = self.s5(inputs)
        # [batch_size, max_episode_length, out_features]
        outputs = self.final(latents)
        return outputs


class DeterminsiticDecoder(nnx.Module):
    def __init__(
        self,
        in_features0: int,
        in_features1: int,
        latent_features: int,
        out_features: int,
        rngs: nnx.Rngs,
        **kwargs,
    ):
        self.second_input = in_features1 > 0
        self.model = MLP(
            in_features=in_features0 + in_features1 + latent_features,
            out_features=out_features,
            rngs=rngs,
            **kwargs,
        )

    def __call__(self, inputs0, inputs1, latents: Array) -> tuple[Array, None]:
        if self.second_input:
            inputs = jnp.concatenate((inputs0, inputs1, latents), axis=-1)
        else:
            inputs = jnp.concatenate((inputs0, latents), axis=-1)
        outputs = self.model(inputs)
        return outputs, None


class GaussianDecoder(nnx.Module):
    def __init__(
        self,
        in_features0: int,
        in_features1: int,
        latent_features: int,
        out_features: int,
        rngs: nnx.Rngs,
        eps: float = EPS,
        **kwargs,
    ):
        self.second_input = in_features1 > 0
        self.eps = eps
        self.model = MLP(
            in_features=in_features0 + in_features1 + latent_features,
            out_features=out_features * 2,
            rngs=rngs,
            **kwargs,
        )

    def __call__(self, inputs0, inputs1, latents: Array) -> tuple[Array, Array]:
        if self.second_input:
            inputs = jnp.concatenate((inputs0, inputs1, latents), axis=-1)
        else:
            inputs = jnp.concatenate((inputs0, latents), axis=-1)
        outputs = self.model(inputs)
        means, stds = jnp.split(outputs, 2, axis=-1)
        stds = softplus(stds) + self.eps
        return means, stds


class TCAutoEncoder(nnx.Module):
    def __init__(
        self,
        action_dim: int,
        clip_eigenvalues: bool,
        codebook_size: int,
        decay: float,
        decode_reward: bool,
        decode_transition: bool,
        deterministic_reward: bool,
        deterministic_transition: bool,
        hidden_features: int,
        latent_dim: int,
        max_timescale: float,
        min_timescale: float,
        num_blocks: int,
        observation_dim: int,
        observation_embedding_dim: int,
        reward_embedding_dim: int,
        rngs: nnx.Rngs,
        ssm_base_size: int,
        use_next_observation: bool,
        **kwargs,
    ):
        self.encoder = Encoder(
            clip_eigenvalues=clip_eigenvalues,
            hidden_features=hidden_features,
            max_timescale=max_timescale,
            min_timescale=min_timescale,
            num_blocks=num_blocks,
            observation_dim=observation_dim,
            observation_embedding_dim=observation_embedding_dim,
            out_features=latent_dim,
            reward_embedding_dim=reward_embedding_dim,
            rngs=rngs,
            ssm_base_size=ssm_base_size,
        )
        self.vector_quantizer = VectorQuantizerEMA(
            decay=decay,
            embedding_dim=latent_dim,
            num_embeddings=codebook_size,
            rngs=rngs,
        )
        self.decoder_actor = GaussianDecoder(
            hidden_features=hidden_features,
            in_features0=observation_dim,
            in_features1=0,
            latent_features=latent_dim,
            out_features=action_dim,
            rngs=rngs,
            **kwargs,
        )
        self.decoder_reward: DeterminsiticDecoder | GaussianDecoder | None
        if decode_reward:
            in_features1 = observation_dim if use_next_observation else 0
            if deterministic_reward:
                self.decoder_reward = DeterminsiticDecoder(
                    hidden_features=hidden_features,
                    in_features0=observation_dim,
                    in_features1=in_features1,
                    latent_features=latent_dim,
                    out_features=1,
                    rngs=rngs,
                    **kwargs,
                )
            else:
                self.decoder_reward = GaussianDecoder(
                    hidden_features=hidden_features,
                    in_features0=observation_dim,
                    in_features1=in_features1,
                    latent_features=latent_dim,
                    out_features=1,
                    rngs=rngs,
                    **kwargs,
                )
        else:
            self.decoder_reward = None
        self.decoder_transition: DeterminsiticDecoder | GaussianDecoder | None
        if decode_transition:
            if deterministic_transition:
                self.decoder_transition = DeterminsiticDecoder(
                    hidden_features=hidden_features,
                    in_features0=observation_dim,
                    in_features1=0,
                    latent_features=latent_dim,
                    out_features=observation_dim,
                    rngs=rngs,
                    **kwargs,
                )
            else:
                self.decoder_transition = GaussianDecoder(
                    hidden_features=hidden_features,
                    in_features0=observation_dim,
                    in_features1=0,
                    latent_features=latent_dim,
                    out_features=observation_dim,
                    rngs=rngs,
                    **kwargs,
                )
        else:
            self.decoder_transition = None
        self.feed_next_observations: Final[bool] = decode_transition or (
            decode_reward and use_next_observation
        )

    def __call__(
        self,
        actions: Array,
        decode_indices: Array,
        decode_mask: Array,
        latent_indices: Array,
        observations: Array,
        rewards: Array,
    ):
        vq_results = self.compute_embeddings(
            latent_indices=latent_indices,
            observations=observations,
            rewards=rewards,
        )
        # [..., max_episode_length, latent_dim]
        quantized = vq_results.quantized
        # [..., subsample_decode, max_episode_length, latent_dim]
        quantized = jnp.broadcast_to(
            jnp.expand_dims(vq_results.quantized, -3),
            decode_indices.shape + quantized.shape[-2:],
        )
        # [..., subsample_decode, subsample_latent]
        batch_shape = quantized.shape[:-1]
        # [..., subsample_decode, 1]
        decode_indices = jnp.expand_dims(decode_indices, -1)
        # [..., subsample_decode, 1, 1]
        decode_mask = jnp.expand_dims(decode_mask, (-1, -2))
        # [batch_size, subsample_decode, subsample_latent, feature_dim]
        actions = self._subsample(
            actions, batch_shape=batch_shape, indices=decode_indices
        )
        observations_ = self._subsample(
            observations, batch_shape=batch_shape, indices=decode_indices
        )
        rewards = self._subsample(
            rewards, batch_shape=batch_shape, indices=decode_indices
        )
        loss_action = self._compute_action_loss(
            actions=actions, latents=quantized, observations=observations_
        )
        if self.feed_next_observations:
            decode_indices_next = lax.select(
                decode_mask[..., 0], decode_indices + 1, decode_indices
            )
            next_observations = self._subsample(
                observations,
                batch_shape=batch_shape,
                indices=decode_indices_next,
            )
            loss_reward = self._compute_reward_loss(
                latents=quantized,
                mask=decode_mask,
                next_observations=next_observations,
                observations=observations_,
                rewards=rewards,
            )
        else:
            next_observations = observations
            loss_reward = self._compute_reward_loss(
                latents=quantized,
                mask=jnp.asarray(1, dtype=decode_mask.dtype),
                next_observations=next_observations,
                observations=observations_,
                rewards=rewards,
            )
        loss_transition = self._compute_transition_loss(
            latents=quantized,
            mask=decode_mask,
            next_observations=next_observations,
            observations=observations_,
        )
        return TCAEResults(
            loss_action=loss_action,
            loss_latent=vq_results.loss,
            loss_reward=loss_reward,
            loss_transition=loss_transition,
            perplexity=vq_results.perplexity,
        )

    def _compute_action_loss(
        self, actions: Array, latents: Array, observations: Array
    ):
        means, stds = self.decoder_actor(observations, observations, latents)
        loss = gaussian_log_likelihood(means=means, samples=actions, stds=stds)
        return -jnp.mean(loss)

    def _compute_reward_loss(
        self,
        latents: Array,
        mask: Array,
        next_observations: Array,
        observations: Array,
        rewards: Array,
    ):
        if self.decoder_reward is None:
            return jnp.array(0, dtype=latents.dtype)
        # [..., subsample_decode, subsample_latent, 1]
        means, stds = self.decoder_reward(
            observations, next_observations, latents
        )
        if stds is None:
            # [..., subsample_decode, subsample_latent, 1]
            loss = squared_error(means, rewards)
        else:
            # [..., subsample_decode, subsample_latent, 1]
            loss = gaussian_log_likelihood(
                means=means, reduce=False, samples=rewards, stds=stds
            )
            loss = -loss
        return jnp.mean(loss * mask)

    def _compute_transition_loss(
        self,
        latents: Array,
        mask: Array,
        next_observations: Array,
        observations: Array,
    ):
        if self.decoder_transition is None:
            return jnp.array(0, dtype=latents.dtype)
        # [..., subsample_decode, subsample_latent, observation_dim]
        means, stds = self.decoder_transition(
            observations, observations, latents
        )
        if stds is None:
            # [..., subsample_decode, subsample_latent, observation_dim]
            loss = squared_error(means, next_observations)
        else:
            # [..., subsample_decode, subsample_latent, observation_dim]
            loss = gaussian_log_likelihood(
                means=means, samples=next_observations, reduce=False, stds=stds
            )
            loss = -loss
        return jnp.mean(loss * mask)

    def _subsample(
        self, inputs: Array, batch_shape: tuple[int, ...], indices: Array
    ):
        # batch_shape: [..., subsample_decode, subsample_latent]
        # [..., subsample_decode, features]
        inputs = jnp.take_along_axis(inputs, indices, axis=-2)
        # [..., subsample_decode, 1, features]
        inputs = jnp.expand_dims(inputs, -2)
        # [..., subsample_decode, subsample_latent, features]
        inputs = jnp.broadcast_to(inputs, batch_shape + (inputs.shape[-1],))
        return inputs

    def compute_embeddings(
        self, latent_indices: Array, observations: Array, rewards: Array
    ):
        # [..., max_episode_length, latent_dim]
        full_latents = self.encoder(observations, rewards)
        # [..., subsample_latent, 1]
        latent_indices = jnp.expand_dims(latent_indices, -1)
        # [..., subsample_latent, latent_dim]
        latents = jnp.take_along_axis(full_latents, latent_indices, axis=-2)
        vq_results = self.vector_quantizer(latents)
        return vq_results

    @property
    def codebook_size(self):
        return self.vector_quantizer.embeddings.shape[1]
