import torch
import torch.nn.functional as F
from typing import Tuple, Optional


class SelfConfidenceRewardModel:
    """
    Reward model that uses the model's own logits to compute a
    self-confidence score for the generated answer, avoiding extra
    verifier generations.

    By default, the reward is (mean_top1_prob - 1.0), which lies in [-1, 0].
    Higher is better (closer to 0 indicates higher confidence), matching the
    existing Latent-Seek code's expectation that rewards are <= 0 and we stop
    early when reward > threshold (e.g., -0.2).
    """

    def __init__(
        self,
        *,
        dllm_type: str,
        model,
        tokenizer,
        device: str = "cuda",
        # Confidence config
        measure: str = "top1",  # "top1" or "gap"
        aggregator: str = "mean",  # "mean" or "min"
        # LLaDA specific
        llada_mask_id: int = 126336,
    ) -> None:
        self.dllm_type = dllm_type.lower()
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.measure = measure
        self.aggregator = aggregator
        self.llada_mask_id = llada_mask_id

    @torch.no_grad()
    def _confidence_llada(
        self, prompt_text: str, answer_text: str
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute token-level confidence for LLaDA by masking the answer span
        and reading the logits at masked positions.

        Returns:
            top1: (T,) top1 probabilities for each answer token position
            gap:  (T,) top1 - top2 gaps
        """
        enc = self.tokenizer([prompt_text], return_tensors="pt")
        prompt_ids = enc.input_ids.to(self.device)  # (1, P)
        ans_ids = self.tokenizer.encode(answer_text, add_special_tokens=False)
        T = len(ans_ids)

        # Build sequence: prompt tokens + T masked tokens
        x = torch.full(
            (1, prompt_ids.shape[1] + T),
            self.llada_mask_id,
            dtype=torch.long,
            device=self.device,
        )
        x[:, : prompt_ids.shape[1]] = prompt_ids

        logits = self.model.forward(x).logits  # (1, L, V)
        # Slice masked answer region
        start = prompt_ids.shape[1]
        end = start + T
        logits_ans = logits[:, start:end, :]  # (1, T, V)
        probs = torch.softmax(logits_ans, dim=-1)  # (1, T, V)
        top2 = torch.topk(probs, k=2, dim=-1).values  # (1, T, 2)
        top1 = top2[..., 0].squeeze(0).contiguous()  # (T,)
        gap = (top2[..., 0] - top2[..., 1]).squeeze(0).contiguous()  # (T,)
        return top1, gap

    @torch.no_grad()
    def _confidence_dream(
        self, prompt_text: str, answer_text: str
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute token-level confidence for DREAM-style DLLMs by teacher-forcing
        over the prompt+answer and using next-token distributions.
        """
        enc_prompt = self.tokenizer([prompt_text], return_tensors="pt")
        prompt_ids = enc_prompt.input_ids.to(self.device)  # (1, P)
        ans_ids = self.tokenizer.encode(answer_text, add_special_tokens=False)
        ans_ids_t = torch.tensor(ans_ids, device=self.device, dtype=torch.long).unsqueeze(0)  # (1, T)

        x = torch.cat([prompt_ids, ans_ids_t], dim=-1)  # (1, P+T)
        logits = self.model.forward(x).logits  # (1, P+T, V)

        P = prompt_ids.shape[1]
        T = ans_ids_t.shape[1]
        if P == 0:
            # Fallback: use the logits starting from position 0
            start_pos = 0
        else:
            # Next-token logits for the first answer token live at position P-1
            start_pos = P - 1
        end_pos = start_pos + T
        logits_ans = logits[:, start_pos:end_pos, :]  # (1, T, V)

        probs = torch.softmax(logits_ans, dim=-1)  # (1, T, V)
        top2 = torch.topk(probs, k=2, dim=-1).values  # (1, T, 2)
        top1 = top2[..., 0].squeeze(0).contiguous()  # (T,)
        gap = (top2[..., 0] - top2[..., 1]).squeeze(0).contiguous()  # (T,)
        return top1, gap

    def _aggregate(self, values: torch.Tensor) -> float:
        if values.numel() == 0:
            return 0.0
        if self.aggregator == "min":
            return float(values.min().item())
        # default mean
        return float(values.mean().item())

    def get_reward(
        self,
        question: str,
        solution: str,
        input_text: Optional[str] = None,
    ) -> float:
        """
        Compute reward from self-confidence on the provided solution.

        Args:
            question: unused; kept for API compatibility
            solution: candidate answer string
            input_text: full formatted prompt (chat template). If None, we
                        approximate by using only the answer text as context.
        Returns:
            reward in [-1, 0], higher is better (closer to 0).
        """
        prompt = input_text if input_text is not None else ""

        if self.dllm_type == "llada":
            top1, gap = self._confidence_llada(prompt, solution)
        elif self.dllm_type == "dream":
            top1, gap = self._confidence_dream(prompt, solution)
        else:
            raise ValueError(f"Unsupported dllm_type for self-confidence: {self.dllm_type}")

        if self.measure == "gap":
            conf = self._aggregate(gap)
        else:  # default to top1
            conf = self._aggregate(top1)

        # Map confidence in [0,1] to reward in [-1, 0]
        reward = conf - 1.0
        return float(reward)

