# NSP-project/model_src/oracles.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, List, Tuple, Literal
import warnings
import numpy as np
import torch

from .generation import (
    LoadedLM,
    LmSampler,
    _last_step_probs,
    _truncate_top_p,
    _truncate_min_p,
    _sample_from_masked,
)
from .oracle_utils import NspOracle

TruncStrategy = Literal["top-p", "min-p"]


def _normalize_input_tokens(lm: LoadedLM, tokens: Iterable[str], expect_bos: bool = False) -> List[int]:
    """
    Convert arbitrary token list into a valid LM context for MQ:
      - If starts without [BOS], prepend [BOS]
      - Input should not contain [EOS]
      - If contains [EOS], truncate at first EOS (warn) because membership is defined on Σ* prefixes.
      - Ensure all remaining non-specials are in Σ.
    Returns ctx_ids (including [BOS]).
    """
    toks = list(tokens)
    if not toks:
        # Empty Σ-prefix → context is just [BOS]
        return [lm.bos_id]

    if toks[0] != "[BOS]":
        # warnings.warn("MembershipOracle: input did not start with [BOS]; prepending [BOS].")
        toks = ["[BOS]"] + toks

    ctx: List[int] = []
    saw_eos = False
    for i, t in enumerate(toks):
        if t not in lm.tok2id:
            raise ValueError(f"Unknown token '{t}'.")
        tid = lm.tok2id[t]
        if i == 0:
            if t != "[BOS]":
                raise AssertionError("First token must be [BOS] after normalization.")
            ctx.append(tid)
            continue
        if tid == lm.eos_id:
            saw_eos = True
            warnings.warn("MembershipOracle: input contained [EOS]; evaluating membership at prefix before EOS.")
            break
        if tid == lm.bos_id:
            warnings.warn("MembershipOracle: found [BOS] mid-sequence; ignoring it.")
            continue
        # regular Σ token
        ctx.append(tid)

    return ctx


@dataclass
class MembershipOracle:
    """
    MQ(y): membership label at prefix y under truncation strategy.
    MQ(y) takes a list of Σ tokens (preferred) or a list starting with [BOS].
    Returns 1 iff EOS is admissible at y; else 0.
    """
    lm: LoadedLM
    strategy: TruncStrategy
    param: float

    @torch.no_grad()
    def label(self, tokens: Iterable[str]) -> int:
        """
        tokens: a Σ-only sequence (preferred) or a sequence that may start with [BOS].
        If it doesn't start with [BOS], we will prepend it (per your helper).
        """
        if self.strategy not in ("top-p", "min-p"):
            raise ValueError("MembershipOracle only supports 'top-p' or 'min-p' strategies.")

        # ctx_ids = [BOS] + ids(tokens over Σ)
        ctx_ids = _normalize_input_tokens(self.lm, tokens)
        device = self.lm.device

        # Single teacher-forced forward for all prefixes
        input_ids = torch.tensor([ctx_ids], dtype=torch.long, device=device)      # [1, T]
        outputs = self.lm.model(input_ids)
        logits = outputs.logits                                                   # [1, T, V]
        probs = torch.softmax(logits, dim=-1).squeeze(0)                          # [T, V]

        # Disallow BOS as a next token at every position; restrict to Σ∪{EOS}
        probs[:, self.lm.bos_id] = 0.0
        V = probs.size(-1)
        keep_vocab = torch.zeros((V,), dtype=probs.dtype, device=probs.device)
        keep_vocab[self.lm.eos_id] = 1.0
        keep_vocab[torch.tensor(self.lm.sigma_ids, dtype=torch.long, device=probs.device)] = 1.0
        probs = probs * keep_vocab  # broadcast to [T, V]

        # Check step-wise feasibility: at row i, the next token is ctx_ids[i+1]
        T = len(ctx_ids)                      # T = N+1 (BOS + N Σ tokens)
        N = T - 1
        for i in range(N):
            row = probs[i]  # distribution after prefix of length i
            if self.strategy == "top-p":
                keep = _truncate_top_p(row, float(self.param))
            elif self.strategy == "min-p":
                keep = _truncate_min_p(row, float(self.param))
            else:
                raise ValueError(f"Unknown strategy: {self.strategy}")
            tid = ctx_ids[i + 1]
            if not bool(keep[tid].item()):
                return 0  # some step token was not admissible

        # Final EOS admissibility at the terminal prefix (row N)
        last_row = probs[N]
        if self.strategy == "top-p":
            keep_last = _truncate_top_p(last_row, float(self.param))
        else:
            keep_last = _truncate_min_p(last_row, float(self.param))
        return 1 if bool(keep_last[self.lm.eos_id].item()) else 0




    @torch.no_grad()
    def label_batch(self, batch_tokens: List[Iterable[str]]) -> List[int]:
        return [self.label(toks) for toks in batch_tokens]


