from __future__ import annotations

"""
Baichuan-Omni-1d5 model runner.
"""

from typing import Any, List, Tuple
import json

import torch

from .base import ModelRunner, register_model

def _build_mm_content_str(processor, user_content: list[dict]) -> str:
    parts: List[str] = []
    img_start = getattr(processor, "image_start_tag", None)
    img_end = getattr(processor, "image_end_tag", None)
    aud_start = getattr(processor, "audio_start_tag", None)
    aud_end = getattr(processor, "audio_end_tag", None)

    for item in user_content:
        typ = item.get("type")
        if typ == "text":
            t = item.get("text", "")
            if t:
                parts.append(t)
                if not t.endswith("\n"):
                    parts.append("\n")
        elif typ == "image" and img_start and img_end:
            path = item.get("image") or item.get("path")
            if path:
                payload = json.dumps({"path": path}, ensure_ascii=False)
                parts.append(f"{img_start}{payload}{img_end}")
        elif typ == "audio" and aud_start and aud_end:
            path = item.get("audio") or item.get("path")
            if path:
                payload = json.dumps({"path": path}, ensure_ascii=False)
                parts.append(f"{aud_start}{payload}{aud_end}")

    return "".join(parts).strip()


@register_model("Baichuan-Omni")
class BaichuanOmniRunner(ModelRunner):
    def __init__(
        self,
        model_path: str = "/path/to/Baichuan-Omni-1d5",
        system_prompt: str | None = None,
    ) -> None:
        self.model_path = model_path
        self.model = None
        self.tokenizer = None
        self.mm_processor = None
        self.system_prompt = system_prompt or (
            "You are an assistant tasked with solving multiple-choice questions that require logical"
            " reasoning over the supplied knowledge diagrams."
            "Use only the information explicitly given—do not rely on outside or commonsense knowledge."
            "Read the question and given information, think step-by-step and answer the question."
            "At the end of your answer, answer precisely in the format 'Answer: X' where X is the chosen letter A / B / C / D."
        )

    def load_model(self) -> None:
        from transformers import AutoModelForCausalLM, AutoTokenizer

        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
            attn_implementation="eager",
            local_files_only=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_path, trust_remote_code=True, use_fast=True)
        try:
            self.mm_processor = self.model.bind_processor(self.tokenizer, training=False)
        except Exception:
            self.mm_processor = getattr(self.model, "processor", None)

        gen_cfg = getattr(self.model, "generation_config", None)
        if gen_cfg is not None:
            def _ensure_role_id(attr: str, token_strs: List[str]):
                cur = getattr(gen_cfg, attr, None)
                if cur is None:
                    for t in token_strs:
                        tid = self.tokenizer.convert_tokens_to_ids(t)
                        if isinstance(tid, int) and tid >= 0:
                            setattr(gen_cfg, attr, tid)
                            return tid
                return cur

            _ensure_role_id("user_token_id", ["<C_Q>"])
            _ensure_role_id("assistant_token_id", ["<C_A>"])

    def build_conversation(self, user_content: list[dict], system_prompt) -> list[dict]:
        if system_prompt is None:
            system_prompt = self.system_prompt
        return [
            {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
            {"role": "user", "content": user_content},
        ]

    def run_model(self, conversation: Any) -> str:
        if self.model is None or self.tokenizer is None or self.mm_processor is None:
            raise RuntimeError("BaichuanOmniRunner is not loaded. Call load_model() first.")
        
        # Build content string with inline multimodal tags from processor
        user_msg = next((m for m in conversation if m.get("role") == "user"), None)
        system_msg = next((m for m in conversation if m.get("role") == "system"), None)

        user_content = user_msg.get("content", []) if user_msg else []
        system_content = system_msg.get("content", []) if system_msg else []

        system_str = system_content[0].get("text")
        if isinstance(user_content, dict):
            user_content = [user_content]
        content_str = _build_mm_content_str(self.mm_processor, user_content)
        content_str = f"{system_str}\n\n{content_str}" if self.system_prompt else content_str

        # Process via model-bound processor to obtain tensors and MM features
        proc_out = self.mm_processor([content_str])
        if proc_out.input_ids is None:
            raise RuntimeError("Baichuan Omni processor failed to build inputs for the given conversation.")
        input_ids = proc_out.input_ids.to(self.model.device)
        attention_mask = proc_out.attention_mask.to(self.model.device) if getattr(proc_out, "attention_mask", None) is not None else None

        # Prepend user token id (chat format)
        usr_id = getattr(self.model.generation_config, "user_token_id", None)
        if usr_id is not None:
            usr = torch.tensor([[usr_id]], device=input_ids.device, dtype=input_ids.dtype)
            input_ids = torch.cat([usr, input_ids], dim=1)
            if attention_mask is not None:
                am_head = torch.ones((attention_mask.size(0), 1), device=attention_mask.device, dtype=attention_mask.dtype)
                attention_mask = torch.cat([am_head, attention_mask], dim=1)

        asst_id = getattr(self.model.generation_config, "assistant_token_id", None)
        if asst_id is not None:
            asst = torch.tensor([[asst_id]], device=input_ids.device, dtype=input_ids.dtype)
            input_ids = torch.cat([input_ids, asst], dim=1)
            if attention_mask is not None:
                am_pad = torch.ones((attention_mask.size(0), 1), device=attention_mask.device, dtype=attention_mask.dtype)
                attention_mask = torch.cat([attention_mask, am_pad], dim=1)
        gen_kwargs = dict(
            inputs=input_ids,
            attention_mask=attention_mask,
        )

        def _move_to_device(obj, device):
            if isinstance(obj, torch.Tensor):
                return obj.to(device)
            if isinstance(obj, (list, tuple)):
                return type(obj)(_move_to_device(x, device) for x in obj)
            if isinstance(obj, dict):
                return {k: _move_to_device(v, device) for k, v in obj.items()}
            return obj

        for key in ("audios", "encoder_length", "bridge_length", "images", "patch_nums", "images_grid", "videos", "videos_patch_nums", "videos_grid"):
            val = getattr(proc_out, key, None)
            if val is not None:
                gen_kwargs[key] = _move_to_device(val, self.model.device)
        with torch.no_grad():
            text_ids = self.model.generate(**gen_kwargs)

        prompt_len = input_ids.shape[1]
        reply_ids = text_ids[:, prompt_len:]

        reply = self.tokenizer.batch_decode(
            reply_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

        return reply.strip()
