import math
from typing import Optional, Tuple, Dict, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import gc
from torch import Tensor
from torch.distributions import Categorical
from torchrl.modules import MaskedCategorical

from model.module import (
    MyTransformerEncoderLayer,
    make_transformer_block_mask,
    aggregate_over_valid_dims,
    build_mlp,
    TransformerBlock,
)

# Constants
DEFAULT_STD_MIN = 1e-4
DEFAULT_NUM_COMPONENTS = 20
DEFAULT_DROPOUT = 0.0
EXPLORATION_EPSILON = 1e-8
FEEDFORWARD_MULTIPLIER = 4


def create_projection(dim_in: int, dim_out: int) -> nn.Module:
    """Create projection layer or identity if dimension matches."""
    return nn.Linear(dim_in, dim_out) if dim_in != dim_out else nn.Identity()


def make_encoder_pad_mask(
    x_mask: Tensor,
    y_mask: Tensor,
    N: int,
    q_mask: Optional[Tensor] = None,  # dummy for API consistency
) -> Tensor:
    """Create padding mask for encoder.

    Args:
        x_mask: Input dimension mask [B, dx_max]
        y_mask: Output dimension mask [B, dy_max]
        N: Number of context + target points
        q_mask: Query mask (unused, kept for API compatibility)

    Returns:
        Padding mask [B, N, dx_max + dy_max] where True means "ignore/pad"
    """
    B, dx_max = x_mask.shape
    _, dy_max = y_mask.shape

    # Expand masks to [B, N, dim]
    y_expanded = y_mask.unsqueeze(1).expand(B, N, dy_max)
    x_expanded = x_mask.unsqueeze(1).expand(B, N, dx_max)

    # Concatenate and invert (True = ignore)
    return ~torch.cat([x_expanded, y_expanded], dim=-1)


class DimensionWiseEmbedder(nn.Module):
    """Embeds x and y values dimension-wise with learnable missing value markers."""

    def __init__(self, dim_mlp: int, max_x_dim: int, max_y_dim: int):
        super().__init__()
        self.dim_mlp = dim_mlp
        self.mlp_x = nn.Linear(1, dim_mlp)
        self.mlp_y = nn.Linear(1, dim_mlp)
        self.missing_x_marker = nn.Parameter(torch.randn(1, dim_mlp))
        self.missing_y_marker = nn.Parameter(torch.randn(1, dim_mlp))

    @staticmethod
    def _fill_missing_with_marker(
        emb: Tensor, mask: Tensor, missing_marker: Tensor
    ) -> Tensor:
        """Replacing missing dims with learnable markers."""
        B, N, d, H = emb.shape
        mask = mask[:, None, :, None].expand(B, N, d, H)
        missing_marker = missing_marker.expand(B, N, d, H)
        emb = torch.where(mask, emb, missing_marker)
        return emb

    def _embed_x(
        self, x: Tensor, query_mask: Optional[Tensor], B: int, N: int, dx_max: int
    ) -> Tensor:
        """Embed x values, replacing missing dims with learnable markers."""
        x = x.reshape(B * N * dx_max, 1)
        x_emb = self.mlp_x(x).view(B, N, dx_max, self.dim_mlp)
        if query_mask is not None:
            x_emb = self._fill_missing_with_marker(
                x_emb, query_mask, self.missing_x_marker
            )
        return x_emb

    def _embed_y(
        self,
        y: Optional[Tuple],
        observed_mask: Optional[Tensor],
        B: int,
        N: int,
        dy_max: int,
    ) -> Tensor:
        """Embed y values, replacing missing dims with learnable markers."""
        if y is None:
            return self.missing_y_marker.expand(B, N, dy_max, self.dim_mlp)
        else:
            y = y.reshape(B * N * dy_max, 1)
            y_emb = self.mlp_y(y).view(B, N, dy_max, self.dim_mlp)
            if observed_mask is not None:
                y_emb = self._fill_missing_with_marker(
                    y_emb, observed_mask, self.missing_y_marker
                )
            return y_emb

    def forward(
        self,
        x: Tensor,
        x_mask: Tensor,
        y_mask: Tensor,
        y: Optional[Tensor] = None,
        y_obs_mask: Optional[Tensor] = None,
        q_mask: Optional[Tensor] = None,
    ) -> Tensor:
        """
        Embed input and output values with missing value handling.

        Args:
            x: Input values [B, N, dx_max]
            x_mask: Valid x dimensions [B, dx_max]
            y_mask: Valid y dimensions [B, dy_max]
            y: Output values [B, N, dy_max] (None for targets)
            q_mask: Query mask for missing x handling [B, dx_max]

        Returns:
            Embedded values [B, N, dx_max + dy_max, dim_mlp]
        """
        B, N, dx_max = x.shape
        _, dy_max = y_mask.shape

        x_emb = self._embed_x(x, q_mask, B, N, dx_max)
        y_emb = self._embed_y(y, y_obs_mask, B, N, dy_max)

        return torch.cat([x_emb, y_emb], dim=2)


class FourierFeatures(nn.Module):
    """Fourier feature embeddings for positional encoding."""

    def __init__(
        self, dim_mlp: int, num_frequencies: int = 10, freq_scale: str = "log"
    ):
        super().__init__()
        self.num_frequencies = num_frequencies

        # Initialize frequency scales
        if freq_scale == "log":
            self.freqs = 2 ** torch.arange(num_frequencies).float()
        elif freq_scale == "linear":
            self.freqs = torch.arange(1, num_frequencies + 1).float()
        else:
            raise ValueError(f"Unknown freq_scale: {freq_scale}")

        self.mlp = nn.Linear(2 * num_frequencies, dim_mlp)

    def forward(self, positions: Tensor, sin_first: bool = True) -> Tensor:
        """
        Compute Fourier features for given positions.

        Args:
            positions: Position indices [..., ]
            sin_first: Whether to put sin before cos in features

        Returns:
            Fourier features [..., dim_mlp]
        """
        positions = positions.float()
        pos_expanded = positions.unsqueeze(-1)

        # Compute sin/cos arguments
        args = 2 * math.pi * pos_expanded * self.freqs.to(positions.device)

        # Concatenate sin and cos features
        if sin_first:
            fourier_feats = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        else:
            fourier_feats = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

        return self.mlp(fourier_feats)


