#%%
import random, json, uuid, shutil
from pathlib import Path
from collections import defaultdict
from sklearn.model_selection import train_test_split

random.seed(42)

input_txt = "CUB_200_2011/images.txt"
src_root = Path("CUB_200_2011/images")
dst_root = Path("bird")
train_img_dir = dst_root / "train/images"
val_img_dir = dst_root / "val/images"
missing_file = dst_root / "missing_images.txt"

train_img_dir.mkdir(parents=True, exist_ok=True)
val_img_dir.mkdir(parents=True, exist_ok=True)

# Step 1: Parse image list into class -> images
class_to_images = defaultdict(list)
with open(input_txt, "r") as f:
    for line in f:
        parts = line.strip().split()
        path = parts[-1]
        cls = path.split("/")[0].split(".", 1)[1]
        class_to_images[cls].append(path)

# Step 2: Filter & sample
eligible_classes = [c for c in class_to_images if len(class_to_images[c]) >= 15]
selected_classes = random.sample(eligible_classes, 100)
result = {}
for c in selected_classes:
    imgs = class_to_images[c]
    result[c] = random.sample(imgs, 30) if len(imgs) > 30 else imgs

species_list = list(result.keys())
choice_map = {s: f"{i+1}. {s}" for i, s in enumerate(species_list)}
choice_lines = [f"{i+1}. {s}" for i, s in enumerate(species_list)]
prompt_with_choices = "<image>\nWhat species of bird is in the image? Choose one from below:\n" + "\n".join(choice_lines)

# Step 3: Split train/val
all_items = [(cls, img) for cls, imgs in result.items() for img in imgs]
train_set, val_set = train_test_split(all_items, test_size=0.2, random_state=42)

# Step 4: Containers
train_json, train_json_2 = [], []
val_json, val_json_2 = [], []
val_ans_json, val_ans_json_2 = [], []
missing = []

# Step 5: Train set
for cls, img_path in train_set:
    uid = str(uuid.uuid4())
    src = src_root / img_path
    dst = train_img_dir / f"{uid}.png"
    if not src.exists():
        missing.append(str(src))
        continue
    shutil.copy(src, dst)

    # train.json - no choices
    train_json.append({
        "id": uid,
        "image": f"train/images/{uid}.png",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nWhat is the species name of Bird in the image?"
            },
            {
                "from": "gpt",
                "value": cls
            }
        ]
    })

    # train_2.json - with choices
    train_json_2.append({
        "id": uid,
        "image": f"train/images/{uid}.png",
        "conversations": [
            {
                "from": "human",
                "value": prompt_with_choices
            },
            {
                "from": "gpt",
                "value": choice_map[cls]
            }
        ]
    })

# Step 6: Val set
for cls, img_path in val_set:
    uid = str(uuid.uuid4())
    src = src_root / img_path
    dst = val_img_dir / f"{uid}.png"
    if not src.exists():
        missing.append(str(src))
        continue
    shutil.copy(src, dst)

    # val.json - no choices
    prompt_no_choices = "<image>\nWhat is the species name of Bird in the image?"
    val_json.append({
        "question_id": uid,
        "image": f"val/images/{uid}.png",
        "category": "default",
        "text": prompt_no_choices,
        "id": uid
    })
    val_ans_json.append({
        "question_id": uid,
        "prompt": prompt_no_choices,
        "text": cls,
        "answer_id": None,
        "model_id": None,
        "metadata": {}
    })

    # val_2.json - with choices
    val_json_2.append({
        "question_id": uid,
        "image": f"val/images/{uid}.png",
        "category": "default",
        "text": prompt_with_choices,
        "id": uid
    })
    val_ans_json_2.append({
        "question_id": uid,
        "prompt": prompt_with_choices,
        "text": choice_map[cls],
        "answer_id": None,
        "model_id": None,
        "metadata": {}
    })

# Step 7: Save all JSONs
(dst_root / "train").mkdir(parents=True, exist_ok=True)
(dst_root / "val").mkdir(parents=True, exist_ok=True)

with open(dst_root / "train/train.json", "w") as f:
    json.dump(train_json, f, indent=2)
with open(dst_root / "train/train_2.json", "w") as f:
    json.dump(train_json_2, f, indent=2)

with open(dst_root / "val/val.json", "w") as f:
    json.dump(val_json, f, indent=2)
with open(dst_root / "val/val_2.json", "w") as f:
    json.dump(val_json_2, f, indent=2)
with open(dst_root / "val/val_ans.json", "w") as f:
    json.dump(val_ans_json, f, indent=2)
with open(dst_root / "val/val_ans_2.json", "w") as f:
    json.dump(val_ans_json_2, f, indent=2)

with open(missing_file, "w") as f:
    f.write("\n".join(missing))

