# chronoscore/encoder.py
"""
ChronosCore encoder:
 - quantize slack state into integer tokens
 - per-task token embedding
 - Transformer encoder that processes sequence-of-tasks and returns per-task features

Interface:
   encoder = ChronosEncoder(cfg)
   features = encoder(state_tuple)  # returns tensor [n_tasks, latent_dim]
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from configs import Config
from typing import Tuple

class ChronosEncoder(nn.Module):
    def __init__(self, cfg: Config, n_tasks: int):
        super().__init__()
        self.n_tasks = n_tasks
        self.latent_dim = cfg.latent_dim
        self.vocab_size = cfg.n_quanta + 5  # +some slack reserve
        self.token_emb = nn.Embedding(self.vocab_size, self.latent_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=cfg.transformer_heads, dim_feedforward=cfg.transformer_ffn_dim)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=cfg.transformer_layers)
        self.pos_emb = nn.Parameter(torch.randn(n_tasks, self.latent_dim) * 0.01)

    def forward(self, state: Tuple[int, ...]) -> torch.Tensor:
        """
        state: tuple of quantized slack ints length n_tasks
        returns: tensor [n_tasks, latent_dim]
        """
        device = self.pos_emb.device
        tokens = torch.tensor(list(state), dtype=torch.long, device=device).unsqueeze(1)  # [n_tasks, 1]
        x = self.token_emb(tokens).squeeze(1)            # [n_tasks, latent_dim]
        x = x + self.pos_emb                             # add positional biases (learned)
        # transformer expects [seq_len, batch, dim]; use batch=1
        x_t = x.unsqueeze(1)  # [seq, 1, dim]
        out = self.transformer(x_t)  # [seq, 1, dim]
        out = out.squeeze(1)         # [seq, dim]
        return out
