# fluency_ppl.py
"""
Perplexity-based fluency scoring with batch support.

Backends:
  - HF (Transformers): original implementation using AutoModelForCausalLM
  - vLLM: uses prompt_logprobs to compute teacher-forcing next-token cross entropy

Public API:
  - calc_ppl(text, device=None, backend=None)
  - calc_ppl_batch(texts, batch_size=..., device=None, backend=None)

Notes for vLLM:
  - We call llm.generate() with TokensPrompt dicts: {"prompt_token_ids": [...]}
    (Do NOT pass prompt_token_ids=... keyword; some vLLM versions reject it.)
  - We request prompt_logprobs and compute CE over prompt tokens.
  - Prefer choosing GPU via CUDA_VISIBLE_DEVICES at launch time.
"""

from __future__ import annotations

from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple

import os
import math
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer  # type: ignore

from config import (
    PPL_BACKEND,
    PPL_MODEL_NAME,
    PPL_MAX_LENGTH,
    PPL_BATCH_SIZE,
    PPL_VLLM_MODEL_NAME,
    PPL_VLLM_DTYPE,
    PPL_VLLM_TENSOR_PARALLEL_SIZE,
    PPL_VLLM_GPU_MEMORY_UTILIZATION,
    PPL_VLLM_TRUST_REMOTE_CODE,
    PPL_VLLM_ENFORCE_EAGER,
    PPL_VLLM_ENABLE_PREFIX_CACHING,
    PPL_VLLM_MAX_MODEL_LEN,
    PPL_VLLM_EVAL_MAX_NEW_TOKENS,
    PPL_VLLM_PROMPT_LOGPROBS_K,
)


# =========================================================
# HF backend (original)
# =========================================================

@lru_cache()
def _get_ppl_model_hf(device: str = "auto") -> Tuple[AutoTokenizer, AutoModelForCausalLM, str]:
    """
    Load (tokenizer, model) and move the model to the requested device.

    Notes:
    - The cache key includes `device`, so "cuda:0" and "cuda:1" will maintain
      separate cached model instances.
    - device="auto" picks "cuda" if available else "cpu".
    """
    tokenizer = AutoTokenizer.from_pretrained(PPL_MODEL_NAME)

    # Ensure pad token exists for batching.
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(PPL_MODEL_NAME)
    model.eval()

    if device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"

    torch_device = torch.device(device)

    # Recommended: set current CUDA device explicitly when an index is provided.
    if torch_device.type == "cuda" and torch_device.index is not None:
        torch.cuda.set_device(torch_device.index)

    model.to(torch_device)

    return tokenizer, model, str(torch_device)


def _ppl_from_logits(
    logits: torch.Tensor,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
) -> torch.Tensor:
    """
    Compute per-example perplexity from logits with masking.

    logits: [B, L, V]
    input_ids: [B, L]
    attention_mask: [B, L] with 1 for valid tokens, 0 for padding
    """
    shift_logits = logits[:, :-1, :].contiguous()      # [B, L-1, V]
    shift_labels = input_ids[:, 1:].contiguous()       # [B, L-1]
    shift_mask = attention_mask[:, 1:].contiguous()    # [B, L-1]

    B, Lm1 = shift_labels.shape

    loss_flat = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        reduction="none",
    ).view(B, Lm1)

    loss_flat = loss_flat * shift_mask

    denom = shift_mask.sum(dim=1).clamp(min=1)
    loss_per_ex = loss_flat.sum(dim=1) / denom

    ppl = torch.exp(loss_per_ex)
    return ppl


