#%%
import json
import random
from pathlib import Path

# Input files
txt_train = "fgvc-aircraft-2013b/data/images_family_train.txt"
txt_test = "fgvc-aircraft-2013b/data/images_family_test.txt"
image_dir = Path("fgvc-aircraft-2013b/data/images")
output_json = "fgvc_aircraft_by_class_filtered_50.json"

# Step 1: Read and merge both label files
all_lines = []
for txt_file in [txt_train, txt_test]:
    if Path(txt_file).exists():
        with open(txt_file, "r") as f:
            all_lines.extend(f.readlines())

# Step 2: Build mapping from class name to image paths
class_to_images = {}
for line in all_lines:
    parts = line.strip().split()
    if len(parts) < 2:
        continue
    image_id = parts[0]
    class_name = " ".join(parts[1:])
    image_path = str(image_dir / f"{image_id}.jpg")
    class_to_images.setdefault(class_name, []).append(image_path)

# Step 3: Filter for classes with at least 15 images
filtered_classes = {k: v for k, v in class_to_images.items() if len(v) >= 15}

# Step 4: Randomly sample 50 class names with seed 42
random.seed(42)
sampled_class_names = random.sample(list(filtered_classes.keys()), min(50, len(filtered_classes)))

# Step 5: For each selected class, sample up to 30 images
final_output = {}
total_images = 0

for cls in sampled_class_names:
    images = filtered_classes[cls]
    random.seed(42)
    if len(images) > 30:
        selected = random.sample(images, 30)
    else:
        selected = images
    final_output[cls] = selected
    total_images += len(selected)

final_output["_total_objects"] = len(final_output)
final_output["_total_images"] = total_images

with open(output_json, "w") as f:
    json.dump(final_output, f, indent=2)

{
    "filtered_classes_with_15+": len(filtered_classes),
    "final_output_classes": len(final_output),
    "final_output_images": total_images
}



#%%

import json
import uuid
import random
import shutil
from pathlib import Path
from sklearn.model_selection import train_test_split

random.seed(42)

input_json = "fgvc_aircraft_by_class_filtered_50.json"
image_root = Path("")
output_dir = Path("aircraft")
train_img_dir = output_dir / "train/images"
val_img_dir = output_dir / "val/images"
missing_file = output_dir / "missing_images.txt"

train_img_dir.mkdir(parents=True, exist_ok=True)
val_img_dir.mkdir(parents=True, exist_ok=True)

# File paths
train_json_path = output_dir / "train/train.json"
train_json_path_2 = output_dir / "train/train_2.json"
val_json_path = output_dir / "val/val.json"
val_json_path_2 = output_dir / "val/val_2.json"
val_ans_path = output_dir / "val/val_ans.json"
val_ans_path_2 = output_dir / "val/val_ans_2.json"

with open(input_json) as f:
    species_to_images = json.load(f)

eligible_classes = {cls: imgs for cls, imgs in species_to_images.items() if len(imgs) >= 15}
selected_classes = random.sample(list(eligible_classes.keys()), 50)

result = {}
for cls in selected_classes:
    imgs = eligible_classes[cls]
    result[cls] = random.sample(imgs, 30) if len(imgs) > 30 else imgs

species_list = list(result.keys())
choice_map = {s: f"{i+1}. {s}" for i, s in enumerate(species_list)}
choice_lines = [f"{i+1}. {s}" for i, s in enumerate(species_list)]
choice_text = "<image>\nWhat is the species name of Animal in the image? Choose one from below:\n" + "\n".join(choice_lines)

all_items = [(cls, img) for cls, imgs in result.items() for img in imgs]
train_items, val_items = train_test_split(all_items, test_size=0.2, random_state=42)

# Output containers
train_json, train_json_2 = [], []
val_json, val_json_2 = [], []
val_ans_json, val_ans_json_2 = [], []
missing = []

for cls, img_path in train_items:
    uid = str(uuid.uuid4())
    src = image_root / img_path
    dst = train_img_dir / f"{uid}.jpg"
    if not src.exists():
        missing.append(str(src))
        continue
    shutil.copy2(src, dst)

    train_json.append({
        "id": uid,
        "image": f"train/images/{uid}.jpg",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nWhat is the species name of Animal in the image?"
            },
            {
                "from": "gpt",
                "value": cls
            }
        ]
    })

    train_json_2.append({
        "id": uid,
        "image": f"train/images/{uid}.jpg",
        "conversations": [
            {
                "from": "human",
                "value": choice_text
            },
            {
                "from": "gpt",
                "value": choice_map[cls]
            }
        ]
    })

for cls, img_path in val_items:
    uid = str(uuid.uuid4())
    src = image_root / img_path
    dst = val_img_dir / f"{uid}.jpg"
    if not src.exists():
        missing.append(str(src))
        continue
    shutil.copy2(src, dst)

    basic_prompt = "<image>\nWhat is the species name of Animal in the image?"

    val_json.append({
        "question_id": uid,
        "image": f"val/images/{uid}.jpg",
        "category": "default",
        "text": basic_prompt,
        "id": uid
    })
    val_ans_json.append({
        "question_id": uid,
        "prompt": basic_prompt,
        "text": cls,
        "answer_id": None,
        "model_id": None,
        "metadata": {}
    })

    val_json_2.append({
        "question_id": uid,
        "image": f"val/images/{uid}.jpg",
        "category": "default",
        "text": choice_text,
        "id": uid
    })
    val_ans_json_2.append({
        "question_id": uid,
        "prompt": choice_text,
        "text": choice_map[cls],
        "answer_id": None,
        "model_id": None,
        "metadata": {}
    })

# Save all outputs
with open(train_json_path, "w") as f:
    json.dump(train_json, f, indent=2)
with open(train_json_path_2, "w") as f:
    json.dump(train_json_2, f, indent=2)
with open(val_json_path, "w") as f:
    json.dump(val_json, f, indent=2)
with open(val_json_path_2, "w") as f:
    json.dump(val_json_2, f, indent=2)
with open(val_ans_path, "w") as f:
    json.dump(val_ans_json, f, indent=2)
with open(val_ans_path_2, "w") as f:
    json.dump(val_ans_json_2, f, indent=2)
with open(missing_file, "w") as f:
    f.write("\n".join(missing))

len(train_json), len(train_json_2), len(val_json), len(val_json_2)


#%%
