import torch
from abc import ABC
import uuid
from typing import List, Dict, Any, Union
from transformers import Qwen3OmniMoeForConditionalGeneration, AutoProcessor
from qwen_omni_utils import process_mm_info
from services.base_model import Truncate

class Qwen3OmniProcessor(Truncate):
    def __init__(self, model_path: str, **kwargs):
        super().__init__(model_path, **kwargs)
        self.model = None
        self.processor = None
        self.use_audio_in_video = kwargs.get("use_audio_in_video", True)

    def initialize_model(self) -> None:
        try:
            print(f"Loading model from {self.model_path}")
            self.processor = AutoProcessor.from_pretrained(self.model_path)

            self.model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
                self.model_path,
                torch_dtype="auto",
                device_map="auto",
                attn_implementation="sdpa",
            )
            self.model.eval()
        except Exception as e:
            raise RuntimeError(f"[ERROR] Failed to initialize model: {e}")
    
    def build_input(self, messages: List[Dict[str, Any]], video_info=None, **kwargs) -> str:
        self.contain_audio_tracks = all(
            vinfo.get("audio") for vinfo in video_info
        ) if video_info else True

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        audios, images, videos = process_mm_info(
            messages, use_audio_in_video=self.use_audio_in_video and self.contain_audio_tracks
        )

        inputs = self.processor(
            text=[text],
            audio=audios,
            images=images,
            videos=videos,
            return_tensors="pt",
            padding=True,
            use_audio_in_video=self.use_audio_in_video and self.contain_audio_tracks,
        ).to(device=self.model.device, dtype=self.model.dtype)

        return inputs
    
    def generate(self, messages: List[Dict[str, Any]], video_info=None, **kwargs):
        inputs = self.build_input(messages=messages, video_info=video_info)

        with torch.no_grad():
            text_ids, audio = self.model.generate(
                **inputs,
                speaker="Ethan",
                thinker_return_dict_in_generate=True,
                use_audio_in_video=self.use_audio_in_video and self.contain_audio_tracks,
            )

        input_len = inputs["input_ids"].shape[1]
        generated_ids_trimmed = text_ids.sequences[:, input_len:]
        generated_text = self.processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )[0]

        self.lightweight_gpu_reset()
        return generated_text

    def batch_generate(self, request_list: List[Dict[str, Any]], **kwargs) -> List[str]:
        return [self.generate(**req) for req in request_list]