@dataclass
class VanillaEX:
    """
    Vanilla example oracle:
      - Draw strings from the LM's NATURAL distribution (no truncation).
      - Label each with the MQ under (strategy, param).
    Returns list of (sigma_tokens, membership_label).
    """
    lm: LoadedLM
    strategy: TruncStrategy
    param: float

    def sample(self, n: int, max_steps: int) -> List[Tuple[List[str], int]]:
        sampler = LmSampler(self.lm, strategy="natural", param=0.0)
        mq = MembershipOracle(self.lm, self.strategy, self.param)
        out: List[Tuple[List[str], int]] = []
        for _ in range(n):
            toks, _ = sampler.generate_one(max_steps)  # toks excludes BOS, may include EOS
            # Strip trailing EOS for membership prefix in Σ*
            if toks and toks[-1] == "[EOS]":
                sigma_toks = toks[:-1]
            else:
                sigma_toks = toks
            label = mq.label(sigma_toks)
            out.append((sigma_toks, label))
        return out


@dataclass
class NspEX:
    """
    NSP example oracle producing (sigma_tokens, NSP_labels) pairs.
    Two modes:
      1) with_strategy: sample under the same truncation rule, then label with NspOracle(strategy).
      2) natural_then_label: sample under natural distribution, then label with NspOracle(strategy).
    """
    lm: LoadedLM
    strategy: TruncStrategy
    param: float

    def _nsp_oracle(self) -> NspOracle:
        return NspOracle(self.lm, self.strategy, self.param, use_cache=True)
    
    def sample_one(self, sampler, max_len: int, max_steps: int):
        max_tries = 40
        check_flag = False
        for k in range(max_tries):
            if (k+1) % 5 == 0:
                check_flag = True
                print(f"NspEX.sample_one: attempt {k+1}/{max_tries}")
            toks, _ = sampler.generate_one(max_steps)  # excludes BOS, may include EOS
            if toks[-1] != "[EOS]":
                continue
            sigma_toks = toks[:-1] if (toks and toks[-1] == "[EOS]") else toks
            if len(sigma_toks) < max_len:
                if check_flag:
                    print(f"NspEX.sample_one: found valid sample of length {len(sigma_toks)}")
                return sigma_toks
        
        print(f"Warning: sample_one hit max_tries={max_tries} without a short enough sample.")
        toks, _ = sampler.generate_one(max_steps)
        sigma_toks = toks[:-1] if (toks and toks[-1] == "[EOS]") else toks
        return sigma_toks


    def sample_strat(self, n: int, max_len: int, max_steps: int) -> Tuple[
            List[Tuple[List[str], np.ndarray]],
            List[Tuple[List[str], List[str], np.ndarray]]
        ]:
        sampler = LmSampler(self.lm, strategy=self.strategy, param=self.param)
        oracle = self._nsp_oracle()
        pairs: List[Tuple[List[str], np.ndarray]] = []
        meta = []
        for _ in range(n):
            sigma_toks = self.sample_one(sampler, max_len, max_steps)
            labels, prefixes, columns, probs = oracle.labels_for(sigma_toks)      # (N+1, |Σ|+1)
            pairs.append((sigma_toks, labels))
            meta.append((prefixes, columns, probs))
            if _ % 500 == 0:
                print(f"NspEX.sample_strat: sampled {_+1}/{n} examples")
        return pairs, meta

    def sample_natural(self, n: int, max_len: int, max_steps: int) -> List[Tuple[List[str], np.ndarray]]:
        sampler = LmSampler(self.lm, strategy="natural", param=0.0)
        oracle = self._nsp_oracle()
        pairs: List[Tuple[List[str], np.ndarray]] = []
        meta = []
        for _ in range(n):
            sigma_toks = self.sample_one(sampler, max_len, max_steps)
            labels, prefixes, columns, probs = oracle.labels_for(sigma_toks)
            pairs.append((sigma_toks, labels))
            meta.append((prefixes, columns, probs))
            if _ % 500 == 0:
                print(f"NspEX.sample_natural: sampled {_+1}/{n} examples")
        return pairs, meta




