"""Contains a PerceiverDecoder implementation."""

from typing import Any

import einops
import torch
from torch import nn

from .perceiver_block import PerceiverBlock
from .continuous_sincos_embed import ContinuousSincosEmbed
from .linear_projection import LinearProjection


class DeepPerceiverDecoder(nn.Module):
    """A perceiver decoder model.

    Args:
        dim: Dimension of the decoder.
        num_heads: Number of heads in the decoder.
        ndim: Number of dimensions for the position input.
        input_dim: Dimension of the input.
        output_dim: Dimension of the output.
        depth: Depth of the decoder.
        block_ctor: Block constructor.
        init_weights: Initialization method for the weights.
        mlp_expansion_factor: Ratio of the hidden dimension of the MLPs.
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        ndim: int,
        input_dim: int,
        output_dim: int,
        depth: int = 1,
        block_ctor: type = PerceiverBlock,
        init_weights: str = "truncnormal002",
        eps: float = 1e-6,
        mlp_expansion_factor: int = 4,
    ):
        super().__init__()
        # create query
        self.pos_embed = ContinuousSincosEmbed(dim=dim, ndim=ndim)
        self.query = nn.Sequential(
            LinearProjection(dim, dim, init_weights=init_weights),
            nn.GELU(),
            LinearProjection(dim, dim, init_weights=init_weights),
        )
        # perceiver
        self.proj = LinearProjection(input_dim, dim, init_weights=init_weights, optional=True)
        self.blocks = nn.ModuleList(
            [
                block_ctor(
                    dim=dim,
                    num_heads=num_heads,
                    init_weights=init_weights,
                    eps=eps,
                    mlp_hidden_dim=dim * mlp_expansion_factor,
                )
                for _ in range(depth)
            ],
        )
        self.norm = nn.LayerNorm(dim, eps=eps)
        self.pred = LinearProjection(dim, output_dim, init_weights=init_weights)

    def forward(
        self,
        x: torch.Tensor,
        pos: torch.Tensor,
        block_kwargs: dict[str, Any] | None = None,
        unbatch_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass of the model.

        Args:
            x: Latent tokens as dense tensor (batch_size, num_latent_tokens, dim).
            pos: Query positions (batch_size, num_output_pos, pos_dim).
            block_kwargs: Additional arguments for the block.
            unbatch_mask: Unbatch mask.

        Returns:
            The predictions as sparse tensor (batch_size * num_output_pos, num_out_values).
        """
        assert x.ndim == 3
        assert pos.ndim == 3

        # create query
        query = self.query(self.pos_embed(pos))

        # project to perceiver dim
        x = self.proj(x)

        # perceiver
        for block in self.blocks:
            query = block(q=query, kv=x, **(block_kwargs or {}))

        # predict value
        query = self.pred(self.norm(query))

        # dense tensor (batch_size, max_num_points, dim) -> sparse tensor (batch_size * num_points, dim)
        query = einops.rearrange(query, "batch_size max_num_points dim -> (batch_size max_num_points) dim")
        if len(pos) == 1:
            # batch_size=1 -> no padding is needed
            pass
        else:
            if unbatch_mask is not None:
                query = query[unbatch_mask]
        assert isinstance(query, torch.Tensor)
        return query
