#!/usr/bin/env python3
import argparse
import logging
import os
from typing import Dict
import datetime
from typing import Literal

import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
from sklearn import metrics as sk_metrics
from sklearn.decomposition import PCA
from tqdm import tqdm

from kmeans_jax.kmeans import (
    compute_loss,
    compute_centroids,
    assign_clusters,
    kmeans_init_from_random_partition,
    kmeans_random_init,
    kmeans_plusplus_init,
    run_lloyd_kmeans,
    run_hartigan_kmeans,
)


def _mkbasedir(path: str) -> None:
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except (FileExistsError, PermissionError):
            raise ValueError("Output path does not exist or cannot be created.")
    return


def _make_initialization(
    key_init, data, data_pca, init_method, n_clusters, true_labels
):
    if init_method == "random_centers":
        init_centroids, indices = kmeans_random_init(data, n_clusters, key_init)
        init_centroids_pca = data_pca[indices]
        init_partition = assign_clusters(init_centroids, data)

    elif init_method == "kmeans++":
        init_centroids, indices = kmeans_plusplus_init(data, n_clusters, key_init)
        init_centroids_pca = data_pca[indices]
        init_partition = assign_clusters(init_centroids, data)

    elif init_method == "random_partition":
        init_centroids, init_partition = kmeans_init_from_random_partition(
            data, n_clusters, key_init, labels=true_labels
        )
        init_centroids_pca = compute_centroids(data_pca, init_partition, n_clusters)

    else:
        raise ValueError(
            f"Unknown initialization method: {init_method}"
            + "Only 'random_centers', 'kmeans++', and 'random_partition' are supported."
        )

    return init_centroids, init_centroids_pca, init_partition


def _make_initial_data(
    key, n_clusters, size_per_cluster, dimension, var_prior, noise_variance
):
    key_centers, key_noise = jax.random.split(key, 2)

    # Generate data
    true_centers = jax.random.normal(
        key_centers, shape=(n_clusters, dimension)
    ) * jnp.sqrt(var_prior)
    true_labels = jnp.arange(n_clusters).repeat(size_per_cluster)
    data = true_centers[true_labels] + jax.random.normal(
        key_noise, shape=(n_clusters * size_per_cluster, dimension)
    ) * jnp.sqrt(noise_variance)

    return data, true_labels


def run_single_experiment(
    key: PRNGKeyArray,
    noise_variance: Float,
    n_clusters: Int,
    size_per_cluster: Int,
    dimension: Int,
    var_prior: Float,
    max_iters: Int,
    init_method: str,
) -> Dict[str, Float]:
    key_init, key_data = jax.random.split(key, 2)

    data, true_labels = _make_initial_data(
        key_data,
        n_clusters=n_clusters,
        size_per_cluster=size_per_cluster,
        dimension=dimension,
        var_prior=var_prior,
        noise_variance=noise_variance,
    )

    pca = PCA(
        n_components=min(max(4, n_clusters), data.shape[1]),
    )
    pca.fit(np.array(data))
    data_pca = jnp.array(pca.transform(np.array(data)))

    true_data_averages = compute_centroids(data, true_labels, n_clusters)
    true_loss = compute_loss(data, true_data_averages, true_labels)

    init_centroids, init_centroids_pca, init_partition = _make_initialization(
        key_init, data, data_pca, init_method, n_clusters, true_labels
    )

    init_data_averages = compute_centroids(data, init_partition, n_clusters)
    init_loss = compute_loss(data, init_data_averages, init_partition)
    init_nmi = sk_metrics.normalized_mutual_info_score(true_labels, init_partition)

    ################################# K-Means ###############################
    _, labels_lloyd, loss_lloyd, n_iter_lloyd = run_lloyd_kmeans(
        data, init_centroids, max_iters=max_iters
    )
    nmi_kmeans = sk_metrics.normalized_mutual_info_score(true_labels, labels_lloyd)

    # Hartigan k-means
    _, labels_hartigan, loss_hartigan, n_iter_hartigan = run_hartigan_kmeans(
        data, init_centroids, max_iters=max_iters
    )
    nmi_hartigan = sk_metrics.normalized_mutual_info_score(true_labels, labels_hartigan)

    ################################# PCA ###############################
    _, labels_pca, loss_pca, n_iter_pca = run_lloyd_kmeans(
        data_pca, init_centroids_pca, max_iters=max_iters
    )
    nmi_kmeans_pca = sk_metrics.normalized_mutual_info_score(true_labels, labels_pca)
    loss_pca = compute_loss(
        data, compute_centroids(data, labels_pca, n_clusters), labels_pca
    )

    # Split PCA
    if n_clusters == 2:
        labels_pca_split = jnp.where(data_pca[:, 0] > 0.0, 0, 1).astype(int)
        nmi_split_pca = sk_metrics.normalized_mutual_info_score(
            true_labels, labels_pca_split
        )
        loss_pca_split = compute_loss(
            data, compute_centroids(data, labels_pca_split, 2), labels_pca_split
        )
        n_iter_pca_split = 1
    else:
        nmi_split_pca = 0.0
        loss_pca_split = jnp.inf
        n_iter_pca_split = 0

    data = np.asanyarray(data.block_until_ready())

    results = {
        "lloyd": {
            "nmi": nmi_kmeans,
            "loss": loss_lloyd,
            "n_iter": n_iter_lloyd,
        },
        "hartigan": {
            "nmi": nmi_hartigan,
            "loss": loss_hartigan,
            "n_iter": n_iter_hartigan,
        },
        "pca": {
            "nmi": nmi_kmeans_pca,
            "loss": loss_pca,
            "n_iter": n_iter_pca,
        },
        "pca_split": {
            "nmi": nmi_split_pca,
            "loss": loss_pca_split,
            "n_iter": n_iter_pca_split,
        },
        "true_partition": {
            "loss": true_loss,
        },
        "init": {
            "loss": init_loss,
            "nmi": init_nmi,
        },
    }
    return results


