"""Equity Tokenizer: maps daily cross-section to K factor tokens."""

import torch
import torch.nn as nn

from equitiesjepa.modules import ISAB, PMA, SAB


class EquityTokenizer(nn.Module):
    """Maps cross-section X_t ∈ ℝ^{N×F} to K factor tokens T_t ∈ ℝ^{K×d}."""

    def __init__(
        self,
        dim_input: int = 28,
        dim_hidden: int = 128,
        num_tokens: int = 24,
        num_inds: int = 32,
        num_heads: int = 4,
        num_isab_layers: int = 2,
        ln: bool = True,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.dim_input = dim_input
        self.dim_hidden = dim_hidden
        self.num_tokens = num_tokens

        self.phi = nn.Sequential(
            nn.Linear(dim_input, dim_hidden),
            nn.GELU(),
            nn.Linear(dim_hidden, dim_hidden),
            nn.GELU(),
            nn.Linear(dim_hidden, dim_hidden),
        )

        self.isab_layers = nn.ModuleList([
            ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln, dropout=dropout)
            for _ in range(num_isab_layers)
        ])

        self.pma = PMA(dim_hidden, num_heads, num_tokens, ln=ln, dropout=dropout)
        self.sab = SAB(dim_hidden, dim_hidden, num_heads, ln=ln, dropout=dropout)

    def forward(self, X, mask=None, return_pre_isab=False):
        E_pre = self.phi(X)
        Z = E_pre
        for isab in self.isab_layers:
            Z = isab(Z, mask=mask)
        U = self.pma(Z, mask=mask)
        T = self.sab(U)
        if return_pre_isab:
            return T, E_pre
        return T

