import joblib
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from sklearn.decomposition import PCA

from scripts.baseline_kmeans.discretizers.interface import ClusteringDiscretizer


class ClusteringDiscretizerScikitKMeans(ClusteringDiscretizer):
    def __init__(
        self,
        n_clusters: int = 8,
        max_iter: int = 300,
        tol: float = 1e-4,
        batch_size: int = 1000,
        n_init: int = 3,
        n_components: int = None,  # For PCA
        random_state: int = None,
        **kwargs,
    ) -> None:
        """
        Initialize the MiniBatchKMeans model with the specified parameters and optionally apply PCA.
        Args:
            n_clusters (int): Number of clusters.
            max_iter (int): Maximum number of iterations.
            tol (float): Tolerance to declare convergence.
            batch_size (int): Batch size for MiniBatchKMeans.
            n_init (int): Number of initializations for KMeans.
            n_components (int): Number of PCA components (None if PCA is not applied).
            random_state (int): Random seed for reproducibility.
            kwargs: Additional parameters.
        """
        # Optionally initialize PCA for dimensionality reduction
        self.n_components = n_components
        self.pca = PCA(n_components=n_components) if n_components is not None else None

        # Initialize the MiniBatchKMeans model
        self.model = MiniBatchKMeans(
            n_clusters=n_clusters,
            max_iter=max_iter,
            n_init=n_init,
            tol=tol,
            batch_size=batch_size,
            verbose=1,
            random_state=random_state,
        )

    def train(self, data: np.ndarray) -> None:
        """
        Train the MiniBatchKMeans model on the provided data.
        If PCA is enabled, apply PCA before training.
        Args:
            data (np.ndarray): A 2D numpy array with shape (n_samples, n_features).
        """
        # Apply PCA if PCA is enabled
        if self.pca:
            print(
                f"Applying PCA to the data from dimension {data.shape[1]} to {self.n_components}..."
            )
            data = self.pca.fit_transform(data)

        # Train the KMeans model
        print(f"Training KMeans model with {self.model.n_clusters} clusters...")
        print(self.model.get_params())
        self.model.fit(data)

    def discrete(self, vectors: np.ndarray) -> np.ndarray[int]:
        """
        Assign the input vectors to clusters using the trained model.
        Args:
            vectors (list[np.ndarray]): A list of vectors to be clustered.

        Returns:
            np.ndarray[int]: An array of integers representing the cluster assignments.
        """
        vectors_array = np.vstack(vectors)

        # Apply PCA transformation if PCA is enabled
        if self.pca:
            vectors_array = self.pca.transform(vectors_array)

        return self.model.predict(vectors_array)

    def save(self, path: str) -> None:
        """
        Save the trained PCA and KMeans model to the specified path.
        Args:
            path (str): The file path to save the model.
        """
        # Save both the KMeans model and the PCA (if PCA is applied)
        model_data = {"kmeans_model": self.model, "pca_model": self.pca}
        joblib.dump(model_data, path)

    @classmethod
    def load(cls, path: str) -> "ClusteringDiscretizerScikitKMeans":
        """
        Load a pre-trained KMeans and PCA model from the specified path.
        Args:
            path (str): The file path from which to load the model.

        Returns:
            ClusteringDiscretizerScikitKMeans: An instance of this class with the loaded model.
        """
        # Create a new instance of the class without calling __init__
        instance = cls.__new__(cls)

        # Load the saved models (KMeans and PCA)
        model_data = joblib.load(path)
        instance.model = model_data["kmeans_model"]
        instance.pca = model_data["pca_model"]

        return instance
