import numpy as np
import faiss
import os
import pickle
import submitit
import tqdm

DATA_PATH = "../metadata/cc12m/clip_embeddings/"


class Faiss_Dbsacn():
    def __init__(self, X, gpu_device_num=0):
        self.X = X
        self.gpu_index = self.get_gpu_index(X, gpu_device_num)

    @staticmethod
    def get_gpu_index(X, gpu_device_num=0):
        d = X.shape[1]
        res = faiss.StandardGpuResources()
        flat_config = faiss.GpuIndexFlatConfig()
        flat_config.device = gpu_device_num
        index = faiss.GpuIndexFlatIP(res, d, flat_config)
        index.add(X)
        return index

    def neighborhood_search(self, search_size=2048):
        # Split the data into chunks to avoid memory issues
        chunk_Y = np.split(self.Y, 1, axis=0)
        D_list = []
        I_list = []
        for chunk in tqdm.tqdm(chunk_Y):
            D, I = self.gpu_index.search(chunk, search_size)
            D_list.append(D)
            I_list.append(I)
        D = np.concatenate(D_list, axis=0)
        I = np.concatenate(I_list, axis=0)
        assert D.shape[0] == self.Y.shape[0], (
            "Distance array shape mismatch: "
            f"{D.shape[0]} vs {self.Y.shape[0]}"
        )
        assert I.shape[0] == self.Y.shape[0], (
            "Index array shape mismatch: "
            f"{I.shape[0]} vs {self.Y.shape[0]}"
        )
        return D, I

    def train(self, Y):
        self.Y = Y
        self.num_samples = self.X.shape[0]
        self.num_features = self.X.shape[1]

        # Conduct neiborhood search
        self.D, self.I = self.neighborhood_search()
        self.n_neighbors = np.sum(~np.isnan(self.I), axis=1)

        return self


def main(i):
    print(f"RUNNING complexity {i}")

    text_path = (
            f"{DATA_PATH}/"
            f"ViT-SO400M-14-SigLIP-384_webli_text_c{i}.npy"
        )
    anchor = np.load(text_path).astype(np.float32)
    img_path = (
        f"{DATA_PATH}/"
        f"ViT-SO400M-14-SigLIP-384_webli_img.npy"
    )
    X = np.load(img_path).astype(np.float32)

    print(f"Loaded embedding shape: {X.shape}")

    gd = Faiss_Dbsacn(X)

    gd.train(anchor)

    print(gd.I)

    os.makedirs(f'{DATA_PATH}/clustering_clip_dbscan_IP', exist_ok=True)

    with open(f"{DATA_PATH}/clustering_clip_dbscan_IP/"
              f"ViT-SO400M-14-SigLIP-384_webli_distances_c{i}.pkl",
              'wb') as f:
        pickle.dump(gd.D, f, protocol=4)
    with open(f"{DATA_PATH}/clustering_clip_dbscan_IP/"
              f"ViT-SO400M-14-SigLIP-384_webli_labels_c{i}.pkl",
              'wb') as f:
        pickle.dump(gd.I, f, protocol=4)


if __name__ == '__main__':
    executor = submitit.AutoExecutor(folder="../logs/dbscan_logs/")
    executor.update_parameters(
        timeout_min=60*2,
        mem_gb=0,
        name="faiss_dbscan",
        slurm_array_parallelism=4,
        slurm_nodes=1,
        slurm_gpus_per_node=1,
        slurm_tasks_per_node=1,
        slurm_cpus_per_task=10,
        slurm_partition="",
    )
    with executor.batch():
        for complexity in range(4):
            print(f"Submitting job: i={complexity}")
            job = executor.submit(main, i=complexity)
            print(job)
