#%%
import json
import os
import random
import uuid
import shutil
import csv
from collections import defaultdict

# === Config ===
root_dir = "vqav2_csv_new"
img_source_dir = "vqav2_data"
output_train_img_dir = os.path.join(root_dir, "train/images")
output_val_img_dir = os.path.join(root_dir, "val/images")
os.makedirs(output_train_img_dir, exist_ok=True)
os.makedirs(output_val_img_dir, exist_ok=True)

# === Load full dataset and filtered keys ===
with open("vqav2_data/counting_summary_combined_2.json", "r") as f:
    all_data = json.load(f)

with open("summary_output_5_2.json", "r") as f:
    summary_filter = json.load(f)

# === Load VQA questions for mapping ===
with open("vqav2_data/v2_OpenEnded_mscoco_train2014_questions.json", "r") as f:
    train_qs = json.load(f)["questions"]

with open("vqav2_data/v2_OpenEnded_mscoco_val2014_questions.json", "r") as f:
    val_qs = json.load(f)["questions"]

vqa_question_map = defaultdict(list)
for q in train_qs + val_qs:
    vqa_question_map[q["image_id"]].append(q)

# === Object extractor for question ===
def extract_object(question):
    question = question.lower()
    if "how many" not in question:
        return None
    after = question.split("how many", 1)[1].strip()
    words = after.split()
    if not words:
        return None
    obj = words[0]
    if "?" in obj:
        return obj.replace("?", "")
    remaining = " ".join(words[1:])
    if "is" in remaining or "are" in remaining or " be seen" in remaining:
        return obj
    return None

# === Build image_id + object → question_id & question map
print("Building matched question map from VQA...")
matched_records = []
for obj_name, count_dict in all_data.items():
    for count_str, info in count_dict.items():
        for img_id_str in info["image_ids"]:
            try:
                img_id = int(img_id_str)
            except:
                continue
            for q in vqa_question_map.get(img_id, []):
                extracted = extract_object(q["question"])
                if extracted == obj_name.lower():
                    matched_records.append({
                        "object": obj_name,
                        "count": int(count_str),
                        "image_id": img_id,
                        "question_id": q["question_id"],
                        "question": q["question"]
                    })

image_obj_to_question = {}
for item in matched_records:
    key = (item["image_id"], item["object"].lower())
    image_obj_to_question[key] = (item["question_id"], item["question"])

# === Refined filtering function ===
def generate_filtered_summary_json(data, filter_keys, min_threshold=15, max_threshold=30, min_valid_counts=5, count_range=(1, 40)):
    result = {}
    grand_total = 0
    object_count = 0
    for obj in filter_keys:
        if obj == "pieces":
            continue
        count_dict = data.get(obj, {})
        valid_count_items = {}
        obj_total = 0
        for count_str, info in count_dict.items():
            try:
                count_i = int(count_str)
                count = int(info["count"])
            except ValueError:
                continue
            if count_range[0] <= count_i <= count_range[1] and count >= min_threshold:
                images_to_add = min(count, max_threshold)
                valid_count_items[count_str] = info["image_ids"][:images_to_add]
                obj_total += len(valid_count_items[count_str])
        if len(valid_count_items) >= min_valid_counts:
            valid_count_items["total"] = obj_total
            result[obj] = valid_count_items
            grand_total += obj_total
            object_count += 1
    result["_total_images"] = grand_total
    result["_total_objects"] = object_count
    return result

filtered_data = generate_filtered_summary_json(all_data, summary_filter.keys())

# === Prepare output containers
train_json = []
val_json = []
val_ans_json = []
missing_images = []
train_csv_rows = []
val_csv_rows = []

random.seed(42)

def find_and_copy_image(img_id, dest_folder, new_name=None):
    img_id_str = f"{int(img_id):012d}"
    for prefix in ["train2014", "val2014"]:
        fname = f"COCO_{prefix}_{img_id_str}.jpg"
        src_path = os.path.join(img_source_dir, prefix, fname)
        if os.path.exists(src_path):
            dst_name = new_name if new_name else f"{img_id}.jpg"
            dst_path = os.path.join(dest_folder, dst_name)
            if not os.path.exists(dst_path):
                shutil.copy2(src_path, dst_path)
            return dst_name
    return None

