import argparse

import numpy as np

from scripts.baseline_kmeans.discretizers.factory import (
    get_clustering_discretizer,
    clustering_model_choices,
)


def main(args: argparse.Namespace) -> None:
    """
    Main function to load embeddings from a memmap, train a KMeans model using a specified implementation, and save the model.

    Args:
    - args (argparse.Namespace): Parsed command-line arguments.
    """
    # Determine the dtype based on whether fp16 is used
    dtype = np.float16 if args.use_fp16 else np.float32

    # Load the embeddings from the memmap
    print(f"Loading embeddings from {args.memmap_file}...")
    embeddings = np.memmap(
        args.memmap_file,
        dtype=dtype,
        mode="r",
        shape=(args.dstore_size, args.dimension),
    )
    print(f"Embeddings loaded with shape: {embeddings.shape}!")

    # Initialize the clustering model using the factory function
    print(f"Training {args.model} model with {args.num_clusters} clusters...")
    clustering_model = get_clustering_discretizer(
        model=args.model,
        n_clusters=args.num_clusters,
        n_components=args.n_components,
        n_init=args.n_init,
        batch_size=args.batch_size,
        max_iter=args.max_iter,
        tol=args.tolerance,
        use_gpu=args.use_gpu,
        random_state=args.random_state,
    )

    # Train the model
    clustering_model.train(embeddings)
    print(f"Model training completed!")

    # Save the trained model to disk
    print(f"Saving trained model to {args.model_path}...")
    clustering_model.save(args.model_path)

    print(f"Model saved successfully!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Train a KMeans model on precomputed embeddings and save the model to disk."
    )

    # Memmap arguments
    parser.add_argument(
        "--memmap_file",
        type=str,
        help="File path of the memmap containing embeddings",
    )
    parser.add_argument(
        "--dstore_size",
        type=int,
        default=1000,
        help="Total number of embeddings stored in the memmap",
    )
    parser.add_argument(
        "--dimension", type=int, default=768, help="Dimension of the embeddings"
    )

    # KMeans arguments
    parser.add_argument(
        "--model",
        type=str,
        default="scikit",
        choices=clustering_model_choices,
        help="KMeans model implementation to use",
    )
    parser.add_argument(
        "--num_clusters", type=int, default=128, help="Number of clusters for KMeans"
    )
    parser.add_argument(
        "--n_components",
        type=int,
        default=None,
        help="Number of PCA components (None if PCA is not applied)",
    )
    parser.add_argument(
        "--n_init", type=int, default=3, help="Number of initializations for KMeans"
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1024,
        help="Number of samples per batch for MiniBatchKMeans",
    )
    parser.add_argument(
        "--max_iter",
        type=int,
        default=300,
        help="Maximum number of iterations for KMeans",
    )
    parser.add_argument(
        "--tolerance",
        type=float,
        default=1e-4,
        help="Relative tolerance to declare convergence",
    )
    parser.add_argument(
        "--random_state",
        type=int,
        default=2,
        help="Random state for centroid initialization",
    )
    parser.add_argument(
        "--use_gpu",
        action="store_true",
        help="Indicate if the implementation should use GPU acceleration (if supported)",
    )

    # Processing arguments
    parser.add_argument(
        "--use_fp16",
        action="store_true",
        help="Indicate if the embeddings are stored in float16 format",
    )

    # Model saving arguments
    parser.add_argument(
        "--model_path",
        type=str,
        default="/home/data/test_kmeans_model.pkl",
        help="File path to save the trained KMeans model",
    )

    args = parser.parse_args()
    main(args)
