# model.py
# -----------------------------------------------------------------------------
# Module: PrefillingModel
# Purpose: Unified interface for (i) prefilling for consensus scoring
#          and (ii) generating reasoning/answers across multiple VLM backends.
# -----------------------------------------------------------------------------

from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import torch
from torch import Tensor
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer
from qwen_vl_utils import process_vision_info

@dataclass
class Specials:
    """Holds backend-specific chat markers and special token ids."""
    user_prefix: List[int]
    asst_prefix: List[int]
    pad_id: Optional[int] = None
    image_token_ids: Tuple[int, ...] = ()

class PrefillingModel:
    """
    A light wrapper that exposes two capabilities:
      (1) prefill_nll: average NLL on the answer span only (for consensus scoring)
      (2) answer:      baseline generative answering (for draft reasoning)
    """

    def __init__(self, model, processor, tokenizer, max_len: int, tag: str, device: str = "cuda"):
        self.model = model
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.tag = tag.lower()
        self.device = device

        self.gen_kwargs = self._default_gen_kwargs()
        self.specials = self._specials_for_tag()

    def _default_gen_kwargs(self) -> Dict:
        """
        Returns backend-specific generation defaults. Following each model's own evaluation framework.
        """
        if "qwen" in self.tag:
            return dict(max_new_tokens=2048, temperature=0.01, top_p=0.001, top_k=1, repetition_penalty=1.0)
        if "mimo" in self.tag:
            return dict(max_new_tokens=16384, temperature=0.0, top_p=1.0)
        if "intern" in self.tag:
            return dict(max_new_tokens=4096, do_sample=True, temperature=0.7, top_p=0.95)
        if "glm" in self.tag:
            return dict(max_new_tokens=8192, temperature=0.1, do_sample=True)
        if "ovis" in self.tag:
            return dict(max_new_tokens=3072, do_sample=False, top_p=None, top_k=None, 
                       temperature=None, repetition_penalty=None)
        return dict(max_new_tokens=2048)

    def _specials_for_tag(self) -> Specials:
        """
        Defines chat markers and special token ids per backend.
        This concentrates all marker/id assumptions in one place.
        """
        if "qwen" in self.tag or "mimo" in self.tag or "intern" in self.tag:
            user = self.tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
            asst = self.tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
            img_ids = (151655, 151652, 151653)
            if "intern" in self.tag:
                img_ids += (151667, 151666, 151665)  # Additional image tokens for InternVL
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
            return Specials(user, asst, pad_id, img_ids)

        if "glm" in self.tag:
            user = self.tokenizer.encode("<|user|>\n", add_special_tokens=False)
            asst = self.tokenizer.encode("<|assistant|>\n", add_special_tokens=False)
            img_ids = (151343, 151339, 151340)
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
            return Specials(user, asst, pad_id, img_ids)

        if "ovis" in self.tag:
            user = self.tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
            asst = self.tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
            img_ids = (151655, 151652, 151653)
            pad_id = 151643
            return Specials(user, asst, pad_id, img_ids)

        return Specials([], [], getattr(self.tokenizer, "pad_token_id", None), ())

    def _build_messages(self, image_path: str, text_prompt: str) -> List[Dict]:
        """
        Constructs a single-turn, image-grounded user message.
        InternVL expects a 'url' field; others accept 'image' path.
        """
        if "intern" in self.tag:
            return [{"role": "user",
                     "content": [{"type": "image", "url": image_path},
                                 {"type": "text", "text": text_prompt}]}]
        return [{"role": "user",
                 "content": [{"type": "image", "image": image_path},
                             {"type": "text", "text": text_prompt}]}]

    def _apply_template_and_encode(self, messages: List[Dict]) -> Dict[str, Tensor]:
        """
        Applies the backend's chat template and packs multi-modal inputs.
        Returns a dict of tensor inputs on self.device.
        """
        if "glm" in self.tag:
            inputs = self.processor.apply_chat_template(
                messages, tokenize=True, add_generation_prompt=True,
                return_dict=True, return_tensors="pt",
            ).to(self.device)
            return inputs

        if "ovis" in self.tag:
            # Special handling for OVIS
            return self._handle_ovis_encoding(messages)

        # Standard processing for qwen, mimo, intern
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text], images=image_inputs, videos=video_inputs,
            padding=True, return_tensors="pt",
        ).to(self.device)
        return inputs

    def _handle_ovis_encoding(self, messages: List[Dict]) -> Dict[str, Tensor]:
        """Special encoding for OVIS model"""
        # Extract image path from messages
        image_path = None
        for content in messages[0]["content"]:
            if content["type"] == "image":
                image_path = content["image"]
                break
        
        if not image_path:
            raise ValueError("No image found in messages for OVIS")

        # OVIS-specific preprocessing
        enable_thinking = False
        min_pixels = 200704
        max_pixels = 2408448
        
        ovis_msg = [{
            "role": "user",
            "content": [
                {"type": "image", "image": Image.open(image_path).convert("RGB")},
                {"type": "text", "text": messages[0]["content"][1]["text"]},
            ],
        }]

        input_ids, pixel_values, grid_thws = self.model.preprocess_inputs(
            messages=ovis_msg,
            add_generation_prompt=True,
            enable_thinking=enable_thinking,
            min_pixels=min_pixels,
            max_pixels=max_pixels,
        )
        
        input_ids = input_ids.to(self.device)
        attention_mask = torch.ne(input_ids, 151643).to(self.device)
        
        inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
        if pixel_values is not None:
            inputs["pixel_values"] = pixel_values.to(device=self.device, dtype=self.model.dtype)
        if grid_thws is not None:
            inputs["grid_thws"] = grid_thws.to(self.device)
            
        return inputs

    # Label Masking for Prefill NLL
    def _mask_labels_for_chat(
        self,
        input_ids: torch.Tensor,
        inputs: Dict | None = None,  
        ) -> torch.Tensor:
        """
        Build a `labels` tensor such that *only* the prompt (question + answer) tokens
        contribute to the loss.

        Masking rules
        -------------
        1. Everything up to and including the last <user> prefix.
        2. From the last <assistant> prefix (inclusive) to the end of sequence
        3. PAD tokens and image-placeholder tokens.

        Parameters
        ----------
        input_ids : Tensor
            Shape (1, L). The full tokenized sequence of question and answer tokens.
        inputs : Dict, optional
            Extra inputs returned by the processor; needed only for OVIS.

        Returns
        -------
        Tensor
            A copy of `input_ids` where masked positions are set to −100.
        """
        labels = input_ids.clone()
        ids = input_ids[0]                      

        def mask_by_prefix(prefix_ids: list[int], mask_after: bool) -> None:
            """
            Locate the *last* occurrence of `prefix_ids` and apply masking.

            mask_after = False  → mask everything up to and including the prefix.
            mask_after = True   → mask from the prefix start all the way to EOS.
            """
            if not prefix_ids:
                return
            k = len(prefix_ids)
            window = ids.unfold(0, k, 1)        
            pattern = torch.tensor(prefix_ids, device=ids.device)
            hits = (window == pattern).all(dim=1)
            if hits.any():
                pos = int(hits.nonzero(as_tuple=False)[-1, 0])
                if mask_after:
                    labels[0, pos:] = -100
                else:
                    labels[0, :pos + k] = -100

        # 1) mask system / user part
        mask_by_prefix(self.specials.user_prefix, mask_after=False)

        # 2) mask assistant prefix + further output
        mask_by_prefix(self.specials.asst_prefix, mask_after=True)

        # 3) mask pad and image tokens
        if self.specials.pad_id is not None:
            labels[labels == self.specials.pad_id] = -100
        for tid in self.specials.image_token_ids:
            labels[labels == tid] = -100

        # 4) OVIS: mask negative IDs
        if "ovis" in self.tag and inputs is not None:
            labels[input_ids < 0] = -100
            labels = labels.to(torch.long)

        return labels

    @torch.inference_mode()
    def prefill_nll(self, image_path: str, question: str, answer: str) -> float:
        """
        Computes the average negative log-likelihood (NLL) for the provided
        answer text, conditioned on (image, question). Only the answer span
        contributes to the loss via label masking.
        """
        text_prompt = f"Question: {question}\nAnswer: {answer}"
        print(f"Prompt: {text_prompt}")
        
        messages = self._build_messages(image_path, text_prompt)
        inputs = self._apply_template_and_encode(messages)

        input_ids: Tensor = (inputs["input_ids"] if isinstance(inputs, dict) else inputs.input_ids).to(self.device)
        labels = self._mask_labels_for_chat(input_ids, inputs)

        if "ovis" in self.tag:
            loss = self.model(**inputs, labels=labels).loss
        else:
            loss = self.model(**inputs, labels=labels).loss
            
        ppl = torch.exp(loss).item()
        print(f"Avg NLL: {loss.item():.4f}")
        print(f"PPL: {ppl:.4f}")
        return ppl

    @torch.inference_mode()
    def answer(self, image_path: str, question: str, prompt_tpl: str,
               max_tokens: int, return_prompt: bool = False) -> str:
        """
        Generates a baseline answer under a given prompt template.
        """
        prompt = prompt_tpl.format(question)
        print("Answering the questions")
        
        if "ovis" in self.tag:
            return self._answer_ovis(image_path, prompt)
        
        # Standard processing
        messages = self._build_messages(image_path, prompt)
        inputs = self._apply_template_and_encode(messages)

        gen_kwargs = dict(self.gen_kwargs)
        if max_tokens and "max_new_tokens" in gen_kwargs:
            gen_kwargs["max_new_tokens"] = min(max_tokens, gen_kwargs["max_new_tokens"])

        outputs = self.model.generate(**inputs, **gen_kwargs)
        
        # Handle different output formats
        if "glm" in self.tag:
            output_text = self.processor.decode(
                outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False
            )
        else:
            trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, outputs)]
            text = self.processor.batch_decode(
                trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            output_text = text[0]
        
        print(f'User: {prompt}\nAssistant: {output_text}')
        return output_text
   
    def _answer_ovis(self, image_path: str, prompt: str) -> str:
        """Handle ovis-specific answer generation"""
        enable_thinking = False
        enable_thinking_budget = False
        thinking_budget = 2048
        min_pixels = 200704
        max_pixels = 2408448

        msg = [{
            "role": "user",
            "content": [
                {"type": "image", "image": Image.open(image_path).convert("RGB")},
                {"type": "text", "text": prompt},
            ],
        }]

        input_ids, pixel_values, grid_thws = self.model.preprocess_inputs(
            messages=msg,
            add_generation_prompt=True,
            enable_thinking=enable_thinking,
            min_pixels=min_pixels,
            max_pixels=max_pixels,
        )
        
        input_ids = input_ids.to(self.device)
        pixel_values = pixel_values.to(self.device) if pixel_values is not None else None
        grid_thws = grid_thws.to(self.device) if grid_thws is not None else None

        outputs = self.model.generate(
            inputs=input_ids,
            pixel_values=pixel_values,
            grid_thws=grid_thws,
            enable_thinking=enable_thinking,
            enable_thinking_budget=enable_thinking_budget,
            thinking_budget=thinking_budget,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            use_cache=True,
            **{k: v for k, v in self.gen_kwargs.items() if k != 'max_new_tokens'}
        )

        output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return output_text