import json
import os
import numpy as np
import copy
import datetime
import sys
import re
from tqdm import tqdm
import ollama
import random

from torch.utils.data import Dataset




class VQADataset(Dataset):
    def __init__(
        self, image_dir_path, question_path, annotations_path, is_train, dataset_name, max_samples=None
    ):
        self.questions = json.load(open(question_path, "r"))["questions"]
        self.qid_to_idx = {q["question_id"]: i for i, q in enumerate(self.questions)}
        if annotations_path is not None:
            self.answers = json.load(open(annotations_path, "r"))["annotations"]
        else:
            self.answers = None

        if max_samples is not None:
            self.questions = self.questions[:max_samples]
            if self.answers is not None:
                self.answers = self.answers[:max_samples]

        self.image_dir_path = image_dir_path
        self.is_train = is_train
        self.dataset_name = dataset_name
        if self.dataset_name in {"vqav2", "ok_vqa"}:
            self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1]
            assert self.img_coco_split in {"train2014", "val2014", "test2015"}

    def __len__(self):
        return len(self.questions)

    def get_img_path(self, question):
        if self.dataset_name in {"vqav2", "ok_vqa"}:
            return os.path.join(
                self.image_dir_path,
                f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg"
                if self.is_train
                else f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg",
            )
        elif self.dataset_name == "vizwiz":
            return os.path.join(self.image_dir_path, question["image_id"])
        elif self.dataset_name == "textvqa":
            return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg")
        else:
            raise Exception(f"Unknown VQA dataset {self.dataset_name}")

    def __getitem__(self, idx):
        question = self.questions[idx]
        img_path = self.get_img_path(question)
        results = {
            "image_path": img_path,
            "question": question["question"],
            "question_id": question["question_id"],
        }
        if self.answers is not None:
            answers = self.answers[idx]
            results["answers"] = [a["answer"] for a in answers["answers"]]
        return results




def main():
    # 设置数据集路径
    train_image_dir_path = "/home/test/yxl/MCoT/data/okvqa/train2014"
    train_questions_json_path = "/home/test/yxl/MCoT/data/vqa2/v2_OpenEnded_mscoco_train2014_questions.json"
    train_annotations_json_path = "/home/test/yxl/MCoT/data/vqa2/v2_mscoco_train2014_annotations.json"
    test_image_dir_path = "/home/test/yxl/MCoT/data/okvqa/val2014"
    test_questions_json_path = "/home/test/yxl/MCoT/data/vqa2/v2_OpenEnded_mscoco_val2014_questions.json"
    test_annotations_json_path = "/home/test/yxl/MCoT/data/vqa2/v2_mscoco_val2014_annotations.json"

    # 创建数据集
    train_dataset = VQADataset(
        image_dir_path=train_image_dir_path,
        question_path=train_questions_json_path,
        annotations_path=train_annotations_json_path,
        is_train=True,
        dataset_name="vqav2",
        max_samples=None
    )

    test_dataset = VQADataset(
        image_dir_path=test_image_dir_path,
        question_path=test_questions_json_path,
        annotations_path=test_annotations_json_path,
        is_train=False,  # 测试集设为False
        dataset_name="vqav2",
        max_samples=10000
    )

    random.seed(3)  # 固定随机种子确保可复现
    '''
    sample_indices = random.sample(range(len(train_dataset)), 4)
    prompt_examples = []

    for idx, i in enumerate(sample_indices, 1):
        sample = train_dataset[i]
        question = sample["question"]
        answer = sample["answers"][0] if sample["answers"] else "no_answer"
        formatted_answer = f"The answer is {answer}"
        prompt_examples.append(f"Example {idx}:\nQuestion: {question}\nAnswer: {formatted_answer}\n")

    '''
    '''
    selected_question_ids = [1232004, 394208005, 109277001, 131215001]  # 替换为实际的question_id

    prompt_examples = []
    for idx, qid in enumerate(selected_question_ids, 1):
        if qid not in train_dataset.qid_to_idx:
            print(f"警告: question_id {qid} 不在训练集中，跳过")
            continue

        # 根据question_id获取样本
        idx_in_dataset = train_dataset.qid_to_idx[qid]
        sample = train_dataset[idx_in_dataset]

        question = sample["question"]
        answer = sample["answers"][0] if sample["answers"] else "no_answer"
        formatted_answer = f"The answer is {answer}"
        prompt_examples.append(f"Example {idx}:\nQuestion: {question}\nAnswer: {formatted_answer}\n")
    
    if not prompt_examples:
        print("错误: 没有找到有效的提示样例，程序退出")
        return

    prompt_prefix = "\n".join(prompt_examples) + "\nPlease answer the question in the form of\' The answer is \'\n\n"
    '''
    # prompt_prefix = "\n".join(prompt_examples)
    
    # 准备结果列表
    predictions = []

    # 遍历数据集并调用模型
    for item in tqdm(test_dataset, desc="Processing dataset"):
        image_path = item["image_path"]
        question = item["question"]
        question_id = item["question_id"]

        # prompt =prompt_prefix + f"Question: {question}\n\n"
        # prompt = f"please answer the question in the form of\' The answer is \'\n\nQuestion: {question}\n\nLet\'s think step by step\n\n"
        prompt = f"Let\'s think step by step\n\nQuestion: {question}\n\nplease answer the question in the form of\' The answer is \'\n\n"


        # 构建消息
        messages = [{
            "role": "user",
            "content": prompt
        }]

        # 添加图像
        if os.path.exists(image_path):
            messages[0]["images"] = [image_path]


        try:
            # 调用 Ollama API
            response = ollama.chat(
                model="llava:7b",
                stream=False,
                messages=messages,
                options={
                    "temperature": 0.5,
                    "max_tokens": 512,
                    "top_p": 1.0,
                    "frequency_penalty": 0.0,
                    "presence_penalty": 0.0,
                    "stop": ["\n"]
                }
            )

            output = response['message']['content']

            # 保存结果
            predictions.append({
                "question_id": question_id,
                "question": question,
                "output": output
            })

        except Exception as e:
            print(f"Error processing question {question_id}: {e}")
            predictions.append({
                "question_id": question_id,
                "question": question,
                "answer": "ERROR"
            })

    # 保存结果到文件
    results_file = "/home/test/yxl/MCoT/vqa2/results/llava_test/CoT_1.json"
    with open(results_file, "w") as f:
        json.dump(predictions, f, indent=4)



if __name__ == "__main__":
    main()