@dataclass
class LMPrefixEQ:
    """
    Generative query simulator for L*‑NSP (subcase B2), with:
      - verifies the given Σ-only prefix never took an inadmissible step under
        the truncation rule (top‑p/min‑p). If it did, raises RuntimeError.
      - Sampling: starting from the prefix, repeatedly SAMPLE a next Σ token from the
        truncated distribution (no greedy) until EOS becomes admissible or max_steps
        is reached. If EOS never becomes admissible within max_steps, retry up to
        `retries` times, then raise RuntimeError.

    Returns a Σ-only list (no [BOS], no [EOS]) that is verified by MembershipOracle (==1).
    """
    lm: LoadedLM
    strategy: TruncStrategy
    param: float
    retries: int = 20  # maximum attempts if fail to reach EOS-admissible within max_steps

    def _apply_trunc(self, probs_1d: torch.Tensor) -> torch.Tensor:
        """Apply top-p or min-p to a single probability row (1D). Returns boolean keep mask."""
        if self.strategy == "top-p":
            return _truncate_top_p(probs_1d, float(self.param))
        elif self.strategy == "min-p":
            return _truncate_min_p(probs_1d, float(self.param))
        else:
            raise ValueError(f"Unknown strategy: {self.strategy}")

    @torch.no_grad()
    def _assert_prefix_feasible(self, sigma_prefix: List[str]) -> None:
        """
        Guardrail: ensure each provided Σ token is admissible at the step it was taken.
        Raises RuntimeError on the first inadmissible step.
        """
        # Map Σ tokens to ids; forbid specials
        ids = []
        for t in sigma_prefix:
            if t not in self.lm.tok2id:
                raise ValueError(f"Unknown token '{t}'.")
            tid = self.lm.tok2id[t]
            if tid in (self.lm.bos_id, self.lm.eos_id):
                raise ValueError("Prefix must be Σ-only: do not include [BOS] or [EOS].")
            ids.append(tid)

        ctx_ids = [self.lm.bos_id] + ids  # [BOS] + Σ-prefix
        device = self.lm.device

        # Teacher-forced probs for all rows
        input_ids = torch.tensor([ctx_ids], dtype=torch.long, device=device)   # [1, T]
        outputs = self.lm.model(input_ids)
        logits = outputs.logits.squeeze(0)                                     # [T, V]
        probs_rows = torch.softmax(logits, dim=-1)                             # [T, V]
        # Forbid BOS mid-sequence; limit to Σ∪{EOS}
        probs_rows[:, self.lm.bos_id] = 0.0

        T = probs_rows.size(0)      # T = N+1
        N = T - 1
        for i in range(N):
            row = probs_rows[i]     # distribution after prefix length i
            keep = self._apply_trunc(row)
            tid_next = ctx_ids[i + 1]
            if not bool(keep[tid_next].item()):
                raise RuntimeError(
                    f"LMPrefixEQ: inadmissible step in given prefix at position {i}: "
                    f"token='{self.lm.id2tok[tid_next]}' not in truncated support."
                )

    @torch.no_grad()
    def sample(self, prefix_tokens: Iterable[str], max_len: int = 100, max_steps: int = 256, verbose: bool = True) -> List[str]:
        """
        Args:
          prefix_tokens: iterable of Σ tokens ONLY (no "[BOS]" or "[EOS]").
          max_steps: maximum number of *new* tokens to append per attempt.

        Returns:
          Σ-only list of tokens (no [BOS]/[EOS]) that is accepted under the truncation rule.

        Raises:
          TypeError if a string is provided instead of a list of tokens.
          ValueError if tokens contain specials or unknown tokens.
          RuntimeError if the given prefix is infeasible or an accepting suffix
                       cannot be found within the retry budget.
        """
        # Enforce: list/iterable of Σ tokens, not a raw string
        if isinstance(prefix_tokens, str):
            raise TypeError(
                "LMPrefixEQ.sample expects an iterable/list of Σ tokens, not a single string."
            )

        sigma_prefix: List[str] = list(prefix_tokens)

        # Guardrail: ensure the provided prefix never took an inadmissible step
        self._assert_prefix_feasible(sigma_prefix)

        # Build initial LM context
        base_ctx_ids = [self.lm.bos_id] + [self.lm.tok2id[t] for t in sigma_prefix]
        device = self.lm.device

        # If EOS is already admissible at the provided prefix → return immediately
        with torch.no_grad():
            input_ids = torch.tensor([base_ctx_ids], dtype=torch.long, device=device)
            last_row = _last_step_probs(self.lm.model, input_ids).squeeze(0)  # [V] at current prefix
            last_row[self.lm.bos_id] = 0.0
            keep_last = self._apply_trunc(last_row)
            if bool(keep_last[self.lm.eos_id].item()):
                # Verify with MQ and return Σ-only prefix
                mq = MembershipOracle(self.lm, self.strategy, self.param)
                if mq.label(sigma_prefix) != 1:
                    raise RuntimeError("LMPrefixEQ: MQ verification failed at starting prefix.")
                return sigma_prefix

        # Otherwise, try to SAMPLE a suffix until EOS becomes admissible
        found_flag = False
        backup_seq = None
        
        for attempt in range(1, self.retries + 1):
            
            if attempt % 5 == 0:
                print(f"LMPrefixEQ: attempt {attempt}/{self.retries}")

            ctx_ids = list(base_ctx_ids)
            sigma_out = list(sigma_prefix)
            steps = 0
            
            while steps < max_steps:
                # Probs at current prefix
                input_ids = torch.tensor([ctx_ids], dtype=torch.long, device=device)
                row = _last_step_probs(self.lm.model, input_ids).squeeze(0)  # [V]
                row[self.lm.bos_id] = 0.0

                keep = self._apply_trunc(row)

                # If EOS is admissible now → verify and return
                if bool(keep[self.lm.eos_id].item()):
                    mq = MembershipOracle(self.lm, self.strategy, self.param)
                    if mq.label(sigma_out) != 1:
                        raise RuntimeError("LMPrefixEQ: MQ verification failed at termination.")
                    
                    if len(sigma_out) <= max_len:
                        if verbose:
                            print('PrefixEQ: Found valid suffix')
                        return sigma_out
                    else:
                        found_flag = True
                        if verbose:
                            print('PrefixEQ: Found backup')
                        backup_seq = sigma_out
                        break

                # Otherwise, sample a Σ token from the truncated support
                keep_sample = keep.clone()
                keep_sample[self.lm.eos_id] = False      # don't sample EOS, we only stop when admissible
                keep_sample[self.lm.bos_id] = False      # never sample BOS

                # Ensure there is at least one admissible Σ token to sample
                has_sigma = any(bool(keep_sample[sid].item()) for sid in self.lm.sigma_ids)
                if not has_sigma:
                    # dead end under current attempt; break and retry
                    break

                # Sample one token (renormalized over kept support)
                next_id = _sample_from_masked(row, keep_sample)
                if next_id in (self.lm.bos_id, self.lm.eos_id):
                    # Shouldn't happen due to keep_sample masking, but guard anyway
                    break

                # Append and continue
                ctx_ids.append(next_id)
                sigma_out.append(self.lm.id2tok[next_id])
                steps += 1

            # If we exit the while loop without success, retry
            # (either hit max_steps or no admissible Σ token)
            # Try next attempt
            continue
        
        if found_flag and backup_seq is not None:
            warnings.warn(
                f"LMPrefixEQ: accepting suffix found but exceeded max_len={max_len}; returning anyway."
            )
            return backup_seq
        # If all attempts failed
        raise RuntimeError(
            f"LMPrefixEQ: accepting suffix could not be found after {self.retries} attempts "
            f"(max_steps={max_steps}) for prefix={sigma_prefix}"
        )