class DimensionAgnosticEncoder(nn.Module):
    """Encoder that handles variable-dimension inputs with learnable or Fourier IDs."""

    def __init__(
        self,
        dim_mlp: int,
        dim_attn: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        max_x_dim: int,
        max_y_dim: int,
        use_learnable_ids: bool = True,
        id_value_aggregator: str = "hadamard",
        use_target_y_id: bool = True,
        dim_hidden: int = 128,
        **kwargs,
    ):
        super().__init__()
        self.dim_mlp = dim_mlp
        self.dim_attn = dim_attn
        self.use_target_y_id = use_target_y_id
        self.id_value_aggregator = id_value_aggregator
        self.use_learnable_ids = use_learnable_ids

        # Initialize position embeddings
        if use_learnable_ids:
            self.id_x = nn.Parameter(torch.randn(max_x_dim, dim_mlp))
            self.id_y = nn.Parameter(torch.randn(max_y_dim, dim_mlp))
        else:
            self.ffs = FourierFeatures(dim_mlp, num_frequencies=10, freq_scale="log")
            self.register_buffer("pos_x", torch.arange(max_x_dim))
            self.register_buffer("pos_y", torch.arange(max_y_dim))

        # Projections
        self.in_proj = create_projection(dim_mlp, dim_attn)
        self.out_proj = create_projection(dim_attn, dim_mlp)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim_attn,
            nhead=nhead,
            dim_feedforward=FEEDFORWARD_MULTIPLIER * dim_attn,
            dropout=dropout,
            batch_first=True,
            activation="relu",
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # MLP for value/id aggregation if needed
        if id_value_aggregator == "mlp":
            self.mlp = build_mlp(2 * dim_mlp, dim_hidden, dim_mlp, 3)

    def forward(
        self,
        tokens: Tensor,
        x_mask: Tensor,
        y_mask: Tensor,
        pad_mask: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """
        Process variable-dimension tokens through transformer.

        Args:
            tokens: Input tokens [B, N, D, H]
            x_mask: Valid x dimensions [B, dx_max]
            y_mask: Valid y dimensions [B, dy_max]
            pad_mask: Padding mask for transformer
            mask: Attention mask for transformer

        Returns:
            Processed tokens [B, N, D, H], x_ids [dx_max, H], y_ids [dy_max, H]
        """
        B, N, D, H = tokens.shape
        dx_max = x_mask.shape[-1]
        dy_max = y_mask.shape[-1]

        assert D == dx_max + dy_max, f"Dimension mismatch: {D} != {dx_max + dy_max}"

        # Apply transformer
        seq_in = self.in_proj(tokens).view(B * N, D, -1)
        seq_out = self.transformer(seq_in, mask=mask, src_key_padding_mask=pad_mask)
        seq_out = self.out_proj(seq_out)

        # Get position embeddings
        x_id, y_id = self._get_position_embeddings(dx_max, dy_max)

        # Aggregate value and ID embeddings
        seq_out = self._aggregate_embeddings(seq_out, x_id, y_id, B, N, dx_max, dy_max)

        return seq_out.view(B, N, D, H), x_id, y_id

    def _get_position_embeddings(
        self, dx_max: int, dy_max: int
    ) -> Tuple[Tensor, Tensor]:
        """Get position embeddings for current dimensions."""
        if self.use_learnable_ids:
            return self.id_x[:dx_max], self.id_y[:dy_max]
        else:
            x_id = self.ffs(self.pos_x[:dx_max], sin_first=True)
            y_id = self.ffs(self.pos_y[:dy_max], sin_first=False)
            return x_id, y_id

    def _aggregate_embeddings(
        self,
        seq_out: Tensor,
        x_id: Tensor,
        y_id: Tensor,
        B: int,
        N: int,
        dx_max: int,
        dy_max: int,
    ) -> Tensor:
        """Aggregate value and position embeddings."""
        # Expand position embeddings
        x_id_expanded = x_id.unsqueeze(0).expand(B * N, dx_max, self.dim_mlp)
        y_id_expanded = y_id.unsqueeze(0).expand(B * N, dy_max, self.dim_mlp)
        id_embeddings = torch.cat([x_id_expanded, y_id_expanded], dim=1)

        # Aggregate based on method
        if self.id_value_aggregator == "hadamard":
            return seq_out * id_embeddings
        elif self.id_value_aggregator == "mlp":
            mlp_in = torch.cat([id_embeddings, seq_out], dim=-1).flatten(0, 1)
            return self.mlp(mlp_in).view(B * N, -1, self.dim_mlp)
        else:
            raise NotImplementedError(f"Unknown aggregator: {self.id_value_aggregator}")


class Decoder(nn.Module):
    """Unified decoder for prediction and optimization tasks."""

    def __init__(
        self,
        dim_mlp: int,
        dim_attn: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        dim_hidden: int,
        depth: int,  # dummy for compatibility
        max_x_dim: int,
        max_y_dim: int,
        num_components: int,
        std_min: float = DEFAULT_STD_MIN,
        use_ar: bool = True,
        perm_ar_order: bool = True,
        my_encoder_layer: bool = True,
        tokenization_mode: str = "v1",
        transformer_based: bool = True,
        ar_cond_option: str = "input",
        learnable_token_sel_bias: bool = True,
    ):
        super().__init__()
        self.dim_mlp = dim_mlp
        self.dim_attn = dim_attn
        self.use_ar = use_ar
        self.perm_ar_order = perm_ar_order
        self.transformer_based = transformer_based
        self.ar_cond_option = ar_cond_option
        self.tokenization_mode = tokenization_mode

        # Validate inputs
        assert ar_cond_option in {"input", "output"}
        assert tokenization_mode in {"v1", "v2", "v3", "v4"}

        # Projections
        self.in_proj = create_projection(dim_mlp, dim_attn)
        self.out_proj = create_projection(dim_attn, dim_mlp)

        # Task and budget embeddings
        self.budget_proj = nn.Linear(1, dim_mlp)
        self.id_task = nn.Parameter(torch.randn(2, dim_mlp))

        # Selected token bias
        if learnable_token_sel_bias:
            self.token_selected = nn.Parameter(torch.randn(1, dim_mlp))
        else:
            self.register_buffer("token_selected", torch.randn(1, dim_mlp))

        # Initialize transformer
        if transformer_based:
            self.transformer = self._create_transformer(
                dim_attn, nhead, dropout, num_layers, my_encoder_layer
            )

        # Prediction heads
        self.prediction_head = GMMPredictionHead(
            dim_mlp, dim_hidden, 3, num_components, std_min
        )
        self.policy_head = build_mlp(dim_mlp, dim_hidden, 1, 3)

    def _create_transformer(
        self,
        dim_attn: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        my_encoder_layer: bool,
    ) -> nn.Module:
        """Create transformer encoder."""
        if self.tokenization_mode == "v1" and my_encoder_layer:
            encoder_layer = MyTransformerEncoderLayer(
                d_model=dim_attn,
                nhead=nhead,
                dim_feedforward=FEEDFORWARD_MULTIPLIER * dim_attn,
                dropout=dropout,
                batch_first=True,
                activation="relu",
            )
        else:
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=dim_attn,
                nhead=nhead,
                dim_feedforward=FEEDFORWARD_MULTIPLIER * dim_attn,
                dropout=dropout,
                batch_first=True,
                activation="relu",
            )
        return nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(
        self,
        task_type: str,
        token_val: Tensor,
        token_id: Tensor,
        x_dim_mask: Optional[Tensor] = None,
        y_dim_mask: Optional[Tensor] = None,
        q_dim_mask: Optional[Tensor] = None,
        q_chunk: Optional[Tensor] = None,
        use_budget: bool = True,
        t: Optional[int] = None,
        T: Optional[int] = None,
        epsilon: float = 1.0,
        evaluate: bool = False,
        logit_mask: Optional[Tensor] = None,  # [B, d]
    ) -> Union[Tensor, Tuple[Tensor, ...]]:
        """
        Forward pass for prediction or optimization.

        Args:
            task_type: "prediction" or "optimization"
            token_val: Value tokens [B, n, d, H]
            token_id: ID tokens [B, n, H]
            Other args specific to task type

        Returns:
            For prediction: GMM parameters [B, n, dy_max, num_components, 3]
            For optimization: (next_x, indices, logp, entropy, logits)
        """
        assert task_type in {"prediction", "optimization"}

        if task_type == "prediction":
            return self._forward_prediction(token_val, token_id, x_dim_mask, y_dim_mask)
        else:
            return self._forward_optimization(
                token_val,
                token_id,
                q_dim_mask,
                q_chunk,
                use_budget,
                t,
                T,
                epsilon,
                evaluate,
                logit_mask,
            )

    def _forward_prediction(
        self,
        token_val: Tensor,
        token_id: Tensor,
        x_dim_mask: Tensor,
        y_dim_mask: Tensor,
    ) -> Tensor:
        """Forward pass for prediction task."""
        B, n, d, H = token_val.shape

        # Prepare tokens
        token_val, token_id, token_task, _, _ = self._make_token(
            token_val, token_id, "prediction", False, False
        )

        # Create input sequence
        token_in, _, _, _, _, d, block_size = self._make_sequence_in(
            token_task, token_id, token_val
        )

        # Create mask
        mask, _ = self._make_decoder_mask(1, 1, 0, 0, d, block_size, token_in)

        # Apply transformer
        token_in = token_in.view(B * n, -1, H)
        seq_in = self.in_proj(token_in)
        seq_out = self.transformer(seq_in, mask=mask)
        seq_out_proj = self.out_proj(seq_out)

        # Extract candidates and predict
        out_candidate = self._slice_candidate(d, seq_out_proj, block_size)
        out_candidate_flat = out_candidate.view(B, n, d, -1)

        return self.prediction_head(out_candidate_flat, x_dim_mask, y_dim_mask)

    def _forward_optimization(
        self,
        token_val: Tensor,
        token_id: Tensor,
        q_dim_mask: Tensor,
        q_chunk: Tensor,
        use_budget: bool,
        t: Optional[int],
        T: Optional[int],
        epsilon: float,
        evaluate: bool,
        logit_mask: Optional[Tensor] = None,  # [B, d]
    ) -> Tuple[Tensor, ...]:
        """Forward pass for optimization task."""
        B, n, d, H = token_val.shape
        use_ar = self.use_ar

        # Prepare tokens
        token_val, token_id, token_task, token_global, token_selected = (
            self._make_token(
                token_val, token_id, "optimization", use_ar, use_budget, t, T
            )
        )

        # Optionally permute for AR
        if use_ar and self.perm_ar_order:
            token_val, perm_idx = self._permute(token_val, n, d)

        # Create input sequence
        token_in, _, _, _, _, d, block_size = self._make_sequence_in(
            token_task, token_id, token_val, token_global, token_selected
        )

        # Create mask
        mask, _ = self._make_decoder_mask(
            1, 1, 0, 1 if use_ar else 0, d, block_size, token_in
        )

        if use_ar:
            results = self._forward_autoregressive(
                B, n, d, H, token_in, mask, block_size, epsilon, evaluate
            )
        else:
            results = self._forward_factorized(
                B, n, d, H, token_in, mask, block_size, epsilon, evaluate, logit_mask
            )

        # Gather selected x and handle permutation
        selected_x = self._gather_selected_x(q_chunk, results[0], q_dim_mask, B, n, d)

        if use_ar and self.perm_ar_order:
            results = self._unpermute_results(results, perm_idx)

        return (selected_x,) + results

    def _forward_autoregressive(
        self,
        B: int,
        n: int,
        d: int,
        H: int,
        token_in: Tensor,
        mask: Tensor,
        block_size: int,
        epsilon: float,
        evaluate: bool,
    ) -> Tuple[Tensor, ...]:
        """Autoregressive policy forward pass."""
        seq_in = self.in_proj(token_in).view(B, n, -1, self.dim_attn)
        seq_len = mask.shape[0]
        start_selected = seq_len - block_size * d - 1

        chunk_indices = []
        chunk_logp = []
        chunk_entropy = []
        chunk_logits = [] if evaluate else None

        chunk_seq_in = seq_in[:, 0].clone()

        for i in range(n):
            # Process current chunk
            seq_out = self.transformer(chunk_seq_in, mask=mask)
            seq_out_candidate = self._slice_candidate(d, seq_out, block_size)
            out_candidate = self.out_proj(seq_out_candidate)
            logits = self.policy_head(out_candidate).squeeze(-1)

            # Sample
            indices, logp, entropy = self._sample(logits, epsilon)

            # Update next chunk if needed
            if i + 1 < n and self.ar_cond_option == "output":
                seq_selected = self._gather_seq_ele(
                    seq_out_candidate, indices.unsqueeze(-1)
                ).detach()
                chunk_seq_in = seq_in[:, i + 1].clone()
                chunk_seq_in[:, start_selected : start_selected + 1] = seq_selected

            # Track results
            chunk_indices.append(indices)
            chunk_logp.append(logp)
            chunk_entropy.append(entropy)
            if evaluate:
                chunk_logits.append(logits)

        # Stack results
        selected_indices = torch.stack(chunk_indices, dim=-1)
        selected_logp = torch.stack(chunk_logp, dim=-1)
        selected_entropy = torch.stack(chunk_entropy, dim=-1)
        chunk_logits = torch.stack(chunk_logits, dim=1) if evaluate else None

        return selected_indices, selected_logp, selected_entropy, chunk_logits

    def _forward_factorized(
        self,
        B: int,
        n: int,
        d: int,
        H: int,
        token_in: Tensor,
        mask: Tensor,
        block_size: int,
        epsilon: float,
        evaluate: bool,
        logit_mask: Optional[Tensor] = None,  # [B, n, d]
    ) -> Tuple[Tensor, ...]:
        """Fully factorized policy forward pass."""
        # Reshape for batch processing
        token_in = token_in.view(B * n, -1, H)
        if logit_mask is not None:
            logit_mask = logit_mask.view(B * n, -1)

        # Apply transformer
        seq_in = self.in_proj(token_in)
        seq_out = self.transformer(seq_in, mask=mask)
        seq_out_proj = self.out_proj(seq_out)

        # Extract candidates and compute logits
        out_candidate = self._slice_candidate(d, seq_out_proj, block_size)
        logits = self.policy_head(out_candidate).squeeze(-1)

        # Sample
        indices, logp, entropy = self._sample(logits, epsilon, logit_mask)

        # Reshape results
        selected_indices = indices.view(B, n)
        selected_logp = logp.view(B, n)
        selected_entropy = entropy.view(B, n)
        chunk_logits = logits.view(B, n, d) if evaluate else None

        return selected_indices, selected_logp, selected_entropy, chunk_logits

    def _gather_selected_x(
        self,
        q_chunk: Tensor,
        indices: Tensor,
        q_dim_mask: Tensor,
        B: int,
        n: int,
        d: int,
    ) -> Tensor:
        """Gather selected x values and apply mask."""
        # Gather selected x
        selected_x = self._gather_seq_ele(
            q_chunk.view(B * n, d, -1), indices.unsqueeze(-1).view(B * n, 1)
        )
        selected_x = selected_x.squeeze(1).view(B, n, -1)

        # Apply mask and sum
        selected_x *= q_dim_mask.float()
        return torch.sum(selected_x, dim=1, keepdim=True)

    def _unpermute_results(
        self, results: Tuple[Tensor, ...], perm_idx: Tensor
    ) -> Tuple[Tensor, ...]:
        """Undo permutation of results."""
        inverse_perm = perm_idx.argsort(dim=1)
        unpermuted = []

        for i, result in enumerate(results):
            if i == 0:  # indices
                unpermuted.append(torch.gather(result, 1, inverse_perm))
            elif i < 3:  # logp, entropy
                unpermuted.append(torch.gather(result, 1, inverse_perm))
            elif result is not None:  # chunk_logits
                expanded_perm = inverse_perm.unsqueeze(-1).expand_as(result)
                unpermuted.append(torch.gather(result, 1, expanded_perm))
            else:
                unpermuted.append(None)

        return tuple(unpermuted)

    def _permute(self, x: Tensor, n: int, d: int) -> Tuple[Tensor, Tensor]:
        """Randomly permute chunks for autoregressive training."""
        b, n, d, p = x.shape

        # Generate random permutation
        perm_idx = torch.argsort(torch.rand(b, n, device=x.device), dim=1)
        perm_idx_expanded = perm_idx.view(b, n, 1, 1).expand(b, n, d, p)

        # Apply permutation
        output = torch.gather(x, 1, perm_idx_expanded)
        return output, perm_idx

    def _sample(
        self,
        logits: Tensor,
        epsilon: float = 1.0,
        mask: Optional[Tensor] = None,  # [B, d]
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """Sample from categorical distribution with epsilon-greedy exploration."""
        assert torch.all(torch.isfinite(logits)), "Logits contain NaN or Inf"

        B, d = logits.shape

        # Create distribution
        if mask is None:
            probs = F.softmax(logits, dim=-1) + EXPLORATION_EPSILON
            dist = Categorical(probs=probs)
        else:
            dist = MaskedCategorical(logits=logits, mask=mask)

        # Sample or take argmax
        if np.random.rand() < epsilon:
            samples = dist.sample().clamp(min=0, max=d - 1)
        else:
            if mask is None:
                samples = logits.argmax(dim=-1)
            else:
                masked_logits = logits.masked_fill(~mask.bool(), float("-inf"))
                samples = masked_logits.argmax(dim=-1)

        logp = dist.log_prob(samples)
        entropy = dist.entropy()

        return samples, logp, entropy

    def _gather_seq_ele(self, data: Tensor, indices: Tensor) -> Tensor:
        """Gather elements along sequence dimension."""
        B, N, H = data.shape
        indices_expanded = indices.unsqueeze(-1).expand(B, -1, H)
        return torch.gather(data, dim=1, index=indices_expanded)

    def _make_sequence_in(
        self,
        token_task: Tensor,
        token_id: Tensor,
        token_val: Tensor,
        token_global: Optional[Tensor] = None,
        token_selected: Optional[Tensor] = None,
        tokenization_mode: Optional[str] = None,
    ) -> Tuple[Tensor, int, int, int, int, int, int]:
        """Create input sequence based on tokenization mode."""
        if tokenization_mode is None:
            tokenization_mode = self.tokenization_mode

        B, n, n_cand, H = token_val.shape

        # Default empty tokens
        if token_global is None:
            token_global = torch.empty((B, n, 0, H), device=token_val.device)
        if token_selected is None:
            token_selected = torch.empty((B, n, 0, H), device=token_val.device)

        # Get token counts
        n_task = token_task.shape[-2]
        n_id = token_id.shape[-2]
        n_global = token_global.shape[-2]
        n_selected = token_selected.shape[-2]

        # Build sequence based on mode
        if tokenization_mode == "v1":
            block_size = 1
            token_in = torch.cat(
                [token_task, token_id, token_global, token_selected, token_val], dim=-2
            )
        else:
            # For v2, v3, v4 modes
            token_id_expanded = token_id.expand(B, n, n_cand, H)

            if tokenization_mode == "v2":
                token_task_expanded = token_task.expand(B, n, n_cand, H)
                token_cand = torch.stack(
                    [token_val, token_id_expanded, token_task_expanded], dim=-2
                )
            else:
                token_cand = torch.stack([token_val, token_id_expanded], dim=-2)

            block_size = token_cand.shape[-2]
            token_cand = token_cand.view(B, n, -1, H)

            if tokenization_mode == "v4":
                token_in = torch.cat(
                    [token_task, token_global, token_selected, token_cand], dim=-2
                )
            else:
                token_in = torch.cat([token_global, token_selected, token_cand], dim=-2)

        return token_in, n_task, n_id, n_global, n_selected, n_cand, block_size

    def _make_decoder_mask(
        self,
        n_task: int,
        n_id: int,
        n_global: int,
        n_selected: int,
        n_cand: int,
        block_size_cand: int,
        token: Tensor,
    ) -> Tuple[Tensor, None]:
        """Create attention mask for decoder."""
        if not self.transformer_based:
            return None, None

        B, n, seq_len, H = token.shape
        mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=token.device)

        # Build block diagonal mask for candidates
        I = torch.ones(
            (block_size_cand, block_size_cand), device=token.device, dtype=torch.bool
        )
        mask_ = torch.block_diag(*[I for _ in range(n_cand)])
        start_tok_cand = seq_len - block_size_cand * n_cand
        mask[start_tok_cand:, start_tok_cand:] = mask_

        # Non-candidate tokens can be attended by all
        mask[:, :start_tok_cand] = True

        return ~mask, None

    def _slice_candidate(
        self, n_cand: int, sequence: Tensor, block_size: int
    ) -> Tensor:
        """Extract candidate tokens from sequence."""
        B, seq_len, H = sequence.shape
        start_tok_cand = seq_len - block_size * n_cand
        token_cand = sequence[:, start_tok_cand::block_size]
        assert token_cand.shape == (B, n_cand, H)
        return token_cand

    def _make_token(
        self,
        token_val: Tensor,
        token_id: Tensor,
        task_type: str,
        use_ar: bool,
        use_budget: bool,
        t: Optional[int] = None,
        T: Optional[int] = None,
    ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
        """Prepare tokens for decoder input."""
        B, n, _, H = token_val.shape
        do_optimization = task_type == "optimization"

        # Task token
        token_task = self.id_task[int(do_optimization)]
        token_task = token_task[None, None, None, :].expand(B, n, 1, H)

        # ID token
        token_id = token_id.unsqueeze(-2).expand(B, n, -1, H)

        # Optional tokens
        token_global = None
        token_selected = None

        if do_optimization:
            if use_ar:
                token_selected = self.token_selected.to(token_val)
                token_selected = (
                    token_selected.unsqueeze(0).unsqueeze(0).expand(B, n, 1, H)
                )

            if use_budget:
                assert t is not None and T is not None
                assert 0 <= t <= T

                budget_ratio = (T - t) / T
                state_in = budget_ratio * torch.ones((1, 1), device=token_val.device)
                token_global = self.budget_proj(state_in)

            if token_global is not None:
                token_global = token_global[None, None, :, :].expand(B, n, -1, H)

        return token_val, token_id, token_task, token_global, token_selected


class GMMPredictionHead(nn.Module):
    """Prediction head that outputs Gaussian Mixture Model parameters."""

    def __init__(
        self,
        dim_mlp: int,
        dim_hidden: int,
        depth: int,
        num_components: int = DEFAULT_NUM_COMPONENTS,
        std_min: float = DEFAULT_STD_MIN,
        **kwargs,
    ):
        super().__init__()
        self.std_min = std_min
        self.depth = depth
        self.num_components = num_components

        # Create heads for each mixture component
        self.heads = nn.ModuleList(
            [
                build_mlp(dim_mlp, dim_hidden, 3, depth)  # mu, std, weight
                for _ in range(num_components)
            ]
        )

    def forward(
        self,
        input: Tensor,  # [B, DY, N, H]
        x_mask: Tensor,  # [B, DX]
        y_mask: Tensor,  # [B, DY]
    ) -> Tensor:
        """
        Predict GMM parameters for each output dimension.

        Returns:
            GMM parameters [B, N, DY, num_components, 3] where 3 = (mean, std, weight)
        """
        B, DY, N, H = input.shape

        # Process through all heads
        input_flat = input.reshape(B * DY * N, H)
        outputs = [head(input_flat) for head in self.heads]

        # Stack and reshape
        outputs_cat = torch.stack(outputs).movedim(0, -1).flatten(-2, -1)
        outputs_cat = outputs_cat.view(B, DY, N, -1)

        # Apply y_mask
        y_mask_expanded = y_mask.unsqueeze(-1).unsqueeze(-1).expand_as(outputs_cat)
        masked_outputs = torch.where(y_mask_expanded, outputs_cat, torch.nan)

        # Split into components
        raw_means, raw_stds, raw_weights = torch.chunk(masked_outputs, 3, dim=-1)

        # Process components
        means = raw_means
        stds = self.std_min + (1 - self.std_min) * F.softplus(raw_stds)
        weights = F.softmax(raw_weights, dim=-1)

        # Stack and transpose
        out = torch.stack([means, stds, weights], dim=-1)
        return out.transpose(1, 2)  # [B, N, DY, num_components, 3]


class TAMO(nn.Module):
    def __init__(
        self,
        max_x_dim: int,
        max_y_dim: int,
        dim_mlp: int,
        dim_attn: int,
        nhead: int = 4,
        dropout: float = DEFAULT_DROPOUT,
        num_layers: int = 6,
        dim_hidden: int = 128,
        num_components: int = DEFAULT_NUM_COMPONENTS,
        std_min: float = DEFAULT_STD_MIN,
        id_value_aggregator: str = "hadamard",
        use_learnable_ids: bool = True,
        encoder_kwargs: Optional[Dict] = None,
        decoder_kwargs: Optional[Dict] = None,
        share_decoder: bool = True,
        fill_missing_x: bool = False,
        **kwargs,
    ):
        super().__init__()
        self.max_y_dim = max_y_dim
        self.max_x_dim = max_x_dim
        self.share_decoder = share_decoder
        self.fill_missing_x = fill_missing_x

        # Default kwargs
        if encoder_kwargs is None:
            encoder_kwargs = {"nhead": 4, "num_layers": 4}
        if decoder_kwargs is None:
            decoder_kwargs = {
                "nhead": 4,
                "num_layers": 4,
                "use_ar": True,
                "perm_ar_order": True,
                "my_encoder_layer": True,
                "tokenization_mode": "v1",
                "transformer_based": True,
                "ar_cond_option": "input",
                "learnable_token_sel_bias": True,
            }

        # Initialize components
        self.embedder = DimensionWiseEmbedder(dim_mlp, max_x_dim, max_y_dim)

        self.encoder = DimensionAgnosticEncoder(
            dim_mlp=dim_mlp,
            dim_attn=dim_attn,
            nhead=encoder_kwargs["nhead"],
            num_layers=encoder_kwargs["num_layers"],
            dropout=dropout,
            max_x_dim=max_x_dim,
            max_y_dim=max_y_dim,
            id_value_aggregator=id_value_aggregator,
            use_learnable_ids=use_learnable_ids,
        )

        self.transformer_block = TransformerBlock(
            dim_mlp=dim_mlp,
            dim_attn=dim_attn,
            nhead=nhead,
            dropout=dropout,
            num_layers=num_layers,
        )

        # Initialize decoder(s)
        decoder_params = dict(
            dim_mlp=dim_mlp,
            dim_attn=dim_attn,
            dropout=dropout,
            dim_hidden=dim_hidden,
            depth=3,
            max_x_dim=max_x_dim,
            max_y_dim=max_y_dim,
            num_components=num_components,
            std_min=std_min,
            **decoder_kwargs,
        )

        if share_decoder:
            self.decoder = Decoder(**decoder_params)
        else:
            self.decoder = nn.ModuleDict(
                {
                    "prediction": Decoder(**decoder_params),
                    "policy": Decoder(**decoder_params),
                }
            )

    def _make_tokens(
        self,
        x: Tensor,
        x_dim_mask: Tensor,
        y_dim_mask: Tensor,
        y: Optional[Tensor] = None,
        q_dim_mask: Optional[Tensor] = None,
        y_obs_mask: Optional[Tensor] = None,
    ):
        B, dx_max = x_dim_mask.shape
        _, dy_max = y_dim_mask.shape
        n = x.shape[1]

        # Embed dimensions
        dim_token = self.embedder(
            x=x,
            y=y,
            x_mask=x_dim_mask,
            y_mask=y_dim_mask,
            q_mask=q_dim_mask,
            y_obs_mask=y_obs_mask,
        )

        # Create padding mask
        pad_mask = make_encoder_pad_mask(x_dim_mask, y_dim_mask, n)
        pad_mask = pad_mask.view(B * n, -1)

        # Apply encoder
        dim_token, x_ids, y_ids = self.encoder(
            dim_token, pad_mask=pad_mask, x_mask=x_dim_mask, y_mask=y_dim_mask
        )

        # Split encoding for x and y
        dim_token_x = dim_token[:, :, :dx_max]
        dim_token_y = dim_token[:, :, dx_max:]

        x_mask_expanded = x_dim_mask[:, None, :, None].expand(
            B, n, dx_max, dim_token_x.shape[-1]
        )
        y_mask_expanded = y_dim_mask[:, None, :, None].expand(
            B, n, dy_max, dim_token_y.shape[-1]
        )

        # Aggregate over valid dimensions
        token_x = aggregate_over_valid_dims(dim_token_x, x_mask_expanded)
        token_y = aggregate_over_valid_dims(dim_token_y, y_mask_expanded)

        seq_in = token_x + token_y
        return seq_in, x_ids, y_ids

    def get_seq_in(
        self,
        x_ctx: Tensor,
        y_ctx: Tensor,
        x_tar: Tensor,
        x_dim_mask: Tensor,
        y_dim_mask: Tensor,
        q_dim_mask: Optional[Tensor] = None,
        y_dim_mask_tar: Optional[Tensor] = None,
        read_cache: bool = False,
        write_cache: bool = False,
    ):
        # Prepare masks
        x_dim_mask_tar = x_dim_mask
        if q_dim_mask is not None:
            if not self.fill_missing_x:
                # When q_dim_mask is given and with not filling missing x dims mode:
                # we only embed target x dimensions that are also in q
                x_dim_mask_tar = x_dim_mask & q_dim_mask
                q_dim_mask = None

        if y_dim_mask_tar is None:
            y_dim_mask_tar = y_dim_mask

        # NOTE: For context, providing observed y mask to handle partial observations
        seq_in_ctx, x_ids, y_ids = self._make_tokens(
            x=x_ctx,
            x_dim_mask=x_dim_mask,
            y_dim_mask=y_dim_mask_tar,
            y_obs_mask=y_dim_mask,
            y=y_ctx,
            q_dim_mask=None,
        )

        seq_in_tar, _, _ = self._make_tokens(
            x=x_tar,
            x_dim_mask=x_dim_mask_tar,
            y_dim_mask=y_dim_mask_tar,
            y=None,
            q_dim_mask=q_dim_mask,
        )

        if read_cache:
            if hasattr(self, "_seq_in_ctx_cache"):
                if write_cache:
                    self._seq_in_ctx_cache = torch.cat(
                        [self._seq_in_ctx_cache, seq_in_ctx], dim=1
                    )
            elif write_cache:  # Write new cache
                self._seq_in_ctx_cache = seq_in_ctx
            else:
                # print("No cache found to read from, using new seq_in_ctx instead.")
                seq_in = torch.cat([seq_in_ctx, seq_in_tar], dim=1)
                nc = seq_in_ctx.shape[1]
                nt = seq_in_tar.shape[1]
                N = nc + nt
                return seq_in, x_ids, y_ids, nc, N

            # print(f"Using cache for seq_in_ctx of shape {self._seq_in_ctx_cache.shape}")
            nc = self._seq_in_ctx_cache.shape[1]
            nt = seq_in_tar.shape[1]
            N = nc + nt
            seq_in = torch.cat([self._seq_in_ctx_cache, seq_in_tar], dim=1)
            return seq_in, x_ids, y_ids, nc, N

        nc = seq_in_ctx.shape[1]
        nt = seq_in_tar.shape[1]
        N = nc + nt
        seq_in = torch.cat([seq_in_ctx, seq_in_tar], dim=1)
        # print(f"Not using cache, new seq_in of shape {seq_in.shape}")
        return seq_in, x_ids, y_ids, nc, N

    def _clear_cache(self):
        if hasattr(self, "_seq_in_ctx_cache"):
            del self._seq_in_ctx_cache

        gc.collect()
        torch.cuda.empty_cache()

    def _is_cache_empty(self):
        if hasattr(self, "_seq_in_ctx_cache"):
            return False
        return True

    def _slice_context(self, x_ctx, y_ctx, read_cache, num_new: int = 1):
        if read_cache and not self._is_cache_empty():
            # If reading from cache and cache exists, slice to only new points
            x_ctx_slice = x_ctx[:, -num_new:, :]
            y_ctx_slice = y_ctx[:, -num_new:, :]
            return x_ctx_slice, y_ctx_slice
        else:
            return x_ctx, y_ctx

    def forward(
        self,
        x_ctx: Tensor,
        y_ctx: Tensor,
        x_tar: Tensor,
        x_dim_mask: Tensor,
        y_dim_mask: Tensor,
        y_dim_mask_tar: Optional[Tensor] = None,
        q_dim_mask: Optional[Tensor] = None,
        read_cache: bool = False,
        write_cache: bool = False,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """
        Encode context and target points.

        Args:
            x_ctx: Context inputs [B, nc, dx_max]
            y_ctx: Context outputs [B, nc, dy_max]
            x_tar: Target inputs [B, nt, dx_max]
            x_dim_mask: Valid x dimensions [B, dx_max]
            y_dim_mask: Valid y dimensions [B, dy_max]
            q_dim_mask: Query mask for optimization [B, dx_max]

        Returns:
            Target representations [B, nt, H], x_ids [dx_max, H], y_ids [dy_max, H]
        """
        seq_in, x_ids, y_ids, nc, N = self.get_seq_in(
            x_ctx=x_ctx,
            y_ctx=y_ctx,
            x_tar=x_tar,
            x_dim_mask=x_dim_mask,
            y_dim_mask=y_dim_mask,
            q_dim_mask=q_dim_mask,
            y_dim_mask_tar=y_dim_mask_tar,
            read_cache=read_cache,
            write_cache=write_cache,
        )

        mask = make_transformer_block_mask(x_dim_mask, y_dim_mask, N, nc)
        seq_out = self.transformer_block(seq_in, mask=mask)

        return seq_out[:, nc:], x_ids, y_ids

    def _create_dim_mask(
        self, x_mask: Tensor, x_mask_tar: Tensor, nc: int, nt: int, dx_max: int, H: int
    ) -> Tensor:
        """Create expanded dimension mask for aggregation."""
        B = x_mask.shape[0]

        if x_mask_tar is x_mask:
            return x_mask[:, None, :, None].expand(B, nc + nt, dx_max, H)
        else:
            x_mask_ctx = x_mask[:, None, :, None].expand(B, nc, dx_max, H)
            x_mask_tar = x_mask_tar[:, None, :, None].expand(B, nt, dx_max, H)
            return torch.cat([x_mask_ctx, x_mask_tar], dim=1)

    def predict(
        self,
        x_ctx: Tensor,
        y_ctx: Tensor,
        x_tar: Tensor,
        x_dim_mask: Tensor,
        y_dim_mask: Tensor,
        y_dim_mask_tar: Optional[Tensor] = None,
        read_cache: bool = False,
        write_cache: bool = False,
        **kwargs,
    ) -> Tensor:
        """
        Predict outputs for target inputs.

        Returns:
            GMM parameters [B, nt, dy_max, num_components, 3]
        """
        if read_cache:
            assert not write_cache, "Avoid messing up optimization caches."
        x_ctx, y_ctx = self._slice_context(x_ctx, y_ctx, read_cache)

        B, dy_max = y_dim_mask.shape
        y_dim_mask_tar = y_dim_mask if y_dim_mask_tar is None else y_dim_mask_tar

        # Encode
        token_val, _, token_id_y = self.forward(
            x_ctx=x_ctx,
            y_ctx=y_ctx,
            x_tar=x_tar,
            x_dim_mask=x_dim_mask,
            y_dim_mask=y_dim_mask,
            y_dim_mask_tar=y_dim_mask_tar,
            read_cache=read_cache,
            write_cache=write_cache,
        )

        # Prepare for decoder
        token_val_expanded = token_val.unsqueeze(1).expand(-1, dy_max, -1, -1)
        token_id_y_expanded = token_id_y.unsqueeze(0).expand(B, -1, -1)

        # Decode
        decoder = self.decoder if self.share_decoder else self.decoder["prediction"]
        return decoder(
            task_type="prediction",
            token_val=token_val_expanded,
            token_id=token_id_y_expanded,
            x_dim_mask=x_dim_mask,
            y_dim_mask=y_dim_mask_tar,
        )

    def action(
        self,
        x_ctx: Tensor,
        y_ctx: Tensor,
        x_dim_mask: Tensor,
        y_dim_mask: Tensor,
        q_chunk: Tensor,
        q_dim_mask: Tensor,
        t: int,
        T: int,
        y_dim_mask_tar: Optional[Tensor] = None,
        use_budget: bool = True,
        epsilon: float = 1.0,
        evaluate: bool = False,
        read_cache: bool = False,
        write_cache: bool = False,
        logit_mask: Optional[Tensor] = None,
        auto_clear_cache: bool = True,
    ) -> Tuple[Tensor, ...]:
        """
        Select next action for optimization.

        Args:
            x_ctx: Context inputs [B, nc, dx_max]
            y_ctx: Context outputs [B, nc, dy_max]
            x_dim_mask: Valid x dimensions [B, dx_max]
            y_dim_mask: Valid y dimensions [B, dy_max]
            q_chunk: Query candidates [B, n, d, dx_max]
            q_dim_mask: Valid query dimensions [B, n, dx_max]
            t: Current time step
            T: Total time steps
            use_budget: Whether to use time budget
            epsilon: Exploration rate
            evaluate: Whether to return logits

        Returns:
            (next_x, indices, log_probs, entropies, logits)
        """
        if read_cache:
            assert write_cache, f"Set write_cache=True when read_cache=True."
        x_ctx, y_ctx = self._slice_context(x_ctx, y_ctx, read_cache)

        B, n, d, dx_max = q_chunk.shape
        dy_max = y_dim_mask.shape[-1]
        nc = x_ctx.shape[1]

        # Expand context for all chunks
        x_ctx_exp = (
            x_ctx.unsqueeze(1).expand(B, n, nc, dx_max).reshape(B * n, nc, dx_max)
        )
        y_ctx_exp = (
            y_ctx.unsqueeze(1).expand(B, n, nc, dy_max).reshape(B * n, nc, dy_max)
        )
        x_mask_exp = x_dim_mask.unsqueeze(1).expand(B, n, dx_max).reshape(B * n, dx_max)
        y_mask_exp = y_dim_mask.unsqueeze(1).expand(B, n, dy_max).reshape(B * n, dy_max)
        if y_dim_mask_tar is not None:
            y_dim_mask_tar = (
                y_dim_mask_tar.unsqueeze(1).expand(B, n, dy_max).reshape(B * n, dy_max)
            )

        # Encode
        token_val, token_id_x, _ = self.forward(
            x_ctx=x_ctx_exp,
            y_ctx=y_ctx_exp,
            x_tar=q_chunk.view(B * n, d, dx_max),
            x_dim_mask=x_mask_exp,
            y_dim_mask=y_mask_exp,
            y_dim_mask_tar=y_dim_mask_tar,
            q_dim_mask=q_dim_mask.view(B * n, dx_max),
            read_cache=read_cache,
            write_cache=write_cache,
        )

        # Reshape and aggregate IDs
        token_val = token_val.view(B, n, d, -1)
        token_id_x_exp = token_id_x.unsqueeze(0).unsqueeze(0).expand(B, n, -1, -1)
        q_mask_exp = q_dim_mask.unsqueeze(-1).expand_as(token_id_x_exp)
        token_id_aggregated = aggregate_over_valid_dims(token_id_x_exp, q_mask_exp)

        # Decode
        decoder = self.decoder if self.share_decoder else self.decoder["policy"]
        results = decoder(
            task_type="optimization",
            token_val=token_val,
            token_id=token_id_aggregated,
            q_chunk=q_chunk,
            q_dim_mask=q_dim_mask,
            use_budget=use_budget,
            t=t,
            T=T,
            epsilon=epsilon,
            evaluate=evaluate,
            logit_mask=logit_mask,
        )

        if auto_clear_cache and t >= T:
            # Clear caches at the end of optimization
            self._clear_cache()

        # Detach all but log_probs
        return (
            results[0].detach(),  # next_x
            results[1].detach(),  # indices
            results[2],  # log_probs (keep gradients)
            results[3],  # entropies
            results[4],  # logits (optional)
        )
