import torch
import torch.nn as nn
import torch.nn.functional as F

class RouterST(nn.Module):
    """
    Router on top of precomputed sentence embeddings:

      - prompt_emb: [B, C] from sentence-transformers
      - model_embeddings: [M, C] trainable, init from SBERT on model cards
      - domain_head: prompt_emb -> logits over domains

    Routing:
      score(x, m) = dot(h_q(x), e_m)

    Training:
      - Domain loss: CE over domains
      - Model loss: CE over *all* models (no domain mask)
    """

    def __init__(self, model_embs_init: torch.Tensor,
                 model_domains: torch.Tensor,
                 num_domains: int):
        super().__init__()
        # [M, C]
        self.model_embeddings = nn.Parameter(model_embs_init.clone())
        self.emb_dim = model_embs_init.size(1)
        self.num_models = model_embs_init.size(0)
        self.num_domains = num_domains

        # [M] (we still keep this for possible gating / analysis)
        self.model_domains = model_domains.clone().long()

        # Domain classifier
        self.domain_head = nn.Linear(self.emb_dim, num_domains)

    @torch.no_grad()
    def route(self, prompt_emb: torch.Tensor, top_k_domains: int = 1) -> int:
        """
        prompt_emb: [C] or [1, C]
        returns: best_model_index (int)
        """
        self.eval()
        device = self.model_embeddings.device

        if prompt_emb.dim() == 1:
            prompt_emb = prompt_emb.unsqueeze(0)  # [1, C]
        prompt_emb = prompt_emb.to(device)

        # Domain gating
        logits_dom = self.domain_head(prompt_emb)       # [1, D]
        probs_dom = F.softmax(logits_dom, dim=-1)[0]    # [D]

        top_k = min(top_k_domains, self.num_domains)
        top_domains = torch.topk(probs_dom, k=top_k).indices.tolist()

        # Candidate models in top domains
        cand_mask = torch.zeros(self.num_models, dtype=torch.bool, device=device)
        for d in top_domains:
            cand_mask |= (self.model_domains == d)      # [M]
        if not cand_mask.any():
            cand_mask[:] = True

        cand_indices = cand_mask.nonzero(as_tuple=True)[0]  # [M']
        cand_embs = self.model_embeddings[cand_indices]     # [M', C]

        scores = (prompt_emb @ cand_embs.T)[0]              # [M']
        best_local = scores.argmax().item()
        best_global = int(cand_indices[best_local].item())
        return best_global

    def training_step(
        self,
        batch_prompt_embs: torch.Tensor,  # [B, C]
        batch_model_idx: torch.Tensor,    # [B]
        batch_domains: torch.Tensor,      # [B]
        optimizer: torch.optim.Optimizer,
        domain_loss_weight: float = 0.1,
    ) -> float:
        """
        Vectorised training step:

        - Domain loss: CE(domain_head(h_q), domain_labels)
        - Model loss: CE over ALL models (no domain mask)
        """
        self.train()
        device = self.model_embeddings.device

        h_q = batch_prompt_embs.to(device)   # [B, C]
        m_idx = batch_model_idx.to(device)   # [B]
        d_idx = batch_domains.to(device)     # [B]

        # Domain loss
        logits_dom = self.domain_head(h_q)   # [B, D]
        domain_loss = F.cross_entropy(logits_dom, d_idx)

        # Model loss over all M
        scores_all = h_q @ self.model_embeddings.T   # [B, M]
        ce_loss = F.cross_entropy(scores_all, m_idx)

        loss = ce_loss + domain_loss_weight * domain_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return float(loss.item())
