from __future__ import annotations

"""Qwen3-VL-8B-Thinking wrapper for the submission package.

Design goals:
- Single, readable file.
- Supports only Qwen3-VL-8B-Thinking.
- Exposes just what inference needs:
  build inputs -> generate -> find <think> span -> forward with attentions.
"""

import os
from dataclasses import dataclass
from typing import Any, Optional, Sequence

import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor

from qwen_vl_utils import process_vision_info


MODEL_ID = "Qwen/Qwen3-VL-8B-Thinking"


@dataclass(frozen=True)
class ModelInputs:
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
    pixel_values: Optional[torch.Tensor] = None
    image_grid_thw: Optional[torch.Tensor] = None


@dataclass(frozen=True)
class GenerationSpans:
    think_start: Optional[int]
    think_end: Optional[int]

    def thought_span(self, *, seq_len: int) -> Optional[tuple[int, int]]:
        if self.think_start is None:
            return None
        start = int(self.think_start) + 1
        end = int(seq_len) if self.think_end is None else int(self.think_end)
        if end < start:
            raise ValueError("Invalid think span")
        return start, end


def _find_first(seq: Sequence[int], value: int, start: int = 0) -> Optional[int]:
    for i in range(int(start), len(seq)):
        if int(seq[i]) == int(value):
            return int(i)
    return None


