

#%%
import json
import os
import re
from collections import defaultdict, OrderedDict

root_dir = "vqav2_data"
splits = ["train2014", "val2014"]

# def extract_object(question):
#     question = question.lower()
#     if "how many" in question:
#         after = question.split("how many", 1)[1]
#         words = after.strip().split()
#         if words:
#             obj = words[0]
#             return obj
#     return None

def extract_object(question):
    question = question.lower()
    if "how many" not in question:
        return None

    after = question.split("how many", 1)[1].strip()
    words = after.split()

    if not words:
        return None

    obj = words[0]
    
    if "?" in obj:
        return obj.replace("?", "")

    remaining = " ".join(words[1:])
    if "is" in remaining or "are" in remaining or " be seen" in remaining:
        return obj

    return None



result = OrderedDict()

for split in splits:
    ann_path = os.path.join(root_dir, f"v2_mscoco_{split}_annotations.json")
    ques_path = os.path.join(root_dir, f"v2_OpenEnded_mscoco_{split}_questions.json")

    with open(ann_path, "r") as f:
        annotations = json.load(f)["annotations"]

    with open(ques_path, "r") as f:
        questions = json.load(f)["questions"]

    qid2question = {q["question_id"]: q["question"] for q in questions}

    for ann in annotations:
        qtype = ann.get("question_type", "").lower()
        if "how many" in qtype:
            qid = ann["question_id"]
            question = qid2question.get(qid, "")
            obj = extract_object(question)
            if obj is None:
                continue
            try:
                count = int(ann["multiple_choice_answer"])
            except ValueError:
                continue
            img_id = str(ann["image_id"])
            if obj not in result:
                result[obj] = OrderedDict()
            if str(count) not in result[obj]:
                result[obj][str(count)] = {"count": 0, "image_ids": []}
            result[obj][str(count)]["count"] += 1
            result[obj][str(count)]["image_ids"].append(img_id)

output_json_path = os.path.join(root_dir, "counting_summary_combined_2.json")
with open(output_json_path, "w") as f:
    json.dump(result, f, indent=2)

output_txt_path = os.path.join(root_dir, "object_titles_combined_2.txt")
with open(output_txt_path, "w") as f:
    for obj in result.keys():
        f.write(obj + "\n")

print(f"Saved to:\n{output_json_path}\n{output_txt_path}")



#%%
import json

def extract_questions(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    questions = data['questions']
    return [(q['question_id'], q['question']) for q in questions]
val_questions = extract_questions('vqav2_data/v2_OpenEnded_mscoco_val2014_questions.json')
train_questions = extract_questions('vqav2_data/v2_OpenEnded_mscoco_train2014_questions.json')
all_questions = val_questions + train_questions
for qid, question in all_questions[:10]:
    print(f"ID: {qid}, Question: {question}")

import csv

with open('all_questions.csv', 'w', newline='', encoding='utf-8') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['question_id', 'question'])
    writer.writerows(all_questions)
#%%
import openai
import time
import json

client = openai.OpenAI(api_key="your_api_key_here")

object_list_file = "vqav2_data/object_titles_combined_2.txt" 

with open(object_list_file, "r") as f:
    candidate_words = [line.strip() for line in f if line.strip()]

results = {}

for word in candidate_words:
    prompt = (
        f"Is the word '{word}' referring to a physical, countable object in the real world? "
        "Ignore spelling or grammar errors (like plural issues), but only say 'Yes' if the word refers to a tangible, countable **thing** "
        "(like apple, sheep, chair, car). Say 'No' if it refers to a number (e.g., '1's'), an abstract concept, an adjective, or anything non-physical.\n\n"
        "Answer with only one word: Yes or No."
    )

    try:
        response = client.chat.completions.create(
            model="gpt-4o", 
            messages=[{"role": "user", "content": prompt}],
            temperature=0,
            max_completion_tokens=10,
        )
        answer = response.choices[0].message.content.strip()
        results[word] = answer
        print(f"{word}: {answer}")
    except Exception as e:
        print(f"[ERROR] {word}: {e}")
        results[word] = "ERROR"


with open("gpt_determine_2.json", "w") as f:
    json.dump(results, f, indent=2)


#%%
import json
from collections import defaultdict

with open("vqav2_data/counting_summary_combined_2.json", "r") as f:
    full_data = json.load(f)

with open("gpt_determine_2.json", "r") as f:
    gpt_results = json.load(f)

merged_data = defaultdict(lambda: defaultdict(lambda: {"count": 0, "image_ids": []}))
total_images = 0
total_objects = 0

