import joblib
import numpy as np
from cuml.cluster import KMeans as cuMLKMeans
from cuml.decomposition import PCA as cuMLPCA
import cupy as cp

from scripts.baseline_kmeans.discretizers.interface import ClusteringDiscretizer


class ClusteringDiscretizerCuMLKMeans(ClusteringDiscretizer):
    def __init__(
        self,
        n_clusters: int = 8,
        max_iter: int = 300,
        tol: float = 1e-4,
        n_init: int = 3,
        n_components: int = None,  # For PCA
        random_state: int = None,
        **kwargs,
    ) -> None:
        """
        Initialize the cuML KMeans model with the specified parameters and optionally apply PCA.
        Args:
            n_clusters (int): Number of clusters.
            max_iter (int): Maximum number of iterations.
            tol (float): Tolerance to declare convergence.
            n_init (int): Number of initializations for KMeans.
            n_components (int): Number of PCA components (None if PCA is not applied).
            random_state (int): Random seed for reproducibility.
            kwargs: Additional parameters.
        """
        # Optionally initialize cuML PCA for dimensionality reduction
        self.n_components = n_components
        self.pca = (
            cuMLPCA(n_components=n_components) if n_components is not None else None
        )

        # Initialize the cuML KMeans model
        self.model = cuMLKMeans(
            n_clusters=n_clusters,
            max_iter=max_iter,
            n_init=n_init,
            tol=tol,
            random_state=random_state,
        )

    def train(self, data: np.ndarray) -> None:
        """
        Train the cuML KMeans model on the provided data.
        If PCA is enabled, apply PCA before training.
        Args:
            data (np.ndarray): A 2D numpy array with shape (n_samples, n_features).
        """
        # Convert data to GPU array (cuML uses cupy/cudf for computation)
        data_gpu = cp.asarray(data)

        # Apply PCA if PCA is enabled
        if self.pca:
            print(
                f"Applying PCA to the data from dimension {data.shape[1]} to {self.n_components}..."
            )
            data_gpu = self.pca.fit_transform(data_gpu)

        # Train the KMeans model
        print(f"Training cuML KMeans model with {self.model.n_clusters} clusters...")
        self.model.fit(data_gpu)

    def discrete(self, vectors: np.ndarray) -> np.ndarray[int]:
        """
        Assign the input vectors to clusters using the trained model.
        Args:
            vectors (list[np.ndarray]): A list of vectors to be clustered.

        Returns:
            np.ndarray[int]: An array of integers representing the cluster assignments.
        """
        vectors_array = np.vstack(vectors)

        # Convert input vectors to GPU array
        vectors_gpu = cp.asarray(vectors_array)

        # Apply PCA transformation if PCA is enabled
        if self.pca:
            vectors_gpu = self.pca.transform(vectors_gpu)

        # Return the cluster assignments (on GPU, convert back to numpy)
        return cp.asnumpy(self.model.predict(vectors_gpu))

    def save(self, path: str) -> None:
        """
        Save the trained PCA and KMeans model to the specified path.
        Args:
            path (str): The file path to save the model.
        """
        # Save both the KMeans model and the PCA (if PCA is applied)
        model_data = {"kmeans_model": self.model, "pca_model": self.pca}
        joblib.dump(model_data, path)

    @classmethod
    def load(cls, path: str) -> "ClusteringDiscretizerCuMLKMeans":
        """
        Load a pre-trained cuML KMeans and PCA model from the specified path.
        Args:
            path (str): The file path from which to load the model.

        Returns:
            ClusteringDiscretizerCuMLKMeans: An instance of this class with the loaded model.
        """
        # Create a new instance of the class without calling __init__
        instance = cls.__new__(cls)

        # Load the saved models (KMeans and PCA)
        model_data = joblib.load(path)
        instance.model = model_data["kmeans_model"]
        instance.pca = model_data["pca_model"]

        return instance
