from __future__ import annotations

"""
MiniCPM-o-2_6 model runner.
"""

from typing import Any, List, Tuple

import torch

from .base import ModelRunner, register_model


@register_model("MiniCPM-o-2_6")
class MiniCPMRunner(ModelRunner):
    def __init__(
        self,
        model_path: str = "/path/to/MiniCPM-o-2_6",
        system_prompt: str | None = None,
        init_vision: bool = True,
        init_audio: bool = True,
        init_tts: bool = False,
        attn_implementation: str = "sdpa",
    ) -> None:
        self.model_path = model_path
        self.model = None
        self.tokenizer = None
        self.init_vision = init_vision
        self.init_audio = init_audio
        self.init_tts = init_tts
        self.attn_impl = attn_implementation
        self.system_prompt = system_prompt or (
            "You are an assistant tasked with solving multiple-choice questions that require logical"
            " reasoning over the supplied knowledge diagrams."
            "Use only the information explicitly given—do not rely on outside or commonsense knowledge."
            "Read the question and given information, think step-by-step, and answer the question."
            "At the end of your answer, answer precisely in the format 'Answer: X' where X is the chosen letter A / B / C / D."
        )

    def load_model(self) -> None:
        from transformers import AutoModel, AutoTokenizer, AutoProcessor

        self.model = AutoModel.from_pretrained(
            self.model_path,
            trust_remote_code=True,
            attn_implementation="sdpa",
            torch_dtype=torch.float32,
            init_vision=self.init_vision,
            init_audio=self.init_audio,
            init_tts=self.init_tts,
        )
        if torch.cuda.is_available():
            self.model = self.model.eval().to("cuda")
        else:
            self.model = self.model.eval()

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_path, trust_remote_code=True
        )
        try:
            self.processor = self.model.processor
        except Exception:
            self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)

    def build_conversation(self, user_content: list[dict], system_prompt = None) -> list[dict]:
        # Keep a consistent interface with other runners (system + user entries).
        if system_prompt is None:
            system_prompt = self.system_prompt
        return [
            {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
            {"role": "user", "content": user_content},
        ]

    def build_minicpm_msgs(self, conversation: list[dict]) -> Tuple[list[dict], list[Any], list[Any], list[int], str]:
        """
        Convert our conversation into MiniCPM chat inputs.
        Returns: (copy_msgs, images, audios, audio_parts, prompt_text)
        where copy_msgs are modified messages with placeholders inserted.
        """
        from PIL import Image
        import librosa

        copy_msgs = []
        images: List[Any] = []
        audios: List[Any] = []
        audio_parts: List[int] = []

        # Ensure a shallow copy of conversation
        for msg in conversation:
            msg_copy = {"role": msg.get("role", "user"), "content": msg.get("content", [])}
            copy_msgs.append(msg_copy)

        # Insert placeholders in the content while collecting media
        for i, msg in enumerate(copy_msgs):
            role = msg.get("role")
            content = msg.get("content", [])
            if isinstance(content, dict):
                content = [content]
            cur_msgs = []
            for c in content:
                typ = c.get("type") if isinstance(c, dict) else None
                if typ == "image":
                    path = c.get("image") or c.get("path")
                    if path:
                        img = Image.open(path).convert("RGB")
                        images.append(img)
                        cur_msgs.append("(<image>./</image>)")
                elif typ == "audio":
                    path = c.get("audio") or c.get("path")
                    if path:
                        wav, _ = librosa.load(path, sr=16000, mono=True)
                        audios.append(wav)
                        audio_parts.append(i)
                        cur_msgs.append("(<audio>./</audio>)")
                elif typ == "text":
                    txt = c.get("text") if isinstance(c, dict) else str(c)
                    if txt:
                        cur_msgs.append(txt)
            # Join with newlines (not omni concatenation) to match non-omni text
            msg["content"] = "\n".join(cur_msgs)

        # Apply chat template to get prompt string
        prompt_text = self.processor.tokenizer.apply_chat_template(
            copy_msgs, tokenize=False, add_generation_prompt=True
        )
        return copy_msgs, images, audios, audio_parts, prompt_text


    def run_model(self, conversation: Any) -> str:
        if self.model is None or self.tokenizer is None:
            raise RuntimeError("MiniCPMRunner is not loaded. Call load_model() first.")

        msgs, images, audios, audio_parts, prompt_text = self.build_minicpm_msgs(conversation)

        inputs = self.processor(
            text=[prompt_text],
            images=[images],
            audios=[audios],
            audio_parts=[audio_parts],
            return_tensors="pt",
            max_slice_nums=None,
            use_image_id=None,
            chunk_input=True,
        ).to(self.model.device)

        gen_inputs = dict(inputs)

        gen_inputs.pop("image_sizes", None)

        with torch.no_grad():
            result_texts, outputs = self.model.generate(
                **gen_inputs,
                tokenizer=self.tokenizer,
                max_new_tokens=1024,
                decode_text=True,
            )

        gen_ids = outputs.sequences 
        reply = self.processor.decode(gen_ids[0])
        reply = str(reply).strip()

        return reply

