# NSP-project/model_src/generation.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Iterable, Literal
import json
import math
import numpy as np
import torch
from transformers import GPT2LMHeadModel
import pdb

Strategy = Literal["natural", "top-p", "min-p"]


@dataclass
class LoadedLM:
    model: GPT2LMHeadModel
    tok2id: Dict[str, int]
    id2tok: Dict[int, str]
    bos_id: int
    eos_id: int
    sigma_tokens: List[str]          # Σ order (excludes [BOS], [EOS])
    sigma_ids: List[int]             # ids aligned with sigma_tokens
    device: torch.device


def _load_id2tok(model_dir: str) -> Dict[int, str]:
    """Loads id2tok mapping from the training save_dir."""
    with open(f"{model_dir}/id2tok.json", "r", encoding="utf-8") as f:
        id2tok = json.load(f)
    # keys are strings in the saved JSON; convert to int
    return {int(k): v for k, v in id2tok.items()}


def _tok_maps_from_id2tok(id2tok: Dict[int, str]) -> Tuple[Dict[str, int], Dict[int, str]]:
    tok2id = {t: i for i, t in id2tok.items()}
    return tok2id, id2tok


def load_lm(model_dir: str, device: Optional[torch.device | str] = None) -> LoadedLM:
    """
    Load a saved model (from train_lm.py) and build Σ.
    Σ is taken as all tokens except [BOS] and [EOS], in length-lexicographic order.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif isinstance(device, str):
        device = torch.device(device)

    model = GPT2LMHeadModel.from_pretrained(model_dir)
    model.to(device)
    model.eval()

    id2tok = _load_id2tok(model_dir)
    tok2id, id2tok = _tok_maps_from_id2tok(id2tok)

    if "[BOS]" not in tok2id or "[EOS]" not in tok2id:
        raise RuntimeError("Vocab must contain [BOS] and [EOS].")

    bos_id = tok2id["[BOS]"]
    eos_id = tok2id["[EOS]"]

    # Σ = all tokens except specials; order = length-lexicographic over token strings
    sigma_tokens = sorted(
        [t for i, t in id2tok.items() if t not in ("[BOS]", "[EOS]")],
        key=lambda s: (len(s), s),
    )
    sigma_ids = [tok2id[t] for t in sigma_tokens]

    return LoadedLM(
        model=model,
        tok2id=tok2id,
        id2tok=id2tok,
        bos_id=bos_id,
        eos_id=eos_id,
        sigma_tokens=sigma_tokens,
        sigma_ids=sigma_ids,
        device=device,
    )


# ---------- Sampling helpers ----------

@torch.no_grad()
def _last_step_probs(model: GPT2LMHeadModel, input_ids: torch.Tensor) -> torch.Tensor:
    """
    Forward pass and return softmax probs at the last position.
    input_ids: [B, T]
    returns: [B, V]
    """
    outputs = model(input_ids=input_ids)
    logits = outputs.logits[:, -1, :]  # [B, V]
    return torch.softmax(logits, dim=-1)


def _mask_bos(probs: torch.Tensor, bos_id: int) -> None:
    """In-place set P([BOS])=0 to forbid BOS mid-sequence generation."""
    probs[..., bos_id] = 0.0


def _truncate_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
    """
    Typical nucleus (top-p) mask. Returns boolean mask of shape [V].
    Ensures at least one token kept.
    """
    sorted_probs, sorted_idx = torch.sort(probs, descending=True)
    cdf = torch.cumsum(sorted_probs, dim=-1)
    keep_sorted = (cdf - sorted_probs) < p
    # always keep the first token
    keep_sorted[..., 0] = True
    keep_mask = torch.zeros_like(probs, dtype=torch.bool)
    keep_mask.scatter_(0, sorted_idx, keep_sorted)
    return keep_mask


def _truncate_min_p(probs: torch.Tensor, tau: float) -> torch.Tensor:
    """
    Min-p mask: keep tokens with p >= tau. If empty, keep argmax.
    """
    keep_mask = probs >= tau
    if not torch.any(keep_mask):
        keep_mask[torch.argmax(probs)] = True
    return keep_mask


def _sample_from_masked(probs: torch.Tensor, keep_mask: torch.Tensor) -> int:
    """Renormalize probs on kept support and draw one sample."""
    kept = probs * keep_mask.float()
    s = kept.sum()
    if s.item() <= 0:
        # Safety fallback
        return int(torch.argmax(probs).item())
    kept = kept / s
    idx = torch.multinomial(kept, num_samples=1)
    return int(idx.item())


class LmSampler:
    """
    Simple wrapper to generate strings using top-p or min-p truncation.
    Starts from [BOS], forbids [BOS] at subsequent steps, and stops at first [EOS] or max_steps.
    """
    def __init__(self, lm: LoadedLM, strategy: Strategy, param: float = 0.0):
        self.lm = lm
        self.strategy = strategy
        self.param = float(param)

    @torch.no_grad()
    def generate_one(self, max_steps: int) -> Tuple[List[str], bool]:
        """
        Returns (tokens_without_BOS_including_EOS, truncated_flag).
        """
        device = self.lm.device
        bos_id, eos_id = self.lm.bos_id, self.lm.eos_id
        input_ids = torch.tensor([[bos_id]], dtype=torch.long, device=device)
        out_ids: List[int] = []
        truncated = False

        for _ in range(max_steps):
            # pdb.set_trace()
            probs = _last_step_probs(self.lm.model, input_ids)
            probs = probs.squeeze(0)  # [V]
            _mask_bos(probs, bos_id)

            if self.strategy == "natural":
                keep = torch.ones_like(probs, dtype=torch.bool)
            elif self.strategy == "top-p":
                keep = _truncate_top_p(probs, self.param)
            elif self.strategy == "min-p":
                keep = _truncate_min_p(probs, self.param)
            else:
                raise ValueError(f"Unknown strategy: {self.strategy}")

            next_id = _sample_from_masked(probs, keep)
            out_ids.append(next_id)
            input_ids = torch.cat([input_ids, torch.tensor([[next_id]], device=device)], dim=1)

            if next_id == eos_id:
                break
            
        else:
            truncated = True

        toks = [self.lm.id2tok[i] for i in out_ids]
        return toks, truncated

    def generate_n(self, n: int, max_steps: int) -> List[Tuple[List[str], bool]]:
        return [self.generate_one(max_steps) for _ in range(n)]



# ---------- Generation accuracy (membership in target language) ----------

def _strip_final_eos(tokens: List[str]) -> Tuple[List[str], bool]:
    """
    Returns (sigma_only_tokens, has_final_eos).
    tokens are as returned by LmSampler: no BOS, includes EOS if it was produced.
    """
    if tokens and tokens[-1] == "[EOS]":
        return tokens[:-1], True
    return tokens, False


def generation_accuracy(
        lm: LoadedLM,
        sampler: LmSampler,
        language,                  # instance of datagen.languages.Language
        n_samples: int,
        max_steps: int,
    ) -> Tuple[float, int, int]:
    """
    Draw n_samples strings and compute the fraction that:
      (i) ended with EOS, and
      (ii) belong to the target language (is_positive on Σ-only tokens).
    Returns (acc, correct, total).
    """
    correct = 0
    for _ in range(n_samples):
        toks, truncated = sampler.generate_one(max_steps)
        sigma_toks, has_eos = _strip_final_eos(toks)
        if not has_eos or truncated:
            # Doesn't count as a valid complete string
            continue
        # language.is_positive expects only Σ tokens
        try:
            if language.is_positive(sigma_toks):
                correct += 1
        except Exception:
            # If language.validate_tokens rejects something, treat as incorrect.
            pass
    return (correct / max(n_samples, 1)), correct, n_samples