for obj_name, obj_info in full_data.items():
    if obj_name.startswith("_"):
        continue 

    base_name = obj_name.rstrip("?").strip().lower()

    if "yes" in gpt_results.get(base_name, "").strip().lower():
        for count_str, info in obj_info.items():
            if isinstance(info, dict) and "count" in info:
                merged_data[base_name][count_str]["count"] += int(info["count"])
                merged_data[base_name][count_str]["image_ids"].extend(info["image_ids"])

filtered_data = {}
for base_name, count_info in merged_data.items():
    total = sum(info["count"] for info in count_info.values())
    if total > 0:
        filtered_data[base_name] = count_info
        total_objects += 1
        total_images += total

filtered_data["_total_images"] = total_images
filtered_data["_total_objects"] = total_objects

with open("object_filter_results_2.json", "w") as f:
    json.dump(filtered_data, f, indent=2)


#%%
import json

def generate_filtered_summary_json(data, min_threshold=15, max_threshold=30, min_valid_counts=5, count_range=(1, 40)):
    result = {}
    grand_total = 0
    object_count = 0

    for obj, count_dict in data.items():
        obj_info = {}
        obj_total = 0
        valid_count_items = {}

        for count_str, info in count_dict.items():
            try:
                count_i = int(count_str)
                count = int(info["count"])
            except ValueError:
                continue

            if count_range[0] <= count_i <= count_range[1] and count >= min_threshold:
                images_to_add = min(count, max_threshold)
                valid_count_items[count_str] = images_to_add
                obj_total += images_to_add

        if len(valid_count_items) >= min_valid_counts:
            valid_count_items["total"] = obj_total
            result[obj] = valid_count_items
            grand_total += obj_total
            object_count += 1

    result["_total_images"] = grand_total
    result["_total_objects"] = object_count
    return result

with open("object_filter_results_2.json", "r") as f:
    data = json.load(f)

summary = generate_filtered_summary_json(data)

with open("summary_output_5_2.json", "w") as f:
    json.dump(summary, f, indent=2)

#%%
import json
import os
import random
import uuid
import shutil

# === Config ===
root_dir = "vqav2"
img_source_dir = "vqav2_data"
output_train_img_dir = os.path.join(root_dir, "train/images")
output_val_img_dir = os.path.join(root_dir, "val/images")
os.makedirs(output_train_img_dir, exist_ok=True)
os.makedirs(output_val_img_dir, exist_ok=True)

# === Load full dataset and filtered keys ===
with open("vqav2_data/counting_summary_combined_2.json", "r") as f:
    all_data = json.load(f)

with open("summary_output_5_2.json", "r") as f:
    summary_filter = json.load(f)

# === Refined filtering function ===
def generate_filtered_summary_json(data, filter_keys, min_threshold=15, max_threshold=30, min_valid_counts=5, count_range=(1, 40)):
    result = {}
    grand_total = 0
    object_count = 0

    for obj in filter_keys:
        if obj == "pieces":
            continue
        count_dict = data.get(obj, {})
        obj_info = {}
        obj_total = 0
        valid_count_items = {}

        for count_str, info in count_dict.items():
            try:
                count_i = int(count_str)
                count = int(info["count"])
            except ValueError:
                continue

            if count_range[0] <= count_i <= count_range[1] and count >= min_threshold:
                images_to_add = min(count, max_threshold)
                valid_count_items[count_str] = info["image_ids"][:images_to_add]
                obj_total += len(valid_count_items[count_str])

        if len(valid_count_items) >= min_valid_counts:
            valid_count_items["total"] = obj_total
            result[obj] = valid_count_items
            grand_total += obj_total
            object_count += 1

    result["_total_images"] = grand_total
    result["_total_objects"] = object_count
    return result

filtered_data = generate_filtered_summary_json(all_data, summary_filter.keys())

# === Copy images and generate train/val splits ===
train_json = []
val_json = []
val_ans_json = []
missing_images = []

random.seed(42)

def find_and_copy_image(img_id, dest_folder, new_name=None):
    img_id_str = f"{int(img_id):012d}"
    for prefix in ["train2014", "val2014"]:
        fname = f"COCO_{prefix}_{img_id_str}.jpg"
        src_path = os.path.join(img_source_dir, prefix, fname)
        if os.path.exists(src_path):
            dst_name = new_name if new_name else f"{img_id}.jpg"
            dst_path = os.path.join(dest_folder, dst_name)
            if not os.path.exists(dst_path):
                shutil.copy2(src_path, dst_path)
            return dst_name
    return None

