import json
import os
import numpy as np
import copy
import datetime
import sys
import re
from tqdm import tqdm
import ollama
import random

from torch.utils.data import Dataset

from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


class VQADataset(Dataset):
    def __init__(
            self, image_dir_path, data_path, is_train, dataset_name, max_samples=None
    ):
        # 加载数据文件
        with open(data_path, "r") as f:
            data_dict = json.load(f)

        # 提取数据列表
        self.data = data_dict["data"]
        self.qid_to_idx = {q["question_id"]: i for i, q in enumerate(self.data)}
        if max_samples is not None:
            self.data = self.data[:max_samples]

        self.image_dir_path = image_dir_path
        self.is_train = is_train
        self.dataset_name = dataset_name
        if self.dataset_name in {"vqav2", "ok_vqa"}:
            self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1]
            assert self.img_coco_split in {"train2014", "val2014", "test2015"}

    def __len__(self):
        return len(self.data)

    def get_img_path(self, item):
        if self.dataset_name in {"vqav2", "ok_vqa"}:
            return os.path.join(
                self.image_dir_path,
                f"COCO_{self.img_coco_split}_{item['image_id']:012d}.jpg"
                if self.is_train
                else f"COCO_{self.img_coco_split}_{item['image_id']:012d}.jpg",
            )
        elif self.dataset_name == "vizwiz":
            return os.path.join(self.image_dir_path, item["image_id"])
        elif self.dataset_name == "textvqa":
            return os.path.join(self.image_dir_path, f"{item['image_id']}.jpg")
        else:
            raise Exception(f"Unknown VQA dataset {self.dataset_name}")

    def __getitem__(self, idx):
        item = self.data[idx]
        img_path = self.get_img_path(item)
        results = {
            "image_path": img_path,
            "question": item["question"],
            "question_id": item["question_id"],
        }
        # 检查答案是否存在
        if "answers" in item:
            results["answers"] = item["answers"]
        return results


def main():
    # 设置数据集路径
    image_dir_path = "/home/test/yxl/MCoT/data/textvqa/images"
    train_data_json_path = "/home/test/yxl/MCoT/data/textvqa/TextVQA_0.5.1_train.json"
    test_data_json_path = "/home/test/yxl/MCoT/data/textvqa/TextVQA_0.5.1_val.json"

    # 创建数据集
    train_dataset = VQADataset(
        image_dir_path=image_dir_path,
        data_path=train_data_json_path,
        is_train=True,
        dataset_name="textvqa",
        max_samples=1000
    )

    test_dataset = VQADataset(
        image_dir_path=image_dir_path,
        data_path=test_data_json_path,
        is_train=False,
        dataset_name="textvqa",
        max_samples=None
    )
    

    random.seed(3)  # 固定随机种子确保可复现

    '''
    sample_indices = random.sample(range(len(train_dataset)), 4)
    prompt_examples = []
    
    for idx, i in enumerate(sample_indices, 1):
        sample = train_dataset[i]
        question = sample["question"]
        answer = sample["answers"][0] if sample["answers"] else "no_answer"
        formatted_answer = f"The answer is {answer}"
        prompt_examples.append(f"Example {idx}:\nQuestion: {question}\nAnswer: {formatted_answer}\n")

    '''
    '''
    selected_question_ids = [429, 502, 985, 383]  # 替换为实际的question_id

    prompt_examples = []
    for idx, qid in enumerate(selected_question_ids, 1):
        if qid not in train_dataset.qid_to_idx:
            print(f"警告: question_id {qid} 不在训练集中，跳过")
            continue

        # 根据question_id获取样本
        idx_in_dataset = train_dataset.qid_to_idx[qid]
        sample = train_dataset[idx_in_dataset]

        question = sample["question"]
        answer = sample["answers"][0] if sample["answers"] else "no_answer"
        formatted_answer = f"The answer is {answer}"
        prompt_examples.append(f"Example {idx}:\nQuestion: {question}\nAnswer: {formatted_answer}\n")

    if not prompt_examples:
        print("错误: 没有找到有效的提示样例，程序退出")
        return

    prompt_prefix = "\n".join(prompt_examples) + "\nPlease answer the question in the form of\' The answer is \'\n\n"
    '''
    # 准备结果列表
    predictions = []

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
    )

    processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")

    # 遍历数据集并调用模型
    for item in tqdm(test_dataset, desc="Processing dataset"):
        image_path = item["image_path"]
        question = item["question"]
        question_id = item["question_id"]

        # prompt = prompt_prefix + f"Question: {question}\n\n"
        prompt = f"please answer the question in the form of\' The answer is \'\n\nQuestion: {question}\n\nLet\'s think step by step\n\n"
        # prompt = f"Let\'s think step by step\n\nQuestion: {question}\n\nplease answer the question in the form of\' The answer is \'\n\n"

        # 构建消息
        messages = [{
            "role": "user",
            "content": prompt
        }]

        # 添加图像
        if os.path.exists(image_path):
            messages[0]["images"] = [image_path]

        try:
            text = processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to("cuda")

            generated_ids = model.generate(**inputs, max_new_tokens=128)
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )

            output = output_text[0]

            # 保存结果
            predictions.append({
                "question_id": question_id,
                "question": question,
                "output": output
            })

        except Exception as e:
            print(f"Error processing question {question_id}: {e}")
            predictions.append({
                "question_id": question_id,
                "question": question,
                "answer": "ERROR"
            })

    # 保存结果到文件
    results_file = "/home/test/yxl/MCoT/textvqa/results/qwen2.5vl_test/CoT.json"
    with open(results_file, "w") as f:
        json.dump(predictions, f, indent=4)


if __name__ == "__main__":
    main()

