import torch
import os
import json
import numpy as np

from collections import defaultdict
from tqdm import tqdm
from torchvision.io import read_image, ImageReadMode
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torchvision import transforms

try:
    import faiss
except ImportError:
    assert torch.cuda.is_available(), "faiss-gpu requires CUDA"
    os.system("pip install faiss-gpu")
    import faiss

import warnings
warnings.filterwarnings('ignore')

train_group = defaultdict(lambda: list())


class ImagenetTrainClassDataset(Dataset):
    def __init__(self, path: str, class_id: int, transform):
        assert path.split('/')[-1] == 'train'
        super().__init__()
        class_names = sorted(os.listdir(path))
        self.class_name = class_names[class_id]
        self.class_path = path + '/' + self.class_name

        self.img_names = sorted(os.listdir(self.class_path))
        self.transform = transform

    def __getitem__(self, idx):
        img_path = self.class_path + '/' + self.img_names[idx]
        image = read_image(img_path, ImageReadMode.RGB)
        return self.transform(image)

    def __len__(self):
        return len(self.img_names)

def get_train_clusters(train_path, num_centroids, seed):
    output_dir = os.path.join("imagenet_train_clusters", f"seed_{seed}", str(num_centroids))
    transform = transforms.Compose([
        transforms.Resize((224, 224))
    ])
    d = 224 * 224 * 3
    np.random.seed(seed)
    centroid_images = np.random.randint(256, size=(num_centroids, d)).astype('float32')

    cpu_index = faiss.IndexFlatL2(d)
    index = faiss.index_cpu_to_all_gpus(cpu_index)
    index.add(centroid_images)

    subsets = []
    name_list = []
    print("Loading images...")
    for i in range(1000):
        class_subset = ImagenetTrainClassDataset(train_path, class_id=i, transform=transform)
        subsets.append(class_subset)
        name_list += class_subset.img_names

    name_list = [name.split('.')[0] for name in name_list] # remove JPEG extension
    subset = ConcatDataset(subsets)
    train_dataloader = DataLoader(subset, batch_size=1000, shuffle=False, num_workers=2)

    print("Clustering...")
    cluster_list = np.array([], dtype=int)
    for images in tqdm(train_dataloader):
        images = images.reshape(images.size(0), -1).detach().numpy().astype('float32')
        _, I = index.search(images, 1)
        cluster_list = np.append(cluster_list, I.reshape(-1))

    assert len(cluster_list) == len(name_list)
    cluster_list = cluster_list.tolist()

    for name, cluster_id in zip(name_list, cluster_list):
        train_group[cluster_id].append(name)

    os.makedirs(output_dir, exist_ok=True)
    with open(f'{output_dir}/train_group.json', 'w') as f:
        json.dump(train_group, f, indent=2)

if __name__ == '__main__':
    image_dir = "imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC"
    assert os.path.exists(image_dir), "ImageNet dataset is not available"

    train_path = image_dir + '/train'
    assert os.path.exists(train_path), "ImageNet training set is not available"

    for seed in range(42, 47):
        for num_centroids in [100, 200, 300, 1000, 5000, 10000]:
            print(f"Seed: {seed}, Num Centroids: {num_centroids}")
            get_train_clusters(train_path, num_centroids, seed)
