from __future__ import annotations

"""
Qwen2.5-Omni model runner.
"""

from typing import Any

import torch

from .base import ModelRunner, register_model


@register_model("Qwen2.5-Omni")
class QwenOmniRunner(ModelRunner):
    def __init__(self, model_path: str = "/path/to/Qwen2.5-Omni-7B", system_prompt: str | None = None) -> None:
        self.model_path = model_path
        self.model = None
        self.processor = None
        self.system_prompt = system_prompt or (
            "You are an assistant tasked with solving multiple-choice questions that require logical"
            " reasoning over the supplied knowledge diagrams."
            "Use only the information explicitly given—do not rely on outside or commonsense knowledge."
            "Read the question and given information, think step-by-step, and answer the question."
            "At the end of your answer, answer precisely in the format 'Answer: X' where X is the chosen letter A / B / C / D."
        )

    def load_model(self) -> None:
        from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor

        self.model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
            self.model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            attn_implementation="sdpa",
        )
        self.processor = Qwen2_5OmniProcessor.from_pretrained(self.model_path)

    def build_conversation(self, user_content: list[dict], system_prompt = None) -> list[dict]:
        if system_prompt is None:
            system_prompt = self.system_prompt
        return [
            {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
            {"role": "user", "content": user_content},
        ]

    def run_model(self, conversation: Any) -> str:
        if self.model is None or self.processor is None:
            raise RuntimeError("QwenOmniRunner is not loaded. Call load_model() first.")

        from qwen_omni_utils import process_mm_info

        prompt_text = self.processor.apply_chat_template(
            conversation, add_generation_prompt=True, tokenize=False
        )
        audios, images, _ = process_mm_info(conversation, use_audio_in_video=False)

        inputs = self.processor(
            text=prompt_text,
            images=images,
            audio=audios,
            return_tensors="pt",
            padding=True,
            use_audio_in_video=False,
        ).to(self.model.device).to(self.model.dtype)

        with torch.no_grad():
            text_ids = self.model.generate(
                **inputs, use_audio_in_video=False, return_audio=False
            )
        prompt_len = inputs["input_ids"].shape[1]
        reply_ids = text_ids[:, prompt_len:]

        reply = self.processor.batch_decode(
            reply_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0].strip()
        return reply
