# test_inference.py
import torch
import soundfile as sf
import os
from loguru import logger
import numpy as np

# Ensure correct import per your project structure
from kimia_infer.api.kimia import KimiAudio

def run_real_model_verification():
    """
    Verifies the full pipeline using the actual Kimi-Audio-7B model.
    1. Loads the model via the final KimiAudioModel class.
    2. Applies and verifies the parameter freezing policy.
    3. Uses the KimiAudio API class to run a multi-modal inference task.
    """
    logger.info("--- Starting Real Model Verification Suite ---")

    # --- 1. Configuration ---
    # Use env vars to avoid exposing local paths in submissions; fall back to your originals.
    MODEL_PATH = "KIMIA_VERIFY_MODEL_PATH"

    TEST_AUDIO_PATHS = ["KIMIA_VERIFY_AUDIO_DIR"]

    logger.info("\n--- STEP 2: Verifying End-to-End Inference with KimiAudio API ---")
    try:
        # Initialize KimiAudio API for inference
        logger.info(f"Initializing KimiAudio API from '{MODEL_PATH}'...")
        kimia_api = KimiAudio(model_path=MODEL_PATH, load_detokenizer=False)

        logger.success("KimiAudio API initialized successfully.")

        # Run inference for each audio file
        for idx, audio_path in enumerate(TEST_AUDIO_PATHS, 1):
            logger.info(f"\n--- Processing Audio {idx}/{len(TEST_AUDIO_PATHS)}: {audio_path} ---")
            
            # Check audio file existence
            if not os.path.exists(audio_path):
                logger.warning(f"Audio file not found: {audio_path}, skipping...")
                continue
            
            chats = [
                {
                    "role": "user",
                    "message_type": "text",
                    "content": (
                        "Describe the audio content in detail.\n"
                    ),
                },
                {
                    "role": "user",
                    "message_type": "audio",
                    "content": audio_path,
                },
            ]
            
            logger.info(f"Running inference for audio: {os.path.basename(audio_path)}")
            
            try:
                # Use greedy sampling (temperature=0.0) for reproducibility
                with torch.inference_mode():
                    generated_wav, generated_text = kimia_api.generate(
                        chats=chats,
                        output_type="text",
                        text_temperature=0.0,
                        max_new_tokens=500,
                    )
                
                logger.success(f"Inference completed for audio {idx}")
                
                # Print results to stdout for human verification
                print("\n" + "=" * 70)
                print(f"INFERENCE RESULTS FOR AUDIO {idx}: {os.path.basename(audio_path)}")
                print("-" * 70)
                print(f"Input Audio: {audio_path}")
                print(f"Generated Text:\n---\n{generated_text}\n---")
                print("=" * 70)
                
            except Exception as e:
                logger.error(f"Failed to process audio {idx} ({audio_path}): {e}")
                import traceback
                traceback.print_exc()
                continue

    except Exception as e:
        logger.error(f"STEP 2 failed during API initialization: {e}")
        import traceback
        traceback.print_exc()
        return

    logger.info("\n--- Real Model Verification Suite Completed Successfully! ---")

if __name__ == "__main__":
    run_real_model_verification()
