"""
Cluster iWildCam domains using average CLIP embeddings and KMeans.
Saves domain->cluster mapping CSV and prints summary.
"""

import os
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from sklearn.cluster import KMeans
import open_clip
from wilds import get_dataset
from torch.utils.data import DataLoader

# --- CONFIG ---
dataset_name = "iwildcam"
dataset_dir = "/home/datasets/iwildcam/"
split_type = "train"
n_clusters = 10
device = "cuda" if torch.cuda.is_available() else "cpu"
random_seed = 42

# --- Load CLIP model ---
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='openai')
model = model.to(device).eval()

# --- Load dataset ---
dataset = get_dataset(dataset=dataset_name, download=True, root_dir=dataset_dir)
subset = dataset.get_subset(split_type, transform=preprocess)

# --- Dataloader ---
dataloader = DataLoader(subset, batch_size=128, shuffle=False)

# --- Extract metadata info ---
metadata_fields = subset.metadata_fields
domain_field_idx = metadata_fields.index('location')

# --- Collect per-domain image embeddings ---
domain_embeddings = defaultdict(list)

print(">>> Encoding images and grouping by domain...")
for images, labels, metadata in tqdm(dataloader):
    images = images.to(device)

    with torch.no_grad():
        feats = model.encode_image(images).cpu()  

    domain_labels = metadata[:, domain_field_idx]  

    for domain_label, feat in zip(domain_labels, feats):
        domain_embeddings[int(domain_label.item())].append(feat)

# --- Average embeddings per domain ---
print(">>> Averaging embeddings per domain...")
domain_ids = []
avg_embeddings = []
sample_weights = []

for domain_id, feats in domain_embeddings.items():
    avg_feat = torch.stack(feats).mean(dim=0)
    domain_ids.append(domain_id)
    avg_embeddings.append(avg_feat)
    sample_weights.append(len(feats))  

avg_embeddings_tensor = torch.stack(avg_embeddings)  # [#domains, 512]

# --- KMeans clustering ---
print(f">>> Clustering {len(domain_ids)} domains into {n_clusters} clusters...")

cluster_kmeans = KMeans(n_clusters=n_clusters, random_state=random_seed, n_init=10)
cluster_labels = cluster_kmeans.fit_predict(avg_embeddings_tensor.numpy())

# --- Save mapping ---
mapping_df = pd.DataFrame({
    'domain_name': domain_ids,
    'cluster_label': cluster_labels
})

csv_path = os.path.join(dataset_dir, "domain_to_cluster_map.csv")
mapping_df.to_csv(csv_path, index=False)

print(f">>> Saved domain_to_cluster_map.csv to: {csv_path}")

# --- Print summary ---
print("\nSummary of clusters:")
cluster_image_counts = {i: 0 for i in range(n_clusters)}
cluster_class_labels = {i: set() for i in range(n_clusters)}
cluster_domain_sets = {i: set() for i in range(n_clusters)}

domain_to_cluster = dict(zip(domain_ids, cluster_labels))
for domain_id, cluster_id in domain_to_cluster.items():
    cluster_domain_sets[cluster_id].add(domain_id)

for images, labels, metadata in tqdm(dataloader):
    domain_labels = metadata[:, domain_field_idx]  
    for domain_label, class_label in zip(domain_labels, labels):
        cluster_id = domain_to_cluster[int(domain_label.item())]
        cluster_image_counts[cluster_id] += 1
        cluster_class_labels[cluster_id].add(int(class_label.item()))

# --- Print cluster stats ---
for cluster_id in sorted(cluster_image_counts.keys()):
    image_count = cluster_image_counts[cluster_id]
    domain_count = len(cluster_domain_sets[cluster_id])
    class_count = len(cluster_class_labels[cluster_id])
    print(f"Cluster {cluster_id}: {domain_count} domains, {image_count} images, {class_count} unique classes")

