"""Embedders used to build ACEv2 model."""

from typing import Any, Callable, List, Optional, Tuple

import torch

from src.models.utils import (
    build_mlp_with_linear_skipcon,
    positional_encoding_init,
)
from src.utils import DataAttr


class Embedder(torch.nn.Module):
    """
    Embeds context pairs (xc, yc), buffer pairs (xb, yb), or targets xt into a
    shared D-dimensional space, with an additional learnable marker indicating
    context, target, or buffer.
    
    The embedder uses separate MLPs for x and y values, and adds a learned
    marker embedding to distinguish between different data modes (context, buffer, target).
    """

    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        hidden_dim: int,
        out_dim: int,
        depth: int,
        mlp_builder: Callable[..., torch.nn.Module] = build_mlp_with_linear_skipcon,
        pos_emb_init: bool = False,
    ):
        """
        Initialize the Embedder module.
        
        Args:
            dim_x: Input dimension for x values
            dim_y: Input dimension for y values
            hidden_dim: Hidden dimension for the MLPs
            out_dim: Output embedding dimension
            depth: Number of layers in the MLPs
            mlp_builder: Function to build MLP networks
            pos_emb_init: Whether to initialize marker embeddings with positional encoding
        """
        super().__init__()
        self.marker_lookup = {"target": 0, "context": 1, "buffer": 2}
        self.x_embed = mlp_builder(dim_x, hidden_dim, out_dim, depth)
        self.y_embed = mlp_builder(dim_y, hidden_dim, out_dim, depth)
        self.marker_embed = torch.nn.Embedding(3, out_dim)

        if pos_emb_init:
            self.marker_embed.weight = positional_encoding_init(3, out_dim, 2)
    
    def _get_marker_embedding(
        self, 
        batch_size: int,
        marker_type: str, 
        device: torch.device
     ) -> torch.Tensor:
        """Get marker embedding for the specified type."""
        marker = self.marker_lookup[marker_type]
        marker_idx = torch.full(
            (batch_size, 1), marker, dtype=torch.long, device=device
        )
        return self.marker_embed(marker_idx)

    def embed_context(self, batch: DataAttr) -> torch.Tensor:
        """Embed context pairs (xc, yc) with context marker."""
        x_emb = self.x_embed(batch.xc)
        y_emb = self.y_embed(batch.yc)
        marker_emb = self._get_marker_embedding(x_emb.size(0), "context", x_emb.device)
        return x_emb + y_emb + marker_emb

    def embed_buffer(self, batch: DataAttr) -> torch.Tensor:
        """Embed buffer pairs (xb, yb) with buffer marker."""
        x_emb = self.x_embed(batch.xb)
        y_emb = self.y_embed(batch.yb)
        marker_emb = self._get_marker_embedding(x_emb.size(0), "buffer", x_emb.device)
        return x_emb + y_emb + marker_emb

    def embed_target(self, batch: DataAttr) -> torch.Tensor:
        """Embed target inputs (xt) with target marker."""
        x_emb = self.x_embed(batch.xt)
        marker_emb = self._get_marker_embedding(x_emb.size(0), "target", x_emb.device)
        return x_emb + marker_emb

