#!/usr/bin/env python
import json
import os
import random
import argparse
import shutil


def build_image_filename(image_id: int) -> str:
    """
    Build COCO val2014 image filename from image_id.
    e.g. image_id = 393226 -> COCO_val2014_000000393226.jpg
    """
    return "COCO_val2014_{:012d}.jpg".format(image_id)


def count_jpg_images(folder: str) -> int:
    if not os.path.isdir(folder):
        return 0
    exts = (".jpg", ".jpeg", ".JPG", ".JPEG")
    return sum(
        1 for f in os.listdir(folder)
        if f.endswith(exts) and os.path.isfile(os.path.join(folder, f))
    )


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--question-json",
        type=str,
        default="/path/to/VQA_v2/vqa_v2/v2_OpenEnded_mscoco_val2014_questions.json",
    )
    parser.add_argument(
        "--annotation-json",
        type=str,
        default="/path/to/VQA_v2/vqa_v2/v2_mscoco_val2014_annotations.json",
    )
    parser.add_argument(
        "--num-samples",
        type=int,
        default=500,
        help="number of questions to sample",
    )
    parser.add_argument(
        "--out-question-file",
        type=str,
        default="/path/to/VQA_v2/sample500_questions.json",
    )
    parser.add_argument(
        "--out-answer-file",
        type=str,
        default="/path/to/VQA_v2/sample500_answers.json",
    )
    parser.add_argument(
        "--image-root",
        type=str,
        default="/path/to/VQA_v2/coco2014/val2014",
        help="COCO val2014 source image folder",
    )
    parser.add_argument(
        "--subset-image-folder",
        type=str,
        default="/path/to/VQA_v2/sample500_images",
        help="folder to copy the subset images into",
    )
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()
    random.seed(args.seed)

    img_count = count_jpg_images(args.subset_image_folder)
    if (
        img_count >= args.num_samples
        and os.path.isfile(args.out_question_file)
        and os.path.isfile(args.out_answer_file)
    ):
        print(
            f"[INFO] Found existing subset: {img_count} images and json files. "
            "Skip sampling and copying."
        )
        return

    print(f"[INFO] Load questions from {args.question_json}")
    with open(args.question_json, "r", encoding="utf-8") as f:
        q_data = json.load(f)

    print(f"[INFO] Load annotations from {args.annotation_json}")
    with open(args.annotation_json, "r", encoding="utf-8") as f:
        a_data = json.load(f)

    questions = q_data["questions"]
    annotations = a_data["annotations"]

    ann_by_qid = {ann["question_id"]: ann for ann in annotations}

    if args.num_samples > len(questions):
        raise ValueError(
            f"num_samples({args.num_samples}) exceeds total number of questions ({len(questions)})"
        )

    sampled_questions = random.sample(questions, args.num_samples)

    out_questions = []
    out_answers = []
    needed_images = set()

    for q in sampled_questions:
        qid = q["question_id"]
        image_id = q["image_id"]
        question_str = q["question"]

        if qid not in ann_by_qid:
            print(f"[WARN] annotation not found for question_id={qid}, skip")
            continue

        ann = ann_by_qid[qid]
        image_file = build_image_filename(image_id)
        needed_images.add(image_file)

        out_questions.append(
            {
                "question_id": qid,
                "image_id": image_id,
                "image": image_file,
                "text": question_str,
            }
        )

        raw_answers = ann["answers"]
        ans_list = [a["answer"] for a in raw_answers]

        out_answers.append(
            {
                "question_id": qid,
                "image_id": image_id,
                "answers": ans_list,
            }
        )

    print(f"[INFO] Sampled {len(out_questions)} questions")
    print(f"[INFO] Unique images in subset: {len(needed_images)}")

    os.makedirs(os.path.dirname(args.out_question_file), exist_ok=True)
    os.makedirs(os.path.dirname(args.out_answer_file), exist_ok=True)

    with open(args.out_question_file, "w", encoding="utf-8") as f:
        json.dump(out_questions, f, indent=2, ensure_ascii=False)
    print(f"[INFO] Saved questions to {args.out_question_file}")

    with open(args.out_answer_file, "w", encoding="utf-8") as f:
        json.dump(out_answers, f, indent=2, ensure_ascii=False)
    print(f"[INFO] Saved answers to {args.out_answer_file}")

    os.makedirs(args.subset_image_folder, exist_ok=True)

    for f in os.listdir(args.subset_image_folder):
        path = os.path.join(args.subset_image_folder, f)
        if os.path.isfile(path) and f.lower().endswith((".jpg", ".jpeg")):
            os.remove(path)

    print(f"[INFO] Copy images to {args.subset_image_folder}")
    missing = 0
    for img_name in sorted(needed_images):
        src_path = os.path.join(args.image_root, img_name)
        dst_path = os.path.join(args.subset_image_folder, img_name)

        if not os.path.isfile(src_path):
            print(f"[WARN] source image not found: {src_path}")
            missing += 1
            continue

        shutil.copy2(src_path, dst_path)

    if missing > 0:
        print(f"[WARN] {missing} images were missing in the source folder.")
    else:
        print("[INFO] All subset images copied successfully.")


if __name__ == "__main__":
    main()
