#!/usr/bin/env python3
import os
import json
import glob
import random
import argparse
import numpy as np
import faiss
from PIL import Image
from tqdm import tqdm
import torchvision.transforms as transforms


def parse_image_id(filepath: str) -> int:
    basename = os.path.splitext(os.path.basename(filepath))[0]  # e.g. "000000123456"
    return int(basename)


def load_and_flatten(filepath: str, transform) -> np.ndarray:
    img = Image.open(filepath).convert("RGB")
    img_tensor = transform(img)  # 3 x 520 x 520 by default
    return img_tensor.view(-1).numpy().astype(np.float32)


def main():
    parser = argparse.ArgumentParser(description="Cluster COCO images using random centroids + Faiss.")
    parser.add_argument("--data_dir", type=str, required=True,
                        help="Path to the folder containing COCO images (train2017 or val2017).")
    parser.add_argument("--split", type=str, required=True, choices=["train", "val"],
                        help="Which split (train or val) these images belong to.")
    parser.add_argument("--K", type=int, required=True, help="Number of clusters.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for processing images.")
    parser.add_argument("--output_dir", type=str, default=".", help="Where to save clustering JSONs.")
    args = parser.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize((520, 520)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    dummy_img = Image.new("RGB", (520, 520))
    dummy_tensor = transform(dummy_img)
    d = dummy_tensor.numel()

    centroids = np.random.rand(args.K, d).astype(np.float32)
    index = faiss.IndexFlatL2(d)
    index.add(centroids)

    image_paths = sorted(glob.glob(os.path.join(args.data_dir, "*.jpg")))
    print(f"Found {len(image_paths)} images in {args.data_dir}.")

    grouping = {str(i): {"train": [], "val": []} for i in range(args.K)}

    for start_idx in tqdm(range(0, len(image_paths), args.batch_size),
                          total=(len(image_paths) // args.batch_size + 1)):
        batch_slice = image_paths[start_idx:start_idx + args.batch_size]
        if not batch_slice:
            break
        batch_vecs = np.stack([load_and_flatten(p, transform) for p in batch_slice], axis=0)
        _, indices = index.search(batch_vecs, 1)

        for i, filepath in enumerate(batch_slice):
            cluster_id = indices[i, 0]
            image_id = parse_image_id(filepath)
            grouping[str(cluster_id)][args.split].append(image_id)

    os.makedirs(args.output_dir, exist_ok=True)
    output_path = os.path.join(args.output_dir, f"grouping_K_{args.K}_seed_{args.seed}.json")
    if os.path.exists(output_path):
        # merge with existing file (other split)
        with open(output_path, "r") as f:
            existing = json.load(f)
        for cid in grouping.keys():
            existing[cid][args.split].extend(grouping[cid][args.split])
        grouping = existing

    with open(output_path, "w") as f:
        json.dump(grouping, f)

    print(f"\nSaved clusters to: {output_path}")


if __name__ == "__main__":
    main()
