import json
import pathlib
from typing import Literal

import torch
import torch.nn.functional as f
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError
from torch import Tensor

from clcp.models.base import Model
from ml_utils import log


class CrossEncoder(Model, arch="cross-encoder"):
    requires_paired_inp = True

    def __init__(self, name: str) -> None:
        super().__init__()
        self.encoder = self.get_model(name)
        self.entailment_idx = self.encoder.config.label2id.get("entailment", 0)
        log.debug(f"{name=}; {self.entailment_idx=}")
        self.setup_device()

        self.is_qwen_reranker = "Qwen3-Reranker" in name
        if self.is_qwen_reranker:
            self.token_true_id, self.token_false_id = 9693, 2152
            log.debug(f"{self.is_qwen_reranker=}")

    def _qwen3_reranker_forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        logits = self.encoder(input_ids=input_ids, attention_mask=attention_mask).logits[:, -1, :]  # (B, V)
        true_vector, false_vector = logits[:, self.token_true_id], logits[:, self.token_false_id]  # (B, ), (B, )
        scores = torch.stack([false_vector, true_vector], dim=1)  # (B, 2)
        scores = torch.nn.functional.softmax(scores, dim=1)  # (B, 2)
        return scores[:, 1]  # (B, ) p("yes")

    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        if self.is_qwen_reranker:
            return self._qwen3_reranker_forward(input_ids=input_ids, attention_mask=attention_mask)
        return self.encoder(input_ids=input_ids, attention_mask=attention_mask).logits[:, self.entailment_idx]


class DualEncoder(Model, arch="dual-encoder"):
    requires_paired_inp = False

    def __init__(self, name: str) -> None:
        super().__init__()
        self.encoder = self.get_model(name)
        self.pool_mode = get_pooling_mode(name=name)
        log.debug(f"{name=}; {self.pool_mode=}")

        self.setup_device()

    def cls_embedding(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        embs = self.encoder(input_ids, attention_mask)["last_hidden_state"]  # (B*2, S, E)
        cls = embs[:, 0, :]  # (B*2, E)
        return f.normalize(cls, p=2, dim=1)  # ||E||_2 = 1

    def mean_embedding(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        embs = self.encoder(input_ids, attention_mask)["last_hidden_state"]  # (B*2, S, E)
        mask = attention_mask.unsqueeze(-1)  # (B*2, S) -> (B*2, S, 1)
        num = (embs * mask).sum(1)  # (B*2, E)
        den = mask.sum(1)  # (B*2, 1)
        return f.normalize(num / den, p=2, dim=1)  # (B*2, E) -> ||E||_2 = 1

    def last_embedding(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        embs = self.encoder(input_ids, attention_mask)["last_hidden_state"]  # (B*2, S, E)
        last_emb = embs[:, -1, :]  # (B*2, E)
        return f.normalize(last_emb, p=2, dim=1)  # ||E||_2 = 1

    def get_embedding(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        if self.pool_mode == "cls":
            return self.cls_embedding(input_ids, attention_mask)
        if self.pool_mode == "mean":
            return self.mean_embedding(input_ids, attention_mask)
        if self.pool_mode == "last":
            return self.last_embedding(input_ids, attention_mask)
        msg = f"Unknown pooling mode: {self.pool_mode}"
        raise ValueError(msg)

    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        # Incoming batch is a concat of tokenized (docs_a, docs_b) pairs
        cls = self.get_embedding(input_ids=input_ids, attention_mask=attention_mask)  # (B*2, E)
        cls_a, cls_b = cls.chunk(2)  # (B, E), (B, E)
        return (cls_a * cls_b).sum(dim=1)  # paired dot product (B,)


def get_pooling_mode(name: str) -> Literal["mean", "cls", "last"]:
    if "Qwen3-Embedding" in name or "e5-mistral-7b-instruct" in name:
        return "last"
    if "e5-" in name:
        return "mean"
    if any(prefix in name for prefix in ("gte-", "bge-", "Qwen3")):
        return "cls"
    try:
        pool_cfg_path = hf_hub_download(name, "1_Pooling/config.json")
        pool_cfg = json.loads(pathlib.Path(pool_cfg_path).read_text(encoding="utf-8"))
        return next(k.split("_")[-2] for k, v in pool_cfg.items() if k.startswith("pooling_mode_") and v)
    except EntryNotFoundError:
        return "cls"
