import argparse
import json
import os
import random
import glob

import faiss
import numpy as np
import xml.etree.ElementTree as ET
from tqdm import tqdm
from PIL import Image

import torchvision.transforms as transforms


def parse_val_name(filename: str) -> str:
    """ Extracts '00000001' from 'ILSVRC2012_val_00000001.JPEG' """
    base = os.path.splitext(os.path.basename(filename))[0]
    base = base.replace("ILSVRC2012_val_", "")
    return base


def parse_train_name(filepath: str) -> str:
    """ Extracts 'n01440764_10040' from 'n01440764/n01440764_10040.JPEG'."""
    base = os.path.splitext(filepath)[0]
    return base.split("/")[-1]


def load_and_flatten(file: str) -> np.ndarray:
    """ Loads an image, resizes to 256x256, center-crops 224x224, normalizes (ImageNet), returns flat np.float32. """
    transform = transforms.Compose([
        transforms.Resize(256, interpolation=Image.BILINEAR),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img = Image.open(file).convert("RGB")
    img = transform(img)
    return img.numpy().astype(np.float32).flatten()


def extract_class_from_xml(xml_path: str) -> str:
    """ Extracts the class (wnid) from an ImageNet val XML file. """
    tree = ET.parse(xml_path)
    root = tree.getroot()
    obj = root.find("object")
    return obj.find("name").text


def make_val_split(val_dir: str, val_xml_dir: str, seed: int):
    """
    Returns:
      val10k_items: list of tuples (class_id, val_name, img_path)
      test40k_items: list of tuples (class_id, val_name, img_path)
    """
    # Collect all val image paths
    val_paths = sorted(glob.glob(os.path.join(val_dir, "ILSVRC2012_val_*.JPEG")))
    if len(val_paths) != 50000:
        print(f"[Warning] Found {len(val_paths)} val images (expected 50000). Proceeding anyway...")

    # Map class_id -> list[(val_name, img_path)]
    by_class = {}
    for p in tqdm(val_paths, desc="Indexing val by class"):
        val_name = parse_val_name(p)
        xml_path = os.path.join(val_xml_dir, f"ILSVRC2012_val_{val_name}.xml")
        class_id = extract_class_from_xml(xml_path)  # wnid
        by_class.setdefault(class_id, []).append( (val_name, p) )

    # Sanity: expect ~1000 classes with ~50 images each
    print(f"Discovered {len(by_class)} classes in val.")

    rng = random.Random(100)
    val10k_items, test40k_items = [], []

    for cls, items in by_class.items():
        # Keep stable order then shuffle for reproducibility
        items = sorted(items, key=lambda x: x[0])
        rng.shuffle(items)
        picked_10 = items[:10]
        rest_40 = items[10:]
        for vn, p in picked_10:
            val10k_items.append( (cls, vn, p) )
        for vn, p in rest_40:
            test40k_items.append( (cls, vn, p) )

    print(f"val_new (10k): {len(val10k_items)}  |  test (40k): {len(test40k_items)}")
    return val10k_items, test40k_items


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--val_dir", type=str, required=True,
                        help="Path to val images (50k) named ILSVRC2012_val_XXXXXX.JPEG.")
    parser.add_argument("--val_xml_dir", type=str, required=True,
                        help="Path to XML annotations for val images.")
    parser.add_argument("--output_dir", type=str, default=".",
                        help="Where to save outputs.")
    parser.add_argument("--seed", type=int, default=100,
                        help="Random seed for split & centroids.")
    parser.add_argument("--K", type=int, required=True, help="Number of clusters")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for feature extraction.")
    args = parser.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)

    os.makedirs(args.output_dir, exist_ok=True)

    # 1) Create val10k / test40k split, stratified
    val10k_items, test40k_items = make_val_split(args.val_dir, args.val_xml_dir, args.seed)
    val10k_list = [f"{cls}_{vn}" for (cls, vn, _) in val10k_items]
    test40k_list = [f"{cls}_{vn}" for (cls, vn, _) in test40k_items]
    with open(os.path.join(args.output_dir, f"val10k_list_seed{args.seed}.json"), "w") as f:
        json.dump(val10k_list, f)
    with open(os.path.join(args.output_dir, f"test40k_list_seed{args.seed}.json"), "w") as f:
        json.dump(test40k_list, f)

    # 2) Clustering on val10k
    print("\n[Clustering] Running clustering on val10k only...")
    # Create centroids similar to the above logic (for reuse in training if needed)
    d = 224 * 224 * 3
    rng = np.random.RandomState(args.seed)
    centroid_vectors = rng.normal(0, 1, (args.K, d)).astype(np.float32)

    # Build index once for val10k
    index = faiss.IndexFlatL2(d)
    index.add(centroid_vectors)

    val_grouping = {str(i): [] for i in range(args.K)}
    for start in tqdm(range(0, len(val10k_items), args.batch_size), desc="Assigning val10k"):
        batch = val10k_items[start:start+args.batch_size]
        batch_vecs = np.stack([load_and_flatten(p) for _,_,p in batch], axis=0).astype(np.float32)
        _, idx = index.search(batch_vecs, 1)
        for i, (cls, vn, _) in enumerate(batch):
            cluster_id = int(idx[i, 0])
            val_grouping[str(cluster_id)].append(f"{cls}_{vn}")

    with open(os.path.join(args.output_dir, f"val10k_grouping_K{args.K}_seed{args.seed}.json"), "w") as f:
        json.dump(val_grouping, f)
    print(f"Saved val10k_grouping_K{args.K}_seed{args.seed}.json")


if __name__ == "__main__":
    main()
