import torch
import torch.nn.functional as f
from torch import Tensor, nn

from clcp.models.base import Model
from ml_utils import log


def freeze_mdl(mdl) -> None:
    for param in mdl.parameters():
        param.requires_grad = False
    mdl.eval()


class CrossEncoderNLI(Model, arch="cross-encoder-nli"):
    requires_paired_inp = True

    def __init__(self, backbone: str) -> None:
        super().__init__()
        self.encoder = self.get_model(name=backbone)
        self.emb_dim = self.encoder.config.hidden_size
        self.encoder.config.label2id = {"entailment": 0}
        self.encoder.config.id2label = {0: "entailment"}
        self.clf_head = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(self.emb_dim, self.emb_dim),
            nn.GELU(),
            nn.LayerNorm(self.emb_dim),
            nn.Linear(self.emb_dim, 1),
        )
        self.setup_device()

    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        embs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]  # (B,S,E)
        cls = embs[:, 0, :]  # (B,E)
        return self.clf_head(cls).squeeze(-1)  # (B,)


class CrossEncoderNLITriplet(Model, arch="cross-encoder-nli-triplet"):
    requires_paired_inp = True

    def __init__(self, backbone: str) -> None:
        super().__init__()
        self.encoder = self.get_model(name=backbone)
        self.emb_dim = self.encoder.config.hidden_size
        self.encoder.config.label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
        self.encoder.config.id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
        self.clf_head = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(self.emb_dim, self.emb_dim),
            nn.GELU(),
            nn.LayerNorm(self.emb_dim),
            nn.Linear(self.emb_dim, 3),
        )
        self.setup_device()

    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        embs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]  # (B,S,E)
        cls = embs[:, 0, :]  # (B,E)

        if self.training:
            return self.clf_head(cls)  # (B, 3)

        logits = self.clf_head(cls)
        self._last_full_logits = logits
        return logits[:, 0] - torch.logsumexp(logits[:, 1:], dim=-1)  # (B, ) entailment logit

    def loss_fn(self, yhs: Tensor, ys: Tensor) -> Tensor:
        if not self.training:
            yhs = self._last_full_logits
        return f.cross_entropy(yhs, ys.long())
