#%%
import random
import json
import shutil
from pathlib import Path

random.seed(42)

src_base = Path("/fs/scratch/PAS2099/dataset/INat21")
temp_merge_dir = Path("./inat_animalia_merged_temp")
final_dir = Path("./inat_animalia_50_final")
temp_merge_dir.mkdir(exist_ok=True, parents=True)
final_dir.mkdir(exist_ok=True, parents=True)

class_to_images = {}

# Step 1: Merge same-named Animalia classes across train/val
for split in ["train", "val"]:
    for folder in (src_base / split).iterdir():
        if not folder.is_dir():
            continue
        parts = folder.name.split("_")
        if len(parts) > 1 and parts[1] == "Plantae":
            class_name = folder.name
            image_paths = list(folder.glob("*.[jJpP]*[gG]"))
            if image_paths:
                class_to_images.setdefault(class_name, []).extend(image_paths)

# Step 2: Filter for classes with ≥ 15 images
filtered_classes = {cls: imgs for cls, imgs in class_to_images.items() if len(imgs) >= 15}

# Step 3: Randomly select 50 classes
selected_classes = random.sample(list(filtered_classes.keys()), 50)

# Step 4: Copy selected class images and generate JSON
output_json = {}
total_images = 0

for cls in selected_classes:
    imgs = filtered_classes[cls]
    target_dir = final_dir / cls
    target_dir.mkdir(parents=True, exist_ok=True)
    rel_paths = []
    for img_path in imgs:
        dest_path = target_dir / img_path.name
        if not dest_path.exists():
            shutil.copy(img_path, dest_path)
        rel_paths.append(f"{cls}/{img_path.name}")
    output_json[cls] = rel_paths
    total_images += len(rel_paths)

# Step 5: Save JSON file with metadata
output_json["_total_objects"] = len(selected_classes)
output_json["_total_images"] = total_images

with open(final_dir / "image_paths.json", "w") as f:
    json.dump(output_json, f, indent=2)

{
    "total_objects": len(selected_classes),
    "total_images": total_images
}

#%%
import json
import random

# Simulate loading from your actual JSON (replace with your real file path)
json_path = "inat_animalia_50_final/image_paths.json"
output_path = "filtered_image_list_cleaned.json"

with open(json_path, "r") as f:
    data = json.load(f)

random.seed(42)

filtered = {}
total_images = 0

for long_key, images in data.items():
    if not isinstance(images, list):
        continue
    if len(images) >= 15:
        if len(images) > 30:
            selected_images = random.sample(images, 30)
        else:
            selected_images = images

        parts = long_key.split("_")
        clean_key = f"{parts[-2]} {parts[-1]}"
        filtered[clean_key] = selected_images
        total_images += len(selected_images)

filtered["_total_objects"] = len(filtered)
filtered["_total_images"] = total_images

with open(output_path, "w") as f:
    json.dump(filtered, f, indent=2)

{
    "final_class_count": filtered["_total_objects"],
    "final_image_count": filtered["_total_images"]
}


#%%
import json
import uuid
import random
import shutil
from pathlib import Path
from sklearn.model_selection import train_test_split

random.seed(42)

input_json = "filtered_image_list_cleaned.json"
image_root = Path("inat_animalia_50_final")
output_dir = Path("plant")
train_img_dir = output_dir / "train/images"
val_img_dir = output_dir / "val/images"
missing_file = output_dir / "missing_images.txt"

train_img_dir.mkdir(parents=True, exist_ok=True)
val_img_dir.mkdir(parents=True, exist_ok=True)

# File paths
train_json_path = output_dir / "train/train.json"
train_json_path_2 = output_dir / "train/train_2.json"
val_json_path = output_dir / "val/val.json"
val_json_path_2 = output_dir / "val/val_2.json"
val_ans_path = output_dir / "val/val_ans.json"
val_ans_path_2 = output_dir / "val/val_ans_2.json"

with open(input_json) as f:
    species_to_images = json.load(f)

eligible_classes = {cls: imgs for cls, imgs in species_to_images.items() if len(imgs) >= 15}
selected_classes = random.sample(list(eligible_classes.keys()), 50)

result = {}
for cls in selected_classes:
    imgs = eligible_classes[cls]
    result[cls] = random.sample(imgs, 30) if len(imgs) > 30 else imgs

species_list = list(result.keys())
choice_map = {s: f"{i+1}. {s}" for i, s in enumerate(species_list)}
choice_lines = [f"{i+1}. {s}" for i, s in enumerate(species_list)]
choice_text = "<image>\nWhat is the species name of Animal in the image? Choose one from below:\n" + "\n".join(choice_lines)

all_items = [(cls, img) for cls, imgs in result.items() for img in imgs]
train_items, val_items = train_test_split(all_items, test_size=0.2, random_state=42)

# Output containers
train_json, train_json_2 = [], []
val_json, val_json_2 = [], []
val_ans_json, val_ans_json_2 = [], []
missing = []

for cls, img_path in train_items:
    uid = str(uuid.uuid4())
    src = image_root / img_path
    dst = train_img_dir / f"{uid}.jpg"
    if not src.exists():
        missing.append(str(src))
        continue
    shutil.copy2(src, dst)

    train_json.append({
        "id": uid,
        "image": f"train/images/{uid}.jpg",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nWhat is the species name of Animal in the image?"
            },
            {
                "from": "gpt",
                "value": cls
            }
        ]
    })

    train_json_2.append({
        "id": uid,
        "image": f"train/images/{uid}.jpg",
        "conversations": [
            {
                "from": "human",
                "value": choice_text
            },
            {
                "from": "gpt",
                "value": choice_map[cls]
            }
        ]
    })

for cls, img_path in val_items:
    uid = str(uuid.uuid4())
    src = image_root / img_path
    dst = val_img_dir / f"{uid}.jpg"
    if not src.exists():
        missing.append(str(src))
        continue
    shutil.copy2(src, dst)

    basic_prompt = "<image>\nWhat is the species name of Animal in the image?"

    val_json.append({
        "question_id": uid,
        "image": f"val/images/{uid}.jpg",
        "category": "default",
        "text": basic_prompt,
        "id": uid
    })
    val_ans_json.append({
        "question_id": uid,
        "prompt": basic_prompt,
        "text": cls,
        "answer_id": None,
        "model_id": None,
        "metadata": {}
    })

    val_json_2.append({
        "question_id": uid,
        "image": f"val/images/{uid}.jpg",
        "category": "default",
        "text": choice_text,
        "id": uid
    })
    val_ans_json_2.append({
        "question_id": uid,
        "prompt": choice_text,
        "text": choice_map[cls],
        "answer_id": None,
        "model_id": None,
        "metadata": {}
    })

# Save all outputs
with open(train_json_path, "w") as f:
    json.dump(train_json, f, indent=2)
with open(train_json_path_2, "w") as f:
    json.dump(train_json_2, f, indent=2)
with open(val_json_path, "w") as f:
    json.dump(val_json, f, indent=2)
with open(val_json_path_2, "w") as f:
    json.dump(val_json_2, f, indent=2)
with open(val_ans_path, "w") as f:
    json.dump(val_ans_json, f, indent=2)
with open(val_ans_path_2, "w") as f:
    json.dump(val_ans_json_2, f, indent=2)
with open(missing_file, "w") as f:
    f.write("\n".join(missing))

len(train_json), len(train_json_2), len(val_json), len(val_json_2)
