# NSP-project/model_src/oracle_utils.py
from __future__ import annotations
from typing import Dict, List, Iterable, Optional, Tuple, Literal
import numpy as np
import torch

from .generation import LoadedLM, _last_step_probs, _truncate_top_p, _truncate_min_p  # reuse

TruncStrategy = Literal["top-p", "min-p"]


class PrefixCache:
    """Cache mapping tuple(input_ids) -> probs (CPU tensor)."""
    def __init__(self):
        self._cache: Dict[Tuple[int, ...], torch.Tensor] = {}

    def get(self, key: Tuple[int, ...]) -> Optional[torch.Tensor]:
        return self._cache.get(key, None)

    def set(self, key: Tuple[int, ...], value: torch.Tensor) -> None:
        self._cache[key] = value






class NspOracle:
    """
    NSP labels under a truncation strategy.

    Returns labels as a 2D numpy array (N+1, |Σ|+1) where
      columns 0..|Σ|-1 are continuation bits (Σ in lm.sigma_tokens order),
      column  |Σ|     is the membership bit ("mem" = EOS admissible).

    Gating semantics (global feasibility):
      Given a concrete string x over Σ, if at any prefix i the realized next token x[i]
      is not admissible under the truncation at that prefix, then for the prefix that
      includes that inadmissible token (row i+1) and all subsequent prefixes, all NSP
      labels are zeros and do not depend on the LM anymore.
    """
    def __init__(self, lm: "LoadedLM", strategy: "TruncStrategy", param: float, use_cache: bool = True):
        self.lm = lm
        self.strategy = strategy
        self.param = float(param)
        self.cache = PrefixCache() if use_cache else None

        # Precompute keep-vocab masks (Σ ∪ {EOS}) on device and CPU
        V = len(self.lm.id2tok) if hasattr(self.lm, "id2tok") else self.lm.model.config.vocab_size
        dev = self.lm.device
        keep_dev = torch.zeros((V,), dtype=torch.float32, device=dev)
        keep_dev[self.lm.eos_id] = 1.0
        keep_dev[torch.tensor(self.lm.sigma_ids, dtype=torch.long, device=dev)] = 1.0
        self._keep_vocab_dev = keep_dev
        self._keep_vocab_cpu = keep_dev.detach().cpu()

    @torch.no_grad()
    def _probs_for_context(self, ctx_ids: List[int]) -> torch.Tensor:
        """
        (Kept for compatibility; not used in the main vectorized path.)
        ctx_ids includes [BOS] and any Σ tokens already consumed.
        Returns probs (CPU tensor) at next position, with BOS zeroed and restricted to Σ∪{EOS}.
        """
        key = tuple(ctx_ids)
        if self.cache is not None:
            cached = self.cache.get(key)
            if cached is not None:
                return cached

        input_ids = torch.tensor([ctx_ids], dtype=torch.long, device=self.lm.device)
        logits = self.lm.model(input_ids).logits                    # [1, T, V]
        probs_next = torch.softmax(logits, dim=-1).squeeze(0)[-1]   # [V] at last step
        probs_next[self.lm.bos_id] = 0.0
        probs_next = probs_next * self._keep_vocab_dev              # restrict to Σ∪{EOS}
        probs_cpu = probs_next.detach().cpu()
        if self.cache is not None:
            self.cache.set(key, probs_cpu)
        return probs_cpu

    def _allowed_mask(self, probs: torch.Tensor) -> torch.Tensor:
        """
        Compute allowed set mask (bool [V]) for the current strategy on CPU/device tensor.
        BOS is not allowed mid-sequence; truncation is computed over Σ∪{EOS}.
        """
        if probs.device.type == "cpu":
            pr = probs.clone()
            pr[self.lm.bos_id] = 0.0
            pr = pr * self._keep_vocab_cpu
        else:
            pr = probs.clone()
            pr[self.lm.bos_id] = 0.0
            pr = pr * self._keep_vocab_dev

        if self.strategy == "top-p":
            keep = _truncate_top_p(pr, self.param)
        elif self.strategy == "min-p":
            keep = _truncate_min_p(pr, self.param)
        else:
            raise ValueError(f"Unknown strategy: {self.strategy}")
        return keep

    @torch.no_grad()
    def labels_for(self, x_tokens: Iterable[str]) -> tuple[np.ndarray, list[str], list[str], np.ndarray]:
        """
        x_tokens: Σ-only sequence (no specials), e.g., ["(", ")", ")"].

        Returns:
          (labels, prefixes, columns, probs)
            - labels: np.uint8 (N+1, |Σ|+1)
                      columns 0..|Σ|-1: continuation bits in lm.sigma_tokens order
                      column  |Σ|     : membership bit ("mem") = EOS admissible
            - prefixes: list[str] of length N+1, e.g., ["[BOS]", "[BOS] (", "[BOS] ( )", ...]
            - columns:  list[str] = lm.sigma_tokens + ["mem"]
            - probs:    np.float32 (N+1, |Σ|+1), softmax probabilities for Σ then EOS
                        after removing BOS and non-Σ tokens; rows after a violation are zeros.
        """
        # Map Σ tokens to ids, validate no specials
        ids: List[int] = []
        for t in x_tokens:
            if t not in self.lm.tok2id:
                raise ValueError(f"Unknown token '{t}'.")
            tid = self.lm.tok2id[t]
            if tid in (self.lm.bos_id, self.lm.eos_id):
                raise ValueError("x_tokens must not contain specials ([BOS],[EOS]).")
            ids.append(tid)
        N = len(ids)

        # One forward for all prefixes: [BOS] + ids  → logits for rows 0..N
        dev = self.lm.device
        input_ids = torch.tensor([[self.lm.bos_id] + ids], dtype=torch.long, device=dev)  # [1, N+1]
        logits = self.lm.model(input_ids).logits                                          # [1, N+1, V]
        probs_full = torch.softmax(logits, dim=-1).squeeze(0)                             # [N+1, V]
        probs_full[:, self.lm.bos_id] = 0.0
        probs_full = probs_full * self._keep_vocab_dev                                    # [N+1, V]

        S = len(self.lm.sigma_ids)
        labels = np.zeros((N + 1, S + 1), dtype=np.uint8)
        probs  = np.zeros((N + 1, S + 1), dtype=np.float32)
        columns = list(self.lm.sigma_tokens) + ["mem"]
        prefixes: list[str] = []

        # Compute allowed masks row-by-row and find earliest violation.
        # Row n corresponds to distribution after prefix length n.
        first_bad = None
        keep_rows: List[torch.Tensor] = []

        for n in range(N + 1):
            row = probs_full[n]               # [V] on device
            keep = self._allowed_mask(row)    # bool [V]
            keep_rows.append(keep)

            # If there is a next token, check if it is admissible at this row
            if n < N:
                tid = ids[n]
                if not bool(keep[tid].item()):
                    first_bad = n
                    break  # later rows are gated to zero

        # Fill outputs, with gating: rows n >= first_bad+1 are all zeros
        for n in range(N + 1):
            # Human-readable prefix "[BOS] ..." (space-separated)
            prefix_tokens = ["[BOS]"] + [self.lm.id2tok[i] for i in ids[:n]]
            prefixes.append(" ".join(prefix_tokens))

            if first_bad is not None and n >= (first_bad + 1):
                # gated rows: leave zeros in labels[n,:] and probs[n,:]
                continue

            # Otherwise, use computed mask and probabilities
            keep = keep_rows[n]
            row = probs_full[n]

            # Continuation bits for Σ (columns 0..S-1) and their probs
            for j, sid in enumerate(self.lm.sigma_ids):
                labels[n, j] = 1 if bool(keep[sid].item()) else 0
                probs[n, j]  = float(row[sid].item())

            # Membership = EOS admissible (last column)
            labels[n, S] = 1 if bool(keep[self.lm.eos_id].item()) else 0
            probs[n, S]  = float(row[self.lm.eos_id].item())

        return labels, prefixes, columns, probs

    def labels_for_batch(self, batch_tokens: List[Iterable[str]]) -> list[tuple[np.ndarray, list[str], list[str], np.ndarray]]:
        return [self.labels_for(x) for x in batch_tokens]