def run_gmm_experiments(
    dimension_vals: Int[Array, " n_dims"],
    noise_variance_vals: Float[Array, " n_noise_variances"],
    prior_variance: Float,
    n_clusters: Int,
    size_per_cluster: Int,
    n_experiments: Int,
    path_to_output: str,
    init_method: Literal["random_centers", "kmeans++", "random_partition"],
    *,
    max_iters: Int = 1000,
    seed: Int = 0,
    overwrite: Bool = False,
) -> Dict[str, Float[Array, "n_dims n_noise_variances n_experiments"]]:
    """
    Run clustering experiments with Lloyds' and Hartigan's k-means, PCA + Lloyd, and PCA + Split (K = 2).

    **Arguments:**
        - dimension_vals: Array of dimensions to test.
        - noise_variance_vals: Array of noise variances to test.
        - prior_variance: Prior variance for the data generation.
        - n_clusters: Number of clusters.
        - size_per_cluster: Size of each cluster.
        - n_experiments: Number of experiments to run for each setting.
        - path_to_output: Path to save the results.
        - init_method: Initialization method for k-means ('random_centers', 'kmeans++', 'random_partition').
        - max_iters: Maximum number of iterations for k-means.
        - seed: Random seed for reproducibility.
        - overwrite: Whether to overwrite existing output files.

    **Returns:**
        - results: Dictionary containing the results of the experiments.
    """
    assert jnp.all(dimension_vals > 0)
    assert jnp.all(noise_variance_vals > 0)
    assert jnp.all(prior_variance > 0)
    assert jnp.all(n_clusters > 1)
    assert jnp.all(size_per_cluster > 0)
    assert jnp.all(n_experiments > 0)
    assert jnp.all(max_iters > 0)

    if os.path.exists(path_to_output) and not overwrite:
        raise ValueError(
            f"Ouput file {path_to_output} exists, but overwrite was set to False"
        )

    key = jax.random.key(seed)

    logging.info("Starting experiments")
    logging.info("=" * 100)

    shape_outputs = (
        len(dimension_vals),
        len(noise_variance_vals),
        n_experiments,
    )

    algorithm_names = [
        "lloyd",
        "hartigan",
        "pca",
        "pca_split",
    ]
    results = {}
    for alg in algorithm_names:
        results[alg] = {
            "nmi": np.zeros(shape_outputs),
            "loss": np.zeros(shape_outputs),
            "n_iter": np.zeros(shape_outputs, dtype=int),
        }
    results["true_partition"] = {
        "loss": np.zeros(shape_outputs),
    }
    results["init"] = {
        "nmi": np.zeros(shape_outputs),
        "loss": np.zeros(shape_outputs),
    }
    params_dict = {
        "dimension_vals": dimension_vals,
        "noise_variance_vals": noise_variance_vals,
        "prior_variance": prior_variance,
        "n_clusters": n_clusters,
        "size_per_cluster": size_per_cluster,
        "n_experiments": n_experiments,
        "max_iters": max_iters,
        "i": 0,
        "j": 0,
    }
    results.update(params_dict)
    for dkey in params_dict:
        logging.info(f"{dkey}: {params_dict[dkey]}")

    for i in tqdm(range(len(dimension_vals))):
        results["i"] = i
        logging.info(f"  Running for d = {dimension_vals[i]}")
        for j in range(len(noise_variance_vals)):
            results["j"] = j
            logging.info(
                f"    Running for noise_variance_vals = {noise_variance_vals[j]}"
            )

            logging.info("      Running experiments")
            for k in range(n_experiments):
                key, subkey = jax.random.split(key)
                experiment_result = run_single_experiment(
                    key=subkey,
                    noise_variance=noise_variance_vals[j],
                    n_clusters=n_clusters,
                    size_per_cluster=size_per_cluster,
                    dimension=dimension_vals[i],
                    var_prior=prior_variance,
                    max_iters=max_iters,
                    init_method=init_method,
                )

                for alg in algorithm_names:
                    results[alg]["nmi"][i, j, k] = experiment_result[alg]["nmi"]
                    results[alg]["loss"][i, j, k] = experiment_result[alg]["loss"]
                    results[alg]["n_iter"][i, j, k] = experiment_result[alg]["n_iter"]

                results["true_partition"]["loss"][i, j, k] = experiment_result[
                    "true_partition"
                ]["loss"]
                results["init"]["nmi"][i, j, k] = experiment_result["init"]["nmi"]
                results["init"]["loss"][i, j, k] = experiment_result["init"]["loss"]

            logging.info("      Done running experiments. Moving to next setting.")
            logging.info("=" * 100)

            jnp.savez(
                path_to_output,
                **results,
            )

            logging.info(f"Saved preliminary results to {path_to_output}")
    logging.info("Finished running all experiments.")
    return results