class Qwen8B:
    def __init__(
        self,
        *,
        weights_dir: Optional[str],
        device: Optional[str] = None,
        allow_download: bool = False,
        attn_implementation: str = "eager",
        dtype: torch.dtype = torch.bfloat16,
    ):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.allow_download = bool(allow_download)
        self.local_files_only = not self.allow_download

        if self.local_files_only:
            os.environ.setdefault("HF_HUB_OFFLINE", "1")
            os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")

        kwargs: dict[str, Any] = {
            "trust_remote_code": True,
            "device_map": {"": self.device},
            "dtype": dtype,
            "attn_implementation": str(attn_implementation),
            "local_files_only": self.local_files_only,
        }
        if weights_dir is not None:
            kwargs["cache_dir"] = str(weights_dir)

        self.model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **kwargs)
        self.processor = AutoProcessor.from_pretrained(
            MODEL_ID,
            trust_remote_code=True,
            local_files_only=self.local_files_only,
            cache_dir=str(weights_dir) if weights_dir is not None else None,
        )
        self.tokenizer = self.processor.tokenizer

        self._think_id = self.tokenizer.convert_tokens_to_ids("<think>")
        self._end_think_id = self.tokenizer.convert_tokens_to_ids("</think>")
        self._image_pad_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>")

    def build_inputs_and_image(
        self,
        *,
        image: Image.Image,
        question: str,
        system_prompt: Optional[str] = None,
        force_think: bool = True,
    ) -> tuple[ModelInputs, Image.Image]:
        messages = []
        if system_prompt:
            messages.append(
                {"role": "system", "content": [{"type": "text", "text": system_prompt}]}
            )

        messages.append(
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": question},
                ],
            }
        )

        prompt_text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        if force_think:
            tail = prompt_text.rstrip()
            if tail.endswith("<think>"):
                if not prompt_text.endswith("\n"):
                    prompt_text = prompt_text + "\n"
            else:
                prompt_text = prompt_text + "<think>\n"

        image_inputs, video_inputs = process_vision_info(messages)
        prepared_image = image_inputs[0] if image_inputs else image

        inputs = self.processor(
            text=[prompt_text],
            images=image_inputs,
            videos=video_inputs,
            padding=False,
            return_tensors="pt",
        )
        inputs = inputs.to(self.device)
        return (
            ModelInputs(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                pixel_values=inputs.get("pixel_values"),
                image_grid_thw=inputs.get("image_grid_thw"),
            ),
            prepared_image,
        )

    def generate(
        self,
        inputs: ModelInputs,
        *,
        max_new_tokens: int,
        stop_at_end_think: bool = True,
    ) -> tuple[list[int], int]:
        self.model.eval()
        prompt_len = int(inputs.input_ids.shape[1])
        with torch.no_grad():
            full_ids = self.model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                pixel_values=inputs.pixel_values,
                image_grid_thw=inputs.image_grid_thw,
                max_new_tokens=int(max_new_tokens),
                do_sample=False,
            )
        ids = full_ids[0].tolist()
        if stop_at_end_think:
            end_pos = _find_first(ids, self._end_think_id, start=prompt_len)
            if end_pos is not None:
                ids = ids[: end_pos + 1]
        return ids, prompt_len

    def find_spans(
        self, full_input_ids: Sequence[int], *, prompt_len: int
    ) -> GenerationSpans:
        ids = list(full_input_ids)
        prompt_len = int(prompt_len)

        gen_thinks = [
            i for i in range(prompt_len, len(ids)) if ids[i] == self._think_id
        ]
        if gen_thinks:
            think_start = int(gen_thinks[0])
        else:
            prompt_thinks = [
                i for i, tok in enumerate(ids[:prompt_len]) if tok == self._think_id
            ]
            think_start = int(prompt_thinks[-1]) if prompt_thinks else None

        if think_start is None:
            return GenerationSpans(think_start=None, think_end=None)

        # Critical detail: only search for </think> after prompt_len to avoid instruction-text matches.
        search_start = max(int(think_start) + 1, int(prompt_len))
        ends = [
            i for i in range(search_start, len(ids)) if ids[i] == self._end_think_id
        ]
        think_end = int(ends[-1]) if ends else None
        return GenerationSpans(think_start=int(think_start), think_end=think_end)

    def decode(self, ids: Sequence[int]) -> str:
        return self.tokenizer.decode(list(ids), skip_special_tokens=False)

    def find_vision_tokens(self, full_input_ids: Sequence[int]) -> list[int]:
        return [
            i
            for i, tok in enumerate(full_input_ids)
            if int(tok) == int(self._image_pad_id)
        ]

    def build_vision_token_map(
        self, full_input_ids: Sequence[int], *, inputs: ModelInputs
    ) -> tuple[list[int], int, int]:
        token_positions = self.find_vision_tokens(full_input_ids)
        info = inputs.image_grid_thw
        if info is None:
            side = int(len(token_positions) ** 0.5)
            if side * side != len(token_positions):
                raise ValueError(
                    "Cannot infer vision grid (non-square and image_grid_thw missing)"
                )
            return token_positions, side, side
        if info.ndim != 2 or int(info.shape[0]) != 1:
            raise ValueError("Only single-image inputs are supported")

        _, h, w = [int(x) for x in info[0].tolist()]
        num_tokens = len(token_positions)
        if num_tokens == h * w:
            return token_positions, h, w
        if num_tokens * 4 == h * w:
            return token_positions, h // 2, w // 2
        raise ValueError(
            f"Unexpected vision token count: {num_tokens} for image_grid_thw (h={h}, w={w})"
        )

    def compute_position_ids(
        self,
        *,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        inputs: ModelInputs,
    ) -> torch.Tensor:
        # Qwen3-VL uses M-RoPE position IDs (3, batch, seq_len). The implementation is in inner model.
        inner = self.model.model
        pos, _ = inner.get_rope_index(
            input_ids=input_ids,
            image_grid_thw=inputs.image_grid_thw,
            video_grid_thw=None,
            attention_mask=attention_mask,
        )
        return pos

    def forward_with_attentions(
        self,
        *,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        position_ids: torch.Tensor,
        inputs: ModelInputs,
    ) -> Any:
        self.model.eval()
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            pixel_values=inputs.pixel_values,
            image_grid_thw=inputs.image_grid_thw,
            output_attentions=True,
            use_cache=False,
        )

    def forward_ablation(
        self,
        *,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        position_ids: torch.Tensor,
        inputs_embeds: torch.Tensor,
    ) -> Any:
        self.model.eval()
        return self.model(
            input_ids=None,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            use_cache=False,
        )

    def prepare_inputs_embeds(self, inputs: ModelInputs) -> torch.Tensor:
        inner = self.model.model
        inputs_embeds = inner.embed_tokens(inputs.input_ids)
        if inputs.pixel_values is not None:
            pixel_values = inputs.pixel_values.type(inner.visual.dtype)
            image_embeds = inner.visual(pixel_values, grid_thw=inputs.image_grid_thw)
            image_mask = inputs.input_ids == inner.config.image_token_id
            inputs_embeds = inputs_embeds.clone()
            inputs_embeds[image_mask] = image_embeds
        return inputs_embeds
