import json
import os
import numpy as np
import copy
import datetime
import sys
import re
from tqdm import tqdm
import ollama

from torch.utils.data import Dataset

from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

os.environ["CUDA_VISIBLE_DEVICES"] = "1"


class VQADataset(Dataset):
    def __init__(
            self, image_dir_path, data_path, is_train, dataset_name, max_samples=None
    ):
        # 加载数据文件
        with open(data_path, "r") as f:
            data_dict = json.load(f)

        # 提取数据列表
        self.data = data_dict["data"]
        if max_samples is not None:
            self.data = self.data[:max_samples]

        self.image_dir_path = image_dir_path
        self.is_train = is_train
        self.dataset_name = dataset_name
        if self.dataset_name in {"vqav2", "ok_vqa"}:
            self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1]
            assert self.img_coco_split in {"train2014", "val2014", "test2015"}

    def __len__(self):
        return len(self.data)

    def get_img_path(self, item):
        if self.dataset_name in {"vqav2", "ok_vqa"}:
            return os.path.join(
                self.image_dir_path,
                f"COCO_{self.img_coco_split}_{item['image_id']:012d}.jpg"
                if self.is_train
                else f"COCO_{self.img_coco_split}_{item['image_id']:012d}.jpg",
            )
        elif self.dataset_name == "vizwiz":
            return os.path.join(self.image_dir_path, item["image_id"])
        elif self.dataset_name == "textvqa":
            return os.path.join(self.image_dir_path, f"{item['image_id']}.jpg")
        else:
            raise Exception(f"Unknown VQA dataset {self.dataset_name}")

    def __getitem__(self, idx):
        item = self.data[idx]
        img_path = self.get_img_path(item)
        results = {
            "image_path": img_path,
            "question": item["question"],
            "question_id": item["question_id"],
        }
        # 检查答案是否存在
        if "answers" in item:
            results["answers"] = item["answers"]
        return results


def main():
    # 设置数据集路径
    train_image_dir_path = "/home/test/yxl/MCoT/data/textvqa/images"
    train_data_json_path = "/home/test/yxl/MCoT/data/textvqa/TextVQA_0.5.1_train.json"

    # 结果保存路径
    results_file = "/home/test/yxl/MCoT/textvqa/results/qwen2.5vl/AP_5.json"

    # 已完成的问题ID集合
    completed_question_ids = set()

    # 检查结果文件是否存在
    if os.path.exists(results_file):
        try:
            with open(results_file, "r") as f:
                existing_predictions = json.load(f)
                # 提取已完成的问题ID
                completed_question_ids = {item["question_id"] for item in existing_predictions}
            print(f"已检测到现有结果文件，将从上次进度继续。已完成 {len(completed_question_ids)} 个问题。")
        except:
            print("结果文件存在但无法读取，将从头开始。")
            existing_predictions = []
    else:
        # 创建结果文件所在目录
        os.makedirs(os.path.dirname(results_file), exist_ok=True)
        existing_predictions = []

    # 创建数据集
    train_dataset = VQADataset(
        image_dir_path=train_image_dir_path,
        data_path=train_data_json_path,
        is_train=True,
        dataset_name="textvqa",
        max_samples=1000
    )

    # 准备结果列表
    predictions = existing_predictions
    batch_size = 100  # 每100个问题保存一次

    # 获取需要处理的样本索引
    indices_to_process = []
    for i in range(len(train_dataset)):
        item = train_dataset[i]
        if item["question_id"] not in completed_question_ids:
            indices_to_process.append(i)

    print(f"总共需要处理 {len(indices_to_process)} 个问题。")

    # 记录当前批处理的起始索引
    start_idx = len(existing_predictions)

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
    )

    processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")

    # 遍历数据集并调用模型
    for i in tqdm(indices_to_process, desc="Processing dataset"):
        item = train_dataset[i]
        image_path = item["image_path"]
        question = item["question"]
        question_id = item["question_id"]

        # 构建提示
        prompt = 'Please answer the question in the form of\' The answer is \'\n\n'
        prompt += f"Question: {question}\n\n"

        # 构建消息
        messages = [{
            "role": "user",
            "content": prompt
        }]

        # 添加图像
        if os.path.exists(image_path):
            messages[0]["images"] = [image_path]

        try:
            text = processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to("cuda")

            generated_ids = model.generate(**inputs, max_new_tokens=128)
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )

            output = output_text[0]

            # 保存结果
            predictions.append({
                "question_id": question_id,
                "question": question,
                "output": output
            })

        except Exception as e:
            print(f"Error processing question {question_id}: {e}")
            predictions.append({
                "question_id": question_id,
                "question": question,
                "output": "ERROR"
            })

        # 每batch_size个问题保存一次结果
        current_idx = len(predictions) - 1
        if (current_idx - start_idx + 1) % batch_size == 0 or current_idx == len(train_dataset) - 1:
            with open(results_file, "w") as f:
                json.dump(predictions, f, indent=4)
            print(f"已保存 {len(predictions)} 个问题的结果到 {results_file}")

    # 最后再保存一次确保所有结果都被保存
    with open(results_file, "w") as f:
        json.dump(predictions, f, indent=4)
    print(f"处理完成！共保存 {len(predictions)} 个问题的结果到 {results_file}")


if __name__ == "__main__":
    main()