def run_gmm_experiments_continue(
    dimension_vals: Int[Array, " n_dims"],
    noise_variance_vals: Float[Array, " n_noise_variances"],
    prior_variance: Float,
    n_clusters: Int,
    size_per_cluster: Int,
    n_experiments: Int,
    path_to_output: str,
    init_method: Literal["random_centers", "kmeans++", "random_partition"],
    *,
    max_iters: Int = 1000,
    seed: Int = 0,
) -> Dict[str, Float[Array, "n_dims n_noise_variances n_experiments"]]:
    """
    Continue running clustering experiments with Lloyds' and Hartigan's k-means, PCA + Lloyd, and PCA + Split (K = 2).

    **Arguments:**
        - dimension_vals: Array of dimensions to test.
        - noise_variance_vals: Array of noise variances to test.
        - prior_variance: Prior variance for the data generation.
        - n_clusters: Number of clusters.
        - size_per_cluster: Size of each cluster.
        - n_experiments: Number of experiments to run for each setting.
        - init_method: Initialization method for k-means ('random_centers', 'kmeans++', 'random_partition').
        - path_to_output: Path to save the results.
        - max_iters: Maximum number of iterations for k-means.
        - seed: Random seed for reproducibility.

    **Returns:**
        - results: Dictionary containing the results of the experiments.
        The results are the NMI vs the true labels_sdp, and loss values for each experiment.
    """
    assert jnp.all(dimension_vals > 0)
    assert jnp.all(noise_variance_vals > 0)
    assert jnp.all(prior_variance > 0)
    assert jnp.all(n_clusters > 1)
    assert jnp.all(size_per_cluster > 0)
    assert jnp.all(n_experiments > 0)
    assert jnp.all(max_iters > 0)

    if not os.path.exists(path_to_output):
        raise ValueError(
            f"Ouput file {path_to_output} does not exist, cannot continue."
        )

    else:
        results = dict(jnp.load(path_to_output, allow_pickle=True))

        for dkey in results.keys():  # relevant_keys:
            if isinstance(results[dkey], np.ndarray) and results[dkey].size == 1:
                results[dkey] = results[dkey].item()

    key = jax.random.key(seed)

    logging.info("Starting experiments")
    logging.info("=" * 100)

    algorithm_names = [
        "lloyd",
        "hartigan",
        "bhartigan",
        "mbhartigan",
        "pca",
        "pca_split",
    ]

    last_i = results["i"]
    last_j = results["j"]

    for i in tqdm(range(len(dimension_vals))):
        results["i"] = i
        logging.info(f"  Running for d = {dimension_vals[i]}")
        for j in range(len(noise_variance_vals)):
            results["j"] = j
            logging.info(
                f"    Running for noise_variance_vals = {noise_variance_vals[j]}"
            )

            logging.info("      Running experiments")
            for k in range(n_experiments):
                key, subkey = jax.random.split(key)

                if i < last_i or (i == last_i and j < last_j):
                    continue

                else:
                    experiment_result = run_single_experiment(
                        key=subkey,
                        noise_variance=noise_variance_vals[j],
                        n_clusters=n_clusters,
                        size_per_cluster=size_per_cluster,
                        dimension=dimension_vals[i],
                        var_prior=prior_variance,
                        max_iters=max_iters,
                        init_method=init_method,
                    )

                    for alg in algorithm_names:
                        results[alg]["nmi"][i, j, k] = experiment_result[alg]["nmi"]
                        results[alg]["loss"][i, j, k] = experiment_result[alg]["loss"]
                        results[alg]["n_iter"][i, j, k] = experiment_result[alg][
                            "n_iter"
                        ]

                    results["true_partition"]["loss"][i, j, k] = experiment_result[
                        "true_partition"
                    ]["loss"]
                    results["init"]["nmi"][i, j, k] = experiment_result["init"]["nmi"]
                    results["init"]["loss"][i, j, k] = experiment_result["init"]["loss"]
            jax.clear_caches()

            logging.info("      Done running experiments. Moving to next setting.")
            logging.info("=" * 100)

            jnp.savez(
                path_to_output,
                **results,
            )

            logging.info(f"Saved preliminary results to {path_to_output}")
    logging.info("Finished running all experiments.")
    return results


