# constrained_decoder.py
# Constrained decoding: maximize logP(y|c) - lambda * R(y;C)
# This implementation provides:
#  - a general ConstrainedDecoder wrapper using a base logprob function
#  - helpers to plug in a Hugging Face model to produce logprobs for the next token

from typing import List, Callable, Optional, Tuple, Iterable
import numpy as np
import math
import heapq
import logging
from templates import ConstraintSet

logger = logging.getLogger(__name__)


class ConstrainedDecoder:
    """
    Constrained decoder wrapper.
    base_logprob_fn(prefix_tokens: List[int]) -> np.ndarray (log-probs over vocab)
    vocab: optional list[str] mapping ids->token strings (used by constraint checkers)
    constraint_set: ConstraintSet instance used to compute R(y;C)
    """

    def __init__(self, base_logprob_fn: Callable[[List[int]], np.ndarray], vocab: Optional[List[str]] = None, constraint_set: Optional[ConstraintSet] = None):
        self.base_logprob_fn = base_logprob_fn
        self.vocab = vocab or []
        self.constraint_set = constraint_set or ConstraintSet()

    def _tokens_to_text(self, token_ids: List[int]) -> str:
        if self.vocab and len(self.vocab) > 0:
            return " ".join(self.vocab[i] if 0 <= i < len(self.vocab) else "<UNK>" for i in token_ids)
        return " ".join(str(t) for t in token_ids)

    def constrained_generate(
        self,
        prefix_tokens: List[int],
        max_len: int = 50,
        beam_size: int = 4,
        lambda_penalty: float = 10.0,
        eos_token_id: Optional[int] = None,
        vocab_size_hint: int = 50000,
    ) -> Tuple[List[int], float]:
        """
        Beam-search with constraint penalty.
        Returns (best_token_list, best_score).
        Score = sum logprobs - lambda_penalty * R(text;C)
        Remarks:
          - base_logprob_fn is expected to return numpy array of log-probs shape (V,)
          - this implementation is greedy in the sense we pick top-k candidates per beam using base_logprob_fn
        """
        # compute initial prefix penalty
        init_text = self._tokens_to_text(prefix_tokens)
        init_pen = self.constraint_set.constraint_penalty(init_text)
        # compute base score of prefix by summing conditional log-probs sequentially if possible
        base_prefix_score = 0.0
        try:
            for i in range(1, len(prefix_tokens) + 1):
                logits = self.base_logprob_fn(prefix_tokens[: i - 1])  # logprobs for next token
                t = prefix_tokens[i - 1]
                base_prefix_score += float(logits[t])
        except Exception:
            base_prefix_score = 0.0

        initial_score = base_prefix_score - lambda_penalty * init_pen
        beam = [(-initial_score, prefix_tokens, False)]  # min-heap by negative score

        best_seq = None
        best_score = -1e9

        V_guess = vocab_size_hint

        for step in range(max_len):
            new_candidates = []
            # expand each beam entry
            for neg_score, tokens, terminated in beam:
                score = -neg_score
                if terminated:
                    # keep terminated as candidate
                    new_candidates.append((neg_score, tokens, terminated))
                    continue
                # get logprobs for next token
                try:
                    logprobs = self.base_logprob_fn(tokens)  # numpy array (V,)
                except Exception as e:
                    logger.warning("base_logprob_fn failed; using uniform fallback: %s", e)
                    logprobs = np.log(np.ones(V_guess) / V_guess)
                # pick topk tokens
                topk = np.argsort(-logprobs)[:beam_size]
                for tok in topk:
                    tok = int(tok)
                    new_tokens = tokens + [tok]
                    text = self._tokens_to_text(new_tokens)
                    penalty = self.constraint_set.constraint_penalty(text)
                    new_score = score + float(logprobs[tok]) - lambda_penalty * penalty
                    terminated_flag = (eos_token_id is not None and tok == eos_token_id)
                    heapq.heappush(new_candidates, (-new_score, new_tokens, terminated_flag))
                    # keep beam size bounded
                    if len(new_candidates) > beam_size * 3:
                        heapq.heappop(new_candidates)

            # prune to beam_size best
            new_candidates = sorted(new_candidates, key=lambda x: x[0])[:beam_size]
            beam = new_candidates

            # update best if any terminated
            for neg_score, tokens, terminated in beam:
                if terminated:
                    sc = -neg_score
                    if sc > best_score:
                        best_score = sc
                        best_seq = tokens
            # early stop if all terminated
            if all(t for (_, _, t) in beam):
                break

        if best_seq is None:
            # pick highest scoring alive
            neg_score, tokens, _ = max(beam, key=lambda x: -x[0])
            best_seq = tokens
            best_score = -neg_score

        return best_seq, best_score


# ---------------------------
# Example helper: Hugging Face integration (optional)
# ---------------------------
def hf_logprob_fn_from_model(tokenizer, model):
    """
    Return a base_logprob_fn(prefix_tokens) -> logprobs (numpy array).
    This function creates a closure that runs the HF model to get next-token logprobs.
    Note: for production, do batching and caching.
    """
    import torch

    def logprob_fn(prefix_tokens: List[int]) -> np.ndarray:
        # build input ids tensor
        input_ids = torch.tensor(prefix_tokens, dtype=torch.long).unsqueeze(0).to(next(model.parameters()).device)
        with torch.no_grad():
            outputs = model(input_ids)
            # logits shape (1, L, V)
            logits = outputs.logits
            next_logits = logits[0, -1, :].cpu().numpy()
            # transform to log-probs (numerical stable)
            maxl = next_logits.max()
            probs = np.exp(next_logits - maxl)
            probs = probs / probs.sum()
            logprobs = np.log(probs + 1e-12)
            return logprobs
    return logprob_fn
