# models/transformer_agent.py
"""
Transformer-based Q-network for ChronosCore.
This module provides a reusable Agent class composed of:
  - per-task token embedding (slack tokens)
  - Transformer encoder (stack of nn.TransformerEncoderLayer)
  - per-task Q-head that maps each task representation to a scalar Q-value
  - optional global idle-head producing the Q for the idle action

Interface:
    agent = TransformerAgent(cfg, n_tasks)
    q_values = agent(state_tuple)   # returns torch.Tensor shape [n_tasks + 1]
Notes:
  - state_tuple is a tuple of quantized slack ints (length = n_tasks)
  - keeps device-awareness via cfg.device
  - easy to extend with relative positional biases in positional_utils.py
"""
from dataclasses import dataclass
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from configs import Config

class PerTaskMLP(nn.Module):
    """Small MLP applied per-task to produce scalar Q"""
    def __init__(self, in_dim: int, hidden: int = None):
        super().__init__()
        hidden = hidden or max(32, in_dim // 2)
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, x):
        # x: [n_tasks, dim] -> returns [n_tasks]
        return self.net(x).squeeze(-1)

class GlobalIdleHead(nn.Module):
    """An optional head computing Q(idle) from flattened per-task reps"""
    def __init__(self, n_tasks: int, feat_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_tasks * feat_dim, feat_dim),
            nn.ReLU(),
            nn.Linear(feat_dim, 1)
        )

    def forward(self, task_feats):
        # task_feats: [n_tasks, feat_dim]
        flat = task_feats.flatten()
        return self.net(flat).squeeze(-1)

class TransformerAgent(nn.Module):
    def __init__(self, cfg: Config, n_tasks: int, vocab_size: int = None):
        """
        cfg: Config object
        n_tasks: number of tasks (sequence length)
        vocab_size: number of quantized slack tokens (optional)
        """
        super().__init__()
        self.cfg = cfg
        self.device = cfg.device
        self.n_tasks = n_tasks
        self.latent_dim = cfg.latent_dim
        self.vocab_size = vocab_size or (cfg.n_quanta + 10)

        # token embedding and optional small MLP for initial projection
        self.token_emb = nn.Embedding(self.vocab_size, self.latent_dim)
        # learned per-position biases (keeps order info but small)
        self.pos_bias = nn.Parameter(torch.randn(n_tasks, self.latent_dim) * 0.01)

        # transformer encoder
        enc_layer = nn.TransformerEncoderLayer(
            d_model=self.latent_dim,
            nhead=max(1, cfg.transformer_heads),
            dim_feedforward=cfg.transformer_ffn_dim,
            dropout=0.1,
            activation="relu"
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=max(1, cfg.transformer_layers))

        # output heads
        self.per_task_head = PerTaskMLP(self.latent_dim)
        self.idle_head = GlobalIdleHead(n_tasks, self.latent_dim)

        # small initializer
        self._init_weights()

        self.to(self.device)

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, state: Tuple[int, ...]) -> torch.Tensor:
        """
        state: tuple of length n_tasks of int tokens
        returns: q_values tensor of length n_tasks + 1 (last is idle)
        """
        device = self.device
        tokens = torch.tensor(list(state), dtype=torch.long, device=device)  # [n_tasks]
        x = self.token_emb(tokens)  # [n_tasks, latent_dim]
        x = x + self.pos_bias       # add learned positional bias

        # transformer expects [seq_len, batch, dim]; set batch=1
        x_in = x.unsqueeze(1)       # [n_tasks, 1, dim]
        out = self.encoder(x_in)    # [n_tasks, 1, dim]
        out = out.squeeze(1)        # [n_tasks, dim]

        per_task_q = self.per_task_head(out)   # [n_tasks]
        idle_q = self.idle_head(out)          # scalar
        q_all = torch.cat([per_task_q, idle_q.unsqueeze(0)], dim=0)  # [n_tasks+1]
        return q_all
