#%%
import json
import os
import re
from collections import defaultdict

with open("vqav2_data/counting_summary_combined_2.json", "r") as f:
    object_counts = json.load(f)

with open("vqav2_data/v2_OpenEnded_mscoco_train2014_questions.json", "r") as f:
    train_qs = json.load(f)["questions"]

with open("vqav2_data/v2_OpenEnded_mscoco_val2014_questions.json", "r") as f:
    val_qs = json.load(f)["questions"]

vqa_question_map = defaultdict(list)
for q in train_qs + val_qs:
    vqa_question_map[q["image_id"]].append(q)

# def extract_object(question):
#     question = question.lower()
#     if "how many" in question:
#         after = question.split("how many", 1)[1]
#         words = after.strip().split()
#         if words:
#             obj = words[0]
#             return obj
#     return None
def extract_object(question):
    question = question.lower()
    if "how many" not in question:
        return None

    after = question.split("how many", 1)[1].strip()
    words = after.split()

    if not words:
        return None

    obj = words[0]
    
    if "?" in obj:
        return obj.replace("?", "")

    remaining = " ".join(words[1:])
    if "is" in remaining or "are" in remaining or " be seen" in remaining:
        return obj

    return None


matched_records = []

for obj_name, count_dict in object_counts.items():
    for count_str, info in count_dict.items():
        image_ids = info["image_ids"]
        for img_id_str in image_ids:
            try:
                img_id = int(img_id_str)
            except:
                continue

            for q in vqa_question_map.get(img_id, []):
                # print(q["question"])
                extracted = extract_object(q["question"])
                if extracted == obj_name.lower():
                    matched_records.append({
                        "object": obj_name,
                        "count": int(count_str),
                        "image_id": img_id,
                        "question_id": q["question_id"],
                        "question": q["question"],
                        "matched_with": extracted
                    })

with open("matched_from_object_imageids_2.json", "w") as f:
    json.dump(matched_records, f, indent=2)




# # %%

# #%%
# import json

# # Load both JSON files
# with open("matched_from_object_imageids.json", "r") as f1, open("matched_from_object_imageids_2.json", "r") as f2:
#     first_data = json.load(f1)
#     second_data = json.load(f2)

# first_questions = set(item["question"] for item in first_data)
# second_questions = set(item["question"] for item in second_data)

# unique_questions = first_questions - second_questions

# unique_entries = [item for item in first_data if item["question"] in unique_questions]

# with open("unique_questions.json", "w") as f_out:
#     json.dump(unique_entries, f_out, indent=2)


#%%
import openai
import json
import time

client = openai.OpenAI(api_key="your_api_key_here")

input_file = "unique_questions.json"
output_file = "counting_with_gpt_label.json"

with open(input_file, "r") as f:
    data = json.load(f)

for i, item in enumerate(data):
    question = item["question"]

    prompt = (
        "Does the following question ask to count visible entities in an image?\n"
        "Answer 'Yes' if the question is a real counting question (e.g., 'How many people are in the photo?').\n"
        "Answer 'No' if it's about capacity, possibility, or non-visual information.\n\n"
        f"Question: \"{question}\"\n\n"
        "Answer with only one word: Yes or No."
    )

    try:
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[{"role": "user", "content": prompt}],
            temperature=0,
            max_tokens=10,
        )
        gpt_answer = response.choices[0].message.content.strip()
        item["gpt"] = gpt_answer
        print(f"[{i+1}/{len(data)}] {question} → {gpt_answer}")
    except Exception as e:
        item["gpt"] = "error"
        print(f"[{i+1}/{len(data)}] Error on question: {question}\n{e}")

with open(output_file, "w") as f_out:
    json.dump(data, f_out, indent=2)


# %%
