import os
import json
import argparse
import time
import torch
import multiprocessing as mp
import traceback
import math
import re  
import sys
from tqdm import tqdm
from transformers import Qwen2_5OmniProcessor
from qwen_omni_utils import process_mm_info
from evaluation.utils import parse_single_choice_response
    

def chat_and_measure(input_modality, file_path, prompt, sys_prompt, model, processor, model_path):
    conversation = [
        {
            "role": "system",
            "content": [{"type": "text", "text": sys_prompt}],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": input_modality,
                    input_modality: file_path
                },
                {"type": "text", "text": prompt},
            ],
        },
    ]

    USE_AUDIO_IN_VIDEO = True

    text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)

    inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=USE_AUDIO_IN_VIDEO)
    inputs = inputs.to(model.device).to(model.dtype)

    with torch.no_grad():
        text_ids = model.generate(
            **inputs, 
            use_audio_in_video=USE_AUDIO_IN_VIDEO, 
            do_sample=False, 
            return_audio=False,
            max_new_tokens=2048,
            use_cache=True
        )
    
    torch.cuda.synchronize()
    
    try:
        generated_text = processor.batch_decode(
        text_ids[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    except Exception as e:
        print(f"[Decode Error] Type: {type(output_ids)}, Data: {output_ids}")
        raise e

    model_generation = generated_text.split("What do you think")[0].split("So, what")[0].replace("Well, from what I can hear,", "").replace("Well, from what you've said,", "").strip()

    return model_generation

def worker_proc(rank, gpu_id, args, task_chunk, out_path):
    device_map = {"": f"cuda:{gpu_id}"}
    torch.cuda.set_device(gpu_id)
    sys_prompt = "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."
    from avcompression_transformer.modeling_qwen2_5_omni import Qwen2_5OmniForConditionalGeneration
    compression_config = {"rho_audio": args.rho_audio, "rho_video": args.rho_video}

    from transformers import Qwen2_5OmniProcessor
    model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
        attn_implementation="flash_attention_2",
    )
    model.disable_talker()

    model.thinker.compression_config = compression_config
        
    processor = Qwen2_5OmniProcessor.from_pretrained(args.model_path)
    
        
    with open(out_path, "w", encoding="utf-8") as fout:
        for item in tqdm(task_chunk, desc=f"Worker-{rank}[GPU-{gpu_id}]"):
            video_path = item['video_path']
            prompt = item['raw_data'].get('prompt', '')
            video_id = item['video_id']
            ground_truth = item['raw_data'].get('Answer', '')
            
            try:
                response, metrics = chat_and_measure(
                    "video", video_path, prompt, sys_prompt, model, processor, args.model_path
                )
                
                out_data = dict(item["raw_data"])
                out_data["output"] = response       
                out_data["video_path"] = video_path     

                fout.write(json.dumps(out_data, ensure_ascii=False) + "\n")
                fout.flush()
            except Exception as e:
                print(f"[Worker-{rank}] Error on {video_id}: {e}")
                traceback.print_exc()

    print(f"[Worker-{rank}] Done. Results -> {out_path}")

def run_multi_gpu(args):
    print(f"Loading data from {args.qa_path}...")
    try:
        with open(args.qa_path, 'r', encoding='utf-8') as f:
            qa_data = json.load(f)
    except Exception as e:
        print(f"Error loading JSON: {e}")
        return

    task_list = []
    missing_count = 0
    
    for item in qa_data:
        vid = item['video_id']
        video_path = os.path.join(args.video_base_dir, vid, f"{vid}_video.mp4")
        
        if os.path.exists(video_path):
            task_list.append({
                "video_id": vid,
                "video_path": video_path,
                "raw_data": item
            })
        else:
            missing_count += 1

    print(f"Total tasks loaded: {len(task_list)}. Missing videos: {missing_count}")

    num_gpus = args.num_gpus
    chunk_size = math.ceil(len(task_list) / num_gpus)
    if chunk_size == 0:
        chunk_size = 1
    chunks = [task_list[i:i+chunk_size] for i in range(0, len(task_list), chunk_size)]

    processes = []
    tmp_files = []
    
    out_dir = os.path.dirname(args.fout_path)
    if out_dir and not os.path.exists(out_dir):
        os.makedirs(out_dir)

    for rank, chunk in enumerate(chunks):
        if not chunk: continue
        if rank >= num_gpus: break # 防止切片多于GPU数量（极少情况）

        gpu_id = rank % num_gpus
        tmp_out = args.fout_path.replace(".jsonl", f".part{rank}.jsonl")
        tmp_files.append(tmp_out)

        p = mp.Process(
            target=worker_proc,
            args=(rank, gpu_id, args, chunk, tmp_out)
        )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    print("Merging results...")
    with open(args.fout_path, "w", encoding="utf-8") as fout:
        for tmp in tmp_files:
            if os.path.exists(tmp):
                with open(tmp, "r", encoding="utf-8") as fin:
                    for line in fin:
                        fout.write(line)
                os.remove(tmp)

    print(f"All done. Saved to {args.fout_path}")

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    parser = argparse.ArgumentParser(description="DailyOmni Evaluation")

    # 基础路径参数
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--video_base_dir", type=str)
    parser.add_argument("--qa_path", type=str)
    parser.add_argument("--fout_path", type=str, required=True, help="Path to save output jsonl")
    parser.add_argument("--num_gpus", type=int, default=8)

    parser.add_argument('--rho_audio', type=float)
    parser.add_argument('--rho_video', type=float)

    args = parser.parse_args()
    
    run_multi_gpu(args)