def add_args(parser):
    parser.add_argument("--output_file", type=str, help="Output File", required=True)
    parser.add_argument(
        "--resume", action="store_true", help="Continue from existing output file"
    )
    parser.add_argument(
        "--overwrite", action="store_true", help="Overwrite existing output file"
    )
    parser.add_argument("--n_clusters", type=int, default=2, help="Number of clusters")
    parser.add_argument(
        "--init", type=str, help="kmeanspp, random, or random_partition", required=True
    )
    return parser


def main(output_file, continues_from_existing, overwrite, n_clusters, init):
    basedir = os.path.dirname(output_file)
    _mkbasedir(basedir)

    # set up logger
    logger = logging.getLogger()
    logger.handlers.clear()

    logger_fname = datetime.datetime.now().strftime("%Y-%m-%d")
    logger_fname = os.path.join(basedir, logger_fname + ".log")

    fhandler = logging.FileHandler(filename=logger_fname, mode="a")
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    fhandler.setFormatter(formatter)
    logger.addHandler(fhandler)
    logger.setLevel(logging.INFO)

    prior_variance = 1.0
    size_per_cluster = 20

    n_experiments = 100

    dimension_vals = jnp.logspace(0.8, 4, 20, dtype=int)
    noise_variance_vals = jnp.logspace(-1.0, 2.0, 20)

    init_parser = {
        "kmeanspp": "kmeans++",
        "random": "random_centers",
        "random_partition": "random_partition",
    }

    if continues_from_existing:
        _ = run_gmm_experiments_continue(
            dimension_vals,
            noise_variance_vals,
            prior_variance,
            n_clusters,
            size_per_cluster,
            n_experiments,
            init_method=init_parser[init],
            path_to_output=output_file,
            max_iters=500,
            seed=0,
        )

    else:
        _ = run_gmm_experiments(
            dimension_vals,
            noise_variance_vals,
            prior_variance,
            n_clusters,
            size_per_cluster,
            n_experiments,
            init_method=init_parser[init],
            path_to_output=output_file,
            max_iters=500,
            seed=0,
            overwrite=overwrite,
        )
    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    args = add_args(parser).parse_args()

    main(args.output_file, args.resume, args.overwrite, args.n_clusters, args.init)