for obj, count_dict in filtered_data.items():
    if obj.startswith("_"):
        continue

    train_items = []
    val_items = []

    for count_str, image_ids in count_dict.items():
        if count_str == "total":
            continue
        count_val = int(count_str)
        image_ids_copy = image_ids.copy()
        random.shuffle(image_ids_copy)
        split_idx = int(len(image_ids_copy) * 0.8)
        train_items += [(img_id, count_val) for img_id in image_ids_copy[:split_idx]]
        val_items += [(img_id, count_val) for img_id in image_ids_copy[split_idx:]]

    for img_id, count in train_items:
        uid = str(uuid.uuid4())
        copied_name = find_and_copy_image(img_id, output_train_img_dir, new_name=f"{uid}.jpg")
        if copied_name:
            train_json.append({
                "id": uid,
                "image": f"vqav2/train/images/{copied_name}",
                "conversations": [
                    {
                        "from": "human",
                        "value": f"<image>\nhow many {obj} are there in the image?"
                    },
                    {
                        "from": "gpt",
                        "value": str(count)
                    }
                ]
            })
        else:
            missing_images.append(str(img_id))

    for img_id, count in val_items:
        copied_name = find_and_copy_image(img_id, output_val_img_dir, new_name=f"{img_id}.jpg")
        if copied_name:
            qid = str(uuid.uuid4())
            question = f"<image>\nhow many {obj} are there in the image?"
            val_json.append({
                "question_id": qid,
                "image": f"val/images/{img_id}.jpg",
                "category": "default",
                "text": question,
                "id": qid
            })
            val_ans_json.append({
                "question_id": qid,
                "prompt": question,
                "text": str(count),
                "answer_id": None,
                "model_id": None,
                "metadata": {}
            })
        else:
            missing_images.append(str(img_id))

# === Save outputs ===
with open(os.path.join(root_dir, "train/train.json"), "w") as f:
    json.dump(train_json, f, indent=2)

with open(os.path.join(root_dir, "val/val.json"), "w") as f:
    json.dump(val_json, f, indent=2)

with open(os.path.join(root_dir, "val/val_ans.json"), "w") as f:
    json.dump(val_ans_json, f, indent=2)

with open(os.path.join(root_dir, "missing_images.txt"), "w") as f:
    for mid in missing_images:
        f.write(mid + "\n")

# === Stats ===
print(f"Total filtered object_ids: {filtered_data.get('_total_objects', 'N/A')}")
print(f"Total images used: {filtered_data.get('_total_images', 'N/A')}")
print(f"Missing images: {len(missing_images)} (logged in missing_images.txt)")

#%%
import json
import matplotlib.pyplot as plt

with open("summary_output_5_2.json", "r") as f:
    summary_data = json.load(f)

summary_data = {k: v for k, v in summary_data.items() if not k.startswith("_")}

object_count_vs_image_count = {}
for obj, counts in summary_data.items():
    for count_str, image_count in counts.items():
        if count_str == "total":
            continue
        count = int(count_str)
        object_count_vs_image_count[count] = object_count_vs_image_count.get(count, 0) + image_count

object_name_vs_image_count = {
    obj: counts["total"] for obj, counts in summary_data.items()
}

plt.figure(figsize=(10, 5))
plt.bar(object_count_vs_image_count.keys(), object_count_vs_image_count.values())
plt.xlabel("Object Count (e.g., 1 object, 2 objects...)")
plt.ylabel("Total Image Count")
plt.title("Object Count vs. Image Count")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Plot 2: Object name vs image count (top 30 for visibility)
sorted_objects = dict(sorted(object_name_vs_image_count.items(), key=lambda x: x[1], reverse=True)[:30])
plt.figure(figsize=(12, 6))
plt.bar(sorted_objects.keys(), sorted_objects.values())
plt.xlabel("Object Name")
plt.ylabel("Image Count")
plt.title("Top 30 Object Names by Image Count")
plt.xticks(rotation=75, ha='right')
plt.tight_layout()
plt.show()


#%%
import json

with open("gpt_determine.json", "r") as f1:
    raw_data1 = json.load(f1)

with open("gpt_determine_2.json", "r") as f2:
    data2 = json.load(f2)


data1 = {k.replace("?", "").strip(): v for k, v in raw_data1.items()}

def is_yes(answer):
    return "yes" in answer.lower()

mismatched = []

for key, value in data1.items():
    if is_yes(value):
        if key not in data2 or not is_yes(data2.get(key, "")):
            mismatched.append(key)

for item in mismatched:
    print(f"- {item}")



#%%
