#!/usr/bin/env python3
"""
Minimal Qwen2.5-Omni ASR script with beam search and probability emission.
Updated to use Qwen2.5-Omni-3B for audio transcription.
"""

import torch
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
from datasets import load_dataset
import numpy as np
import librosa
from qwen_omni_utils import process_mm_info

def load_model_and_processor(model_name="Qwen/Qwen2.5-Omni-3B"):
    """Load Qwen2.5-Omni model and processor."""
    processor = Qwen2_5OmniProcessor.from_pretrained(model_name)
    model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    # Disable audio output since we only need transcription
    model.disable_talker()
    return model, processor

def process_audio_with_beam_search(model, processor, audio_array, sample_rate, 
                                 prompt_text=None, num_beams=5):
    """Process audio with beam search and return probabilities."""
    
    # Resample audio if needed (Qwen2.5-Omni typically uses 16kHz)
    target_sample_rate = 16000
    if sample_rate != target_sample_rate:
        print(f"Resampling audio from {sample_rate} Hz to {target_sample_rate} Hz")
        audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sample_rate)
        sample_rate = target_sample_rate
    
    # Prepare conversation format for Qwen2.5-Omni with ASR-focused system prompt
    if prompt_text:
        conversation = [
            {
                "role": "system",
                "content": [
                    {"type": "text", "text": "You are an expert audio transcription system. Your task is to accurately transcribe the provided audio content into text."}
                ],
            },
            {
                "role": "user", 
                "content": [
                    {"type": "audio", "audio": audio_array},
                    {"type": "text", "text": f"{prompt_text} Please transcribe this audio accurately into written form. add named entities <NOUN> tag just adjacent to the word"}
                ]
            }
        ]
    else:
        conversation = [
            {
                "role": "system",
                "content": [
                    #{"type": "text", "text": "You are an expert audio transcription system. Your task is to accurately transcribe the provided audio content into text."}
                    {"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
                ],
            },
            {
                "role": "user", 
                "content": [
                    {"type": "audio", "audio": audio_array},
                    #{"type": "text", "text": "Please transcribe this audio accurately into written form."}
                    {"type": "text", "text": "find all named entities in the speech"}
                ]
            }
        ]
    
    # Apply chat template and process with qwen_omni_utils
    text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
    
    # Process inputs
    inputs = processor(
        text=text, 
        audio=audios, 
        images=images, 
        videos=videos,
        return_tensors="pt", 
        padding=True, 
        use_audio_in_video=False
    )
    
    # Move inputs to model device
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate with beam search - text only for transcription
    generation_kwargs = {
        "num_beams": num_beams,
        "num_return_sequences": num_beams,
        "return_dict_in_generate": True,
        "output_scores": True,
        "max_new_tokens": 256,
        "do_sample": False,
        "temperature": 1.0,
        "pad_token_id": processor.tokenizer.pad_token_id,
        "eos_token_id": processor.tokenizer.eos_token_id,
        "return_audio": False,  # Only text output for transcription
        "use_audio_in_video": False
    }
    
    # Generate with beam search
    with torch.no_grad():
        try:
            generated_outputs = model.generate(
                **inputs,
                **generation_kwargs
            )
        except Exception as e:
            print(f"Generation failed with full config, trying simplified version: {e}")
            # Fallback to simpler generation
            generated_outputs = model.generate(
                **inputs,
                num_beams=num_beams,
                num_return_sequences=num_beams,
                return_dict_in_generate=True,
                output_scores=True,
                max_new_tokens=256,
                return_audio=False,
                use_audio_in_video=False
            )
    
    # Extract sequences and scores
    sequences = generated_outputs.sequences
    scores = generated_outputs.sequences_scores
    
    # Convert to probabilities (scores are log probabilities)
    probabilities = torch.softmax(scores, dim=0)
    
    # Decode sequences (skip the input tokens)
    input_length = inputs['input_ids'].shape[1]
    generated_sequences = sequences[:, input_length:]
    
    # Decode the generated text
    transcriptions = processor.batch_decode(generated_sequences, skip_special_tokens=True)
    
    # Clean up transcriptions (remove any chat formatting artifacts)
    cleaned_transcriptions = []
    for transcription in transcriptions:
        # Remove common chat artifacts and clean up
        cleaned = transcription.strip()
        # Remove any remaining special tokens or formatting
        if cleaned.startswith("assistant"):
            cleaned = cleaned[9:].strip()
        if cleaned.startswith(":"):
            cleaned = cleaned[1:].strip()
        # Remove any system-like responses and focus on transcription
        if cleaned.lower().startswith("the audio"):
            # Sometimes the model might start with "The audio contains..." - extract the actual transcription
            lines = cleaned.split('\n')
            for line in lines:
                if line.strip() and not line.lower().startswith(('the audio', 'this audio', 'i can hear')):
                    cleaned = line.strip()
                    break
        cleaned_transcriptions.append(cleaned)
    
    return cleaned_transcriptions, probabilities.cpu().tolist()

def main():
    """Main function to demonstrate Qwen2.5-Omni ASR with beam search."""
    
    # Load model and processor
    print("Loading Qwen2.5-Omni model...")
    model, processor = load_model_and_processor()
    
    # Load audio dataset
    print("Loading audio dataset...")
    dataset = load_dataset("prdeepakbabu/AgenticASR-MultiVoice-Enhanced", "default", split="train")
    
    # Take first audio sample
    audio_sample = dataset[0]
    audio_array = audio_sample["audio"]["array"]
    sample_rate = audio_sample["audio"]["sampling_rate"]
    
    print(f"Processing audio sample (length: {len(audio_array)} samples, rate: {sample_rate} Hz)")
    
    # Example 1: Without prompt
    print("\n=== Qwen2.5-Omni ASR Beam Search Results (No Prompt) ===")
    transcriptions, probabilities = process_audio_with_beam_search(
        model, processor, audio_array, sample_rate
    )
    
    for i, (transcription, prob) in enumerate(zip(transcriptions, probabilities)):
        print(f"Beam {i+1} (prob: {prob:.4f}): {transcription}")
    
    # Example 2: With prompt text
    print("\n=== Qwen2.5-Omni ASR Beam Search Results (With Prompt) ===")
    prompt = "This is a financial analysis audio."
    transcriptions_prompt, probabilities_prompt = process_audio_with_beam_search(
        model, processor, audio_array, sample_rate, prompt_text=prompt
    )
    
    for i, (transcription, prob) in enumerate(zip(transcriptions_prompt, probabilities_prompt)):
        print(f"Beam {i+1} (prob: {prob:.4f}): {transcription}")
    
    # Show ground truth for comparison
    if "text" in audio_sample:
        print(f"\nGround truth: {audio_sample['text']}")
    else:
        print("\nGround truth: Not available in dataset")

if __name__ == "__main__":
    main()
