from __future__ import annotations
from typing import List
import torch
import torch.nn as nn

class ControlUnitLearner(nn.Module):
    def __init__(self, plm_name: str, out_dim: int):
        super().__init__()
        self.plm_name = plm_name
        try:
            from transformers import AutoModel, AutoTokenizer, AutoConfig
            self._tok = AutoTokenizer.from_pretrained(plm_name)
            self._cfg = AutoConfig.from_pretrained(plm_name)
            self.encoder = AutoModel.from_pretrained(plm_name)
            hidden = self._cfg.hidden_size
            print(f"[CUL] Loaded HF encoder '{plm_name}' hidden_size={hidden}")
        except Exception as e:
            print(f"[CUL][WARN] Failed to load HF model '{plm_name}': {e}. Falling back to a small character-CNN.")
            self._tok = None
            hidden = 256
            self.encoder = nn.Sequential(
                nn.Conv1d(1, 64, 7, padding=3), nn.ReLU(),
                nn.Conv1d(64, 128, 5, padding=2), nn.ReLU(),
                nn.AdaptiveAvgPool1d(1), nn.Flatten(), nn.Linear(128, hidden),
            )
        self.proj = nn.Linear(hidden, out_dim)

    def forward(self, texts: List[str], device: torch.device) -> torch.Tensor:
        if hasattr(self, "_tok") and self._tok is not None:
            batch = self._tok(texts, padding=True, truncation=True, return_tensors="pt").to(device)
            out = self.encoder(**batch)
            token_embeddings = out.last_hidden_state
            input_mask_expanded = batch["attention_mask"].unsqueeze(-1).type_as(token_embeddings)
            sum_embeddings = (token_embeddings * input_mask_expanded).sum(dim=1)
            sum_mask = input_mask_expanded.sum(dim=1).clamp(min=1e-9)
            mean_pooled = sum_embeddings / sum_mask
            logits = self.proj(mean_pooled)
            return logits
        mx = max(len(t) for t in texts) if texts else 1
        pad = torch.zeros((len(texts), 1, mx), dtype=torch.float32, device=device)
        for i, t in enumerate(texts):
            arr = torch.tensor([ord(c) % 128 for c in t[:mx]], dtype=torch.float32, device=device)
            pad[i, 0, : arr.numel()] = arr / 127.0
        feat = self.encoder(pad)
        logits = self.proj(feat)
        return logits

    def set_finetune(self, mode: str, last_slice: str = "-2:"):
        for p in self.encoder.parameters():
            p.requires_grad = False
        if mode == "none":
            for p in self.proj.parameters():
                p.requires_grad = True
            return
        if mode == "all":
            for p in self.encoder.parameters():
                p.requires_grad = True
            for p in self.proj.parameters():
                p.requires_grad = True
            return
        layers = None
        for name in ["encoder.layer", "encoder.layers", "layer"]:
            obj = getattr(self.encoder, name.split(".")[0], None)
            if obj is not None:
                try:
                    layers = getattr(obj, name.split(".")[-1])
                    break
                except Exception:
                    pass
        if layers is None:
            print("[CUL][WARN] Cannot locate encoder layers; falling back to proj-only tuning.")
            for p in self.proj.parameters():
                p.requires_grad = True
            return
        try:
            if ":" in last_slice:
                s, e = last_slice.split(":")
                s = int(s) if s.strip() != "" else None
                e = int(e) if e.strip() != "" else None
                sel = list(range(len(layers)))[slice(s, e)]
            else:
                sel = [int(last_slice)]
        except Exception:
            sel = [len(layers) - 1]
        for i, block in enumerate(layers):
            req = i in sel
            for p in block.parameters():
                p.requires_grad = req
        for p in self.proj.parameters():
            p.requires_grad = True