# === Main Loop for Train/Val
for obj, count_dict in filtered_data.items():
    if obj.startswith("_"):
        continue
    train_items = []
    val_items = []

    for count_str, image_ids in count_dict.items():
        if count_str == "total":
            continue
        count_val = int(count_str)
        image_ids_copy = image_ids.copy()
        random.shuffle(image_ids_copy)
        split_idx = int(len(image_ids_copy) * 0.8)
        train_items += [(img_id, count_val) for img_id in image_ids_copy[:split_idx]]
        val_items += [(img_id, count_val) for img_id in image_ids_copy[split_idx:]]

    for img_id, count in train_items:
        uid = str(uuid.uuid4())
        copied_name = find_and_copy_image(img_id, output_train_img_dir, new_name=f"{uid}.jpg")
        if copied_name:
            question = f"how many {obj} are there in the image?"
            train_json.append({
                "id": uid,
                "image": f"vqav2/train/images/{copied_name}",
                "conversations": [
                    {"from": "human", "value": f"<image>\n{question}"},
                    {"from": "gpt", "value": str(count)}
                ]
            })
            # Match to original
            qid, orig_q = image_obj_to_question.get((int(img_id), obj.lower()), ("", ""))
            train_csv_rows.append({
                "uuid": uid,
                "image_id": img_id,
                "question_id": qid,
                "original_question": orig_q
            })
        else:
            missing_images.append(str(img_id))

    for img_id, count in val_items:
        copied_name = find_and_copy_image(img_id, output_val_img_dir, new_name=f"{img_id}.jpg")
        if copied_name:
            qid = str(uuid.uuid4())
            question = f"<image>\nhow many {obj} are there in the image?"
            val_json.append({
                "question_id": qid,
                "image": f"val/images/{img_id}.jpg",
                "category": "default",
                "text": question,
                "id": qid
            })
            val_ans_json.append({
                "question_id": qid,
                "prompt": question,
                "text": str(count),
                "answer_id": None,
                "model_id": None,
                "metadata": {}
            })
            # Match to original
            vqa_qid, orig_q = image_obj_to_question.get((int(img_id), obj.lower()), ("", ""))
            val_csv_rows.append({
                "uuid": qid,
                "image_id": img_id,
                "question_id": vqa_qid,
                "original_question": orig_q
            })
        else:
            missing_images.append(str(img_id))

# === Save outputs
os.makedirs(os.path.join(root_dir, "train"), exist_ok=True)
os.makedirs(os.path.join(root_dir, "val"), exist_ok=True)

with open(os.path.join(root_dir, "train/train.json"), "w") as f:
    json.dump(train_json, f, indent=2)

with open(os.path.join(root_dir, "val/val.json"), "w") as f:
    json.dump(val_json, f, indent=2)

with open(os.path.join(root_dir, "val/val_ans.json"), "w") as f:
    json.dump(val_ans_json, f, indent=2)

with open(os.path.join(root_dir, "missing_images.txt"), "w") as f:
    for mid in missing_images:
        f.write(mid + "\n")

# === Save CSV
with open(os.path.join(root_dir, "train_uuid_question.csv"), "w", newline="", encoding="utf-8") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=["uuid", "image_id", "question_id", "original_question"])
    writer.writeheader()
    writer.writerows(train_csv_rows)

with open(os.path.join(root_dir, "val_uuid_question.csv"), "w", newline="", encoding="utf-8") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=["uuid", "image_id", "question_id", "original_question"])
    writer.writeheader()
    writer.writerows(val_csv_rows)

# === Summary
print(f"Done! Train samples: {len(train_json)}, Val samples: {len(val_json)}")
print(f"CSVs: train_uuid_question.csv / val_uuid_question.csv")
print(f"Missing images: {len(missing_images)}")
