import faiss
import joblib
import numpy as np

# from sklearn.decomposition import PCA

from scripts.baseline_kmeans.discretizers.interface import ClusteringDiscretizer


class ClusteringDiscretizerFaissKMeans(ClusteringDiscretizer):
    def __init__(
        self,
        n_clusters: int = 128,
        max_iter: int = 300,
        n_init: int = 3,
        n_components: int = None,  # For PCA
        random_state: int = None,
        gpu: bool = True,
        **kwargs,
    ) -> None:
        """
        Initialize the FAISS KMeans model with the specified parameters and optionally apply PCA.
        Args:
            n_clusters (int): Number of clusters.
            max_iter (int): Maximum number of iterations.
            n_init (int): Number of initializations for KMeans (nredo in FAISS).
            n_components (int): Number of PCA components (None if PCA is not applied).
            random_state (int): Random seed for reproducibility.
            gpu (bool): Whether to use GPU for FAISS operations.
            kwargs: Additional parameters.
        """
        self.n_components = n_components
        # self.pca = PCA(n_components=n_components) if n_components is not None else None
        self.pca = None
        self.n_clusters = n_clusters
        self.max_iter = max_iter
        self.n_init = n_init
        self.gpu = gpu
        self.random_state = random_state

        # FAISS KMeans setup
        self.kmeans = None

    def train(self, data: np.ndarray) -> None:
        """
        Train the FAISS KMeans model on the provided data.
        If PCA is enabled, apply PCA before training.
        Args:
            data (np.ndarray): A 2D numpy array with shape (n_samples, n_features).
        """
        # Apply PCA if PCA is enabled
        if self.pca:
            print(
                f"Applying PCA to the data from dimension {data.shape[1]} to {self.n_components}..."
            )
            data = self.pca.fit_transform(data)

        d = data.shape[1]  # Dimensionality of input data
        print(f"Training FAISS KMeans model with {self.n_clusters} clusters...")

        # Initialize FAISS KMeans model
        self.kmeans = faiss.Kmeans(
            d=d,
            k=self.n_clusters,
            niter=self.max_iter,
            verbose=True,
            gpu=self.gpu,
            nredo=self.n_init,
            max_points_per_centroid=1000000,  # to ensure we use all datapoints
            seed=self.random_state if self.random_state is not None else -1,
        )

        # Train the FAISS KMeans model
        self.kmeans.train(data)

    def discrete(self, vectors: np.ndarray) -> np.ndarray[int]:
        """
        Assign the input vectors to clusters using the trained model.
        Args:
            vectors (np.ndarray): A 2D numpy array of vectors to be clustered.

        Returns:
            np.ndarray[int]: An array of integers representing the cluster assignments.
        """
        vectors_array = np.vstack(vectors)

        # Apply PCA transformation if PCA is enabled
        if self.pca:
            vectors_array = self.pca.transform(vectors_array)

        # Assign vectors to the nearest cluster
        _, cluster_indices = self.kmeans.index.search(vectors_array, 1)
        return cluster_indices.ravel()

    def save(self, path: str) -> None:
        """
        Save the trained FAISS KMeans model and PCA (if used) to the specified path.
        Args:
            path (str): The file path to save the model.
        """
        print(f"Saving FAISS KMeans model to {path}...")

        # Save the KMeans index and metadata (including PCA if applied)
        cpu_index = (
            faiss.index_gpu_to_cpu(self.kmeans.index) if self.gpu else self.kmeans.index
        )
        faiss.write_index(cpu_index, f"{path}")

        kmeans_metadata = {
            "n_clusters": self.kmeans.k,
            "centroids": self.kmeans.centroids,
            "d": self.kmeans.d,
            "obj": self.kmeans.obj,
            "seed": self.kmeans.cp.seed,
            "iteration_stats": self.kmeans.iteration_stats,
            "pca": self.pca,
            "gpu": self.gpu,
            "n_components": self.n_components,
        }
        joblib.dump(kmeans_metadata, f"{path}.metadata")

    @classmethod
    def load(cls, path: str) -> "ClusteringDiscretizerFaissKMeans":
        """
        Load a pre-trained FAISS KMeans model and PCA from the specified path.
        Args:
            path (str): The file path from which to load the model.

        Returns:
            ClusteringDiscretizerFaissKMeans: An instance of this class with the loaded model.
        """
        print(f"Loading FAISS KMeans model from {path}...")

        # Create a new instance of the class without calling __init__
        instance = cls.__new__(cls)

        # Load the FAISS index and metadata
        loaded_index = faiss.read_index(f"{path}")
        kmeans_metadata = joblib.load(f"{path}.metadata")

        # Manually set attributes from metadata
        instance.n_clusters = kmeans_metadata["n_clusters"]
        instance.gpu = kmeans_metadata.get(
            "gpu", True
        )  # Default to True if not present
        instance.n_components = kmeans_metadata.get("n_components", None)
        instance.pca = kmeans_metadata.get("pca", None)

        # Initialize the FAISS KMeans model using metadata
        instance.kmeans = faiss.Kmeans(
            d=kmeans_metadata["d"],
            k=instance.n_clusters,
            niter=100,  # Default value (can be customized)
            verbose=True,
            gpu=instance.gpu,  # Use the loaded GPU setting
            nredo=4,
            seed=kmeans_metadata["seed"],
        )

        instance.kmeans.centroids = kmeans_metadata["centroids"]
        instance.kmeans.obj = kmeans_metadata["obj"]
        instance.kmeans.iteration_stats = kmeans_metadata["iteration_stats"]

        # Load the index back into GPU if it was trained on GPU
        if instance.gpu:
            print("Moving the index to GPU...")
            gpu_resources = faiss.StandardGpuResources()
            instance.kmeans.index = faiss.index_cpu_to_gpu(
                gpu_resources, 0, loaded_index
            )
        else:
            instance.kmeans.index = loaded_index

        instance.pca = kmeans_metadata["pca"]

        return instance
