import torch
import os
import json
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
import numpy as np
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets

# from transformers import AutoModelForCausalLM 
from modeling_phi3_v import Phi3VForCausalLM, Phi3Attention
from transformers import AutoProcessor 

from utils.data_utils import load_yaml, construct_prompt, save_json, process_single_sample, CAT_SHORT2LONG
from utils.model_utils import phi3_image_processor, call_phi3_engine_df
from utils.eval_utils import parse_multi_choice_response, parse_open_response
from argparse import ArgumentParser

TAGET_MODULE = {
    "phi3": None,
    "phi3_h2o": Phi3Attention
}

# 加载模型和处理器
model_id = "microsoft/Phi-3.5-vision-instruct"
model = Phi3VForCausalLM.from_pretrained(
         model_id, 
         device_map="cuda", 
         trust_remote_code=True, 
         torch_dtype="auto", 
         _attn_implementation='eager'
    )
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, num_crops=4)

# 生成图像描述（每次上传一张图像）
def generate_caption_for_image(image):
    """
    为每张图像生成描述
    """
    # 生成图像的占位符标签
    placeholder = "<|image_1|>\n"
    
    # 生成请求文本
    query = (
        "Please generate a story description for the uploaded image. The requirements are:\n"
        "- The description for the image should be consistent with the style of the 'Rabbids' cartoon, ensuring that the text and visuals are aligned in terms of style.\n"
        "- The description should be engaging and entertaining, with elements that captivate the audience and maintain their interest in the story.\n"
        "- The description should be closely related to the content of the image, ensuring that the text is coherent with the visuals and maintains logical consistency in the narrative.\n"
    )
    
    # 创建输入消息
    messages = [{"role": "user", "content": placeholder + query}]
    
    # 创建输入提示
    prompt = processor.tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    # 处理输入
    inputs = processor(prompt, [image], return_tensors="pt").to("cuda:0")
    
    # 设置生成参数
    generation_args = { 
        "max_new_tokens": 196,  # 设置最大生成文本长度
        "temperature": 0.7, 
        "do_sample": True, 
    } 
    
    # 生成模型输出
    generate_ids = model.generate(**inputs, 
                                  eos_token_id=processor.tokenizer.eos_token_id, 
                                  **generation_args)
    for name, m in model.named_modules():
        if isinstance(m, TAGET_MODULE["phi3_h2o"]):
            m._clean_cache()
    
    # 移除输入的tokens
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
    
    response = processor.batch_decode(generate_ids, 
                                      skip_special_tokens=True, 
                                      clean_up_tokenization_spaces=False)
    
    return response

# 读取数据集
def load_data(file_path):
    """
    加载并解析JSONL格式的数据集
    """
    with open(file_path, 'r') as file:
        data = [json.loads(line) for line in file]
    return data

# 处理数据集中的图像并生成长文本描述（每次上传一张图像）
def process_dataset(data):
    """
    遍历数据集，每次上传一张图像生成描述
    """
    all_results = []
    timings = []  # 记录每次生成的响应时间
    repetitions = 0  # 用于计算平均值

    for item in tqdm(data[:3], desc="Processing items", unit="item"):  # 使用 tqdm 显示进度条
        # 获取当前条目的图片路径列表
        images = item['images']
        images = [os.path.join('/hy-tmp/Rabbids/rabbids', image_path) for image_path in images]
        
        # 每次上传一张图片
        for image_path in images:
            image = Image.open(image_path).resize((256, 256))
            
            # 记录时间
            starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
            starter.record()
            
            generated_caption = generate_caption_for_image(image)
            
            ender.record()
            torch.cuda.synchronize()  # 等待 GPU 同步
            curr_time = starter.elapsed_time(ender)  # 获取响应时间
            timings.append(curr_time)
            repetitions += 1

            all_results.append({"generated_caption": generated_caption})

    # 计算响应时间的平均值和标准差
    mean_syn = np.sum(timings) / repetitions if repetitions > 0 else 0
    std_syn = np.std(timings) if repetitions > 0 else 0
    mean_fps = 1000. / mean_syn if mean_syn > 0 else 0  # FPS = 1000ms / 平均响应时间（ms）
    
    print(f"Mean response time: {mean_syn:.3f} ms")
    print(f"Std response time: {std_syn:.3f} ms")
    print(f"Mean FPS: {mean_fps:.2f} FPS")
    
    return all_results

# 主程序入口
def main():
    # 数据文件路径
    file_path = '/hy-tmp/Rabbids/val.jsonl'  # 请替换为你实际的路径
    
    # 加载数据集
    data = load_data(file_path)
    
    # 处理数据集并生成结果
    results = process_dataset(data)
    
    # 输出生成的结果到文件
    output_file = os.path.join(os.path.dirname(file_path), "generated_captions_EBM.jsonl")
    
    with open(output_file, 'w') as file:
        for result in results:
            file.write(json.dumps(result) + "\n")
    
    print(f"Results saved to: {output_file}")

# 调用主程序
if __name__ == "__main__":
    main()
