from __future__ import annotations

"""
Phi-4 Multimodal model runner.
"""

from typing import Any, Optional, Tuple, List

import torch
from PIL import Image
import soundfile as sf

from .base import ModelRunner, register_model


def _prepare_prompt_and_media(
    conversation: list[dict], system_prompt: str
) -> Tuple[str, List[Any], List[Any]]:
    pieces: list[str] = []
    images: List[Any] = []
    audios: List[Any] = []

    img_idx = 0
    aud_idx = 0

    pieces.append("<|user|>")
    if isinstance(system_prompt, str) and system_prompt.strip():
        pieces.append(system_prompt.strip() + "\n\n")

    for msg in conversation:
        role = msg.get("role")
        if role == "system":
            continue
        content = msg.get("content", [])
        if isinstance(content, dict):
            content = [content]
        for part in content:
            typ = part.get("type")
            if typ == "text":
                t = part.get("text")
                if isinstance(t, str) and t:
                    pieces.append(t)
                    if not t.endswith("\n"):
                        pieces.append("\n")
            elif typ == "image":
                path = part.get("image") or part.get("path")
                if path:
                    try:
                        img = Image.open(path).convert("RGB")
                        images.append(img)
                        img_idx += 1
                        pieces.append(f"<|image_{img_idx}|>")
                    except Exception:
                        pass
            elif typ == "audio":
                path = part.get("audio") or part.get("path")
                if path:
                    try:
                        wav, sr = sf.read(path)
                        audios.append((wav, sr))
                        aud_idx += 1
                        pieces.append(f"<|audio_{aud_idx}|>")
                    except Exception:
                        pass

    pieces.append("<|end|><|assistant|>")
    prompt = "".join(pieces)
    return prompt, images, audios


@register_model("Phi-4-Multimodal")
class Phi4OmniRunner(ModelRunner):
    def __init__(
        self,
        model_path: str = "/path/to/Phi-4-multimodal-instruct",
        system_prompt: str | None = None,
    ) -> None:
        self.model_path = model_path
        self.model = None
        self.processor = None
        self.generation_config = None
        self.system_prompt = system_prompt or (
            "You are an assistant tasked with solving multiple-choice questions that require logical"
            " reasoning over the supplied knowledge diagrams."
            "Use only the information explicitly given—do not rely on outside or commonsense knowledge."
            "Read the question and given information, think step-by-step, and answer the question."
            "At the end of your answer, answer precisely in the format 'Answer: X' where X is the chosen letter A / B / C / D."
        )

    def load_model(self) -> None:
        from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
        self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)

        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            trust_remote_code=True,
            torch_dtype=torch.float16,
            _attn_implementation="sdpa",
            device_map="auto",
        )

        try:
            self.generation_config = GenerationConfig.from_pretrained(
                self.model_path, "generation_config.json"
            )
        except Exception:
            self.generation_config = None

    def build_conversation(self, user_content: list[dict]) -> list[dict]:
        return [
            {"role": "system", "content": [{"type": "text", "text": self.system_prompt}]},
            {"role": "user", "content": user_content},
        ]

    def _build_phi4_prompt(self, conversation: list[dict]) -> tuple[str, list[Any], list[Any]]:
        return _prepare_prompt_and_media(conversation, self.system_prompt)

    def run_model(self, conversation: Any) -> str:
        if self.model is None or self.processor is None:
            raise RuntimeError("Phi4OmniRunner is not loaded. Call load_model() first.")

        prompt, images, audios = self._build_phi4_prompt(conversation)

        inputs = self.processor(
            text=prompt,
            images=images if images else None,
            audios=audios if audios else None,
            return_tensors="pt",
        ).to(self.model.device)

        inputs = inputs.to(self.model.dtype)

        gen_kwargs = {}
        if self.generation_config is not None:
            gen_kwargs["generation_config"] = self.generation_config

        with torch.no_grad():
            generate_ids = self.model.generate(
                **inputs,
                **gen_kwargs,
                max_new_tokens=1024,
            )

        if "input_ids" in inputs:
            prompt_len = inputs["input_ids"].shape[1]
            reply_ids = generate_ids[:, prompt_len:]
        else:
            reply_ids = generate_ids

        reply = self.processor.batch_decode(
            reply_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

        for marker in ("<|im_end|>", "<|turn_end|>", "</s>", "<eos>"):
            reply = reply.replace(marker, "")
        return reply.strip()