def _calc_ppl_batch_hf(
    texts: List[str],
    batch_size: int,
    device: Optional[str],
) -> List[float]:
    if not texts:
        return []

    tokenizer, model, resolved_device = _get_ppl_model_hf(device or "auto")

    ppls: List[float] = []
    for i in range(0, len(texts), batch_size):
        chunk = texts[i : i + batch_size]

        enc = tokenizer(
            chunk,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=PPL_MAX_LENGTH,
        )

        input_ids = enc["input_ids"].to(resolved_device)
        attention_mask = enc["attention_mask"].to(resolved_device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

        ppl_tensor = _ppl_from_logits(logits, input_ids, attention_mask)
        ppls.extend([float(x) for x in ppl_tensor.detach().cpu().tolist()])

    return ppls


# =========================================================
# vLLM backend
# =========================================================

_VLLM_CACHE: Dict[str, Any] = {}


def _maybe_set_cuda_visible_devices_from_device_arg(device: Optional[str]) -> None:
    """
    Best-effort GPU selection for vLLM.

    vLLM generally expects GPU selection via CUDA_VISIBLE_DEVICES.
    If the user passes device like "cuda:1" and CUDA_VISIBLE_DEVICES is unset,
    we set it to that index before constructing vLLM.

    If CUDA_VISIBLE_DEVICES is already set, we do not override it.
    """
    if not device:
        return
    if os.environ.get("CUDA_VISIBLE_DEVICES"):
        return
    if not device.startswith("cuda:"):
        return
    try:
        idx = int(device.split("cuda:")[1])
    except Exception:
        return
    os.environ["CUDA_VISIBLE_DEVICES"] = str(idx)


def _init_vllm_llm_cached(device: Optional[str]) -> Tuple[Any, Any]:
    """
    Return (llm, tokenizer) for vLLM backend, cached globally.

    NOTE:
      - Prefer launching with CUDA_VISIBLE_DEVICES for correct GPU binding.
      - We set VLLM_ENABLE_V1_MULTIPROCESSING=0 to keep everything in-process.
    """
    _maybe_set_cuda_visible_devices_from_device_arg(device)

    cache_key = (
        f"{PPL_VLLM_MODEL_NAME}|{PPL_VLLM_DTYPE}|tp{PPL_VLLM_TENSOR_PARALLEL_SIZE}|"
        f"mml{PPL_VLLM_MAX_MODEL_LEN}|gmu{PPL_VLLM_GPU_MEMORY_UTILIZATION}|"
        f"trc{int(PPL_VLLM_TRUST_REMOTE_CODE)}|eager{int(PPL_VLLM_ENFORCE_EAGER)}|"
        f"pc{int(PPL_VLLM_ENABLE_PREFIX_CACHING)}|cvis{os.environ.get('CUDA_VISIBLE_DEVICES','')}"
    )
    if cache_key in _VLLM_CACHE:
        return _VLLM_CACHE[cache_key]["llm"], _VLLM_CACHE[cache_key]["tokenizer"]

    os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

    try:
        from vllm import LLM  # type: ignore
    except Exception as e:
        raise RuntimeError("vLLM backend requested but vllm is not available. Please install vllm.") from e

    llm = LLM(
        model=PPL_VLLM_MODEL_NAME,
        tensor_parallel_size=int(PPL_VLLM_TENSOR_PARALLEL_SIZE),
        gpu_memory_utilization=float(PPL_VLLM_GPU_MEMORY_UTILIZATION),
        dtype=str(PPL_VLLM_DTYPE),
        max_model_len=int(PPL_VLLM_MAX_MODEL_LEN),
        trust_remote_code=bool(PPL_VLLM_TRUST_REMOTE_CODE),
        enforce_eager=bool(PPL_VLLM_ENFORCE_EAGER),
        disable_log_stats=True,
        enable_prefix_caching=bool(PPL_VLLM_ENABLE_PREFIX_CACHING),
    )
    tokenizer = llm.get_tokenizer()

    _VLLM_CACHE[cache_key] = {"llm": llm, "tokenizer": tokenizer}
    return llm, tokenizer


def _extract_logprob_value(v: Any) -> Optional[float]:
    """
    vLLM may store logprob values as floats or objects with .logprob.
    """
    if v is None:
        return None
    if isinstance(v, (float, int)):
        return float(v)
    if hasattr(v, "logprob"):
        try:
            return float(v.logprob)
        except Exception:
            return None
    return None


def _prompt_next_token_ce_from_vllm(output: Any, prompt_token_ids: List[int]) -> Tuple[float, int, int]:
    """
    Compute teacher-forcing next-token CE from vLLM output.prompt_logprobs.

    We ignore position i=0 and sum -log p(token_i) for i>=1.

    Returns:
      (nll_sum, n_tokens_counted, n_tokens_missing)
    """
    plp = getattr(output, "prompt_logprobs", None)
    if plp is None:
        return 0.0, 0, 0

    nll_sum = 0.0
    n = 0
    missing = 0

    T = min(len(plp), len(prompt_token_ids))
    for i in range(1, T):
        item = plp[i]
        if item is None:
            missing += 1
            continue

        tok = int(prompt_token_ids[i])
        lp = None

        # Common: dict[token_id] -> Logprob
        if isinstance(item, dict):
            if tok in item:
                lp = _extract_logprob_value(item[tok])
            else:
                # Some builds might not include the chosen token in top-k dict.
                missing += 1
                continue
        else:
            # Some builds may store the chosen token logprob directly.
            lp = _extract_logprob_value(item)

        if lp is None:
            missing += 1
            continue

        nll_sum += -lp
        n += 1

    return nll_sum, n, missing


def _tokenize_texts_for_vllm(tokenizer: Any, texts: List[str], max_prompt_len: int) -> List[List[int]]:
    """
    Tokenize and truncate each text to `max_prompt_len` tokens (including special tokens).
    """
    token_ids_list: List[List[int]] = []
    for s in texts:
        enc = tokenizer(
            s,
            add_special_tokens=True,
            truncation=True,
            max_length=int(max_prompt_len),
        )
        ids = enc["input_ids"] if isinstance(enc, dict) else enc.input_ids
        token_ids_list.append(list(ids))
    return token_ids_list


def _calc_ppl_batch_vllm(
    texts: List[str],
    batch_size: int,
    device: Optional[str],
) -> List[float]:
    if not texts:
        return []

    llm, tokenizer = _init_vllm_llm_cached(device)

    # vLLM constraint: prompt_len + max_tokens <= max_model_len
    effective_max_prompt_len = min(int(PPL_MAX_LENGTH), int(PPL_VLLM_MAX_MODEL_LEN) - int(PPL_VLLM_EVAL_MAX_NEW_TOKENS))
    if effective_max_prompt_len < 2:
        raise RuntimeError(
            f"Invalid effective_max_prompt_len={effective_max_prompt_len}. "
            f"Need at least 2 tokens for next-token CE. "
            f"PPL_MAX_LENGTH={PPL_MAX_LENGTH}, PPL_VLLM_MAX_MODEL_LEN={PPL_VLLM_MAX_MODEL_LEN}, "
            f"PPL_VLLM_EVAL_MAX_NEW_TOKENS={PPL_VLLM_EVAL_MAX_NEW_TOKENS}"
        )

    try:
        from vllm import SamplingParams  # type: ignore
    except Exception as e:
        raise RuntimeError("vllm is not available (SamplingParams import failed).") from e

    sp = SamplingParams(
        temperature=0.0,
        top_p=1.0,
        max_tokens=int(PPL_VLLM_EVAL_MAX_NEW_TOKENS),
        prompt_logprobs=int(PPL_VLLM_PROMPT_LOGPROBS_K),
        detokenize=False,
        skip_special_tokens=False,
    )

    ppls: List[float] = []
    for i in range(0, len(texts), batch_size):
        chunk = texts[i : i + batch_size]
        prompt_token_ids_list = _tokenize_texts_for_vllm(tokenizer, chunk, max_prompt_len=effective_max_prompt_len)
        tokens_prompts = [{"prompt_token_ids": ids} for ids in prompt_token_ids_list]

        outs = llm.generate(tokens_prompts, sampling_params=sp, use_tqdm=False)

        for out, ids in zip(outs, prompt_token_ids_list):
            nll_sum, n_tok, missing = _prompt_next_token_ce_from_vllm(out, ids)

            if n_tok <= 0:
                ppls.append(float("nan"))
                continue

            # If too many tokens are missing, return NaN to avoid silently biased PPL.
            # Increase PPL_VLLM_PROMPT_LOGPROBS_K in config.py if this happens often.
            if missing > 0 and (missing / max(1, (n_tok + missing))) > 0.05:
                ppls.append(float("nan"))
                continue

            ppl = math.exp(nll_sum / n_tok)
            ppls.append(float(ppl))

    return ppls


# =========================================================
# Public API
# =========================================================

def calc_ppl_batch(
    texts: List[str],
    batch_size: int = PPL_BATCH_SIZE,
    device: Optional[str] = None,
    backend: Optional[str] = None,
) -> List[float]:
    """
    Compute perplexity for a list of texts using batching.

    Parameters
    ----------
    texts:
        List of texts to score.
    batch_size:
        Mini-batch size for scoring.
    device:
        - HF: explicit device (e.g. "cuda:1", "cuda:0", "cpu")
        - vLLM: best-effort; prefer CUDA_VISIBLE_DEVICES at launch
    backend:
        "hf" or "vllm". If None, uses config.PPL_BACKEND.

    Returns
    -------
    List[float]
        Perplexity values, one per input text.
    """
    be = (backend or PPL_BACKEND or "hf").strip().lower()
    if be == "hf":
        return _calc_ppl_batch_hf(texts=texts, batch_size=batch_size, device=device)
    if be == "vllm":
        return _calc_ppl_batch_vllm(texts=texts, batch_size=batch_size, device=device)
    raise ValueError(f"Unknown PPL backend: {backend!r} (expected 'hf' or 'vllm').")


def calc_ppl(text: str, device: Optional[str] = None, backend: Optional[str] = None) -> float:
    """
    Single-text perplexity (wrapper over calc_ppl_batch).
    """
    return calc_ppl_batch([text], batch_size=1, device=device, backend=backend)[0]
