#!/usr/bin/env python3
import argparse
import logging
import os
from typing import Dict
import datetime

import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
from sklearn import metrics as sk_metrics
from tqdm import tqdm

from kmeans_jax import run_sdp_clustering
from kmeans_jax.kmeans import (
    compute_loss,
    compute_centroids,
)


def _mkbasedir(path: str) -> None:
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except (FileExistsError, PermissionError):
            raise ValueError("Output path does not exist or cannot be created.")
    return


def _make_initial_data(
    key, n_clusters, size_per_cluster, dimension, var_prior, noise_variance
):
    key_centers, key_noise = jax.random.split(key, 2)

    # Generate data
    true_centers = jax.random.normal(
        key_centers, shape=(n_clusters, dimension)
    ) * jnp.sqrt(var_prior)
    true_labels = jnp.arange(n_clusters).repeat(size_per_cluster)
    data = true_centers[true_labels] + jax.random.normal(
        key_noise, shape=(n_clusters * size_per_cluster, dimension)
    ) * jnp.sqrt(noise_variance)

    return data, true_labels


def run_single_experiment(
    key: PRNGKeyArray,
    noise_variance: Float,
    n_clusters: Int,
    size_per_cluster: Int,
    dimension: Int,
    var_prior: Float,
    max_iters: Int,
) -> Dict[str, Float]:
    _, key_data = jax.random.split(key, 2)

    data, true_labels = _make_initial_data(
        key_data,
        n_clusters=n_clusters,
        size_per_cluster=size_per_cluster,
        dimension=dimension,
        var_prior=var_prior,
        noise_variance=noise_variance,
    )
    true_data_averages = compute_centroids(data, true_labels, n_clusters)
    true_loss = compute_loss(data, true_data_averages, true_labels)

    ############################# Alternative methods #############################
    data = np.asanyarray(data.block_until_ready())
    _, labels_sdp, loss_sdp, n_iter_sdp = run_sdp_clustering(
        data, n_clusters=n_clusters, max_iters=max_iters
    )
    nmi_sdp = sk_metrics.normalized_mutual_info_score(true_labels, labels_sdp)

    results = {
        "sdp": {
            "nmi": nmi_sdp,
            "loss": loss_sdp,
            "n_iter": n_iter_sdp,
        },
        "true_partition": {
            "loss": true_loss,
        },
    }
    return results


def run_gmm_experiments(
    dimension_vals: Int[Array, " n_dims"],
    noise_variance_vals: Float[Array, " n_noise_variances"],
    prior_variance: Float,
    n_clusters: Int,
    size_per_cluster: Int,
    n_experiments: Int,
    path_to_output: str,
    *,
    max_iters: Int = 1000,
    seed: Int = 0,
    overwrite: Bool = False,
) -> Dict[str, Float[Array, "n_dims n_noise_variances n_experiments"]]:
    """
    Run clustering with SDP.

    **Arguments:**
        - dimension_vals: Array of dimensions to test.
        - noise_variance_vals: Array of noise variances to test.
        - prior_variance: Prior variance for the data generation.
        - n_clusters: Number of clusters.
        - size_per_cluster: Size of each cluster.
        - n_experiments: Number of experiments to run for each setting.
        - path_to_output: Path to save the results.
        - max_iters: Maximum number of iterations for k-means.
        - seed: Random seed for reproducibility.
        - overwrite: Whether to overwrite existing output files.
    **Returns:**
        - results: Dictionary containing the results of the experiments.
        The results are the NMI vs the true labels_sdp, and loss values for each experiment.
    """
    assert jnp.all(dimension_vals > 0)
    assert jnp.all(noise_variance_vals > 0)
    assert jnp.all(prior_variance > 0)
    assert jnp.all(n_clusters > 1)
    assert jnp.all(size_per_cluster > 0)
    assert jnp.all(n_experiments > 0)
    assert jnp.all(max_iters > 0)

    if os.path.exists(path_to_output) and not overwrite:
        raise ValueError(
            f"Ouput file {path_to_output} exists, but overwrite was set to False"
        )

    key = jax.random.key(seed)

    logging.info("Starting experiments")
    logging.info("=" * 100)

    shape_outputs = (
        len(dimension_vals),
        len(noise_variance_vals),
        n_experiments,
    )

    algorithm_names = ["sdp"]
    results = {}
    for alg in algorithm_names:
        results[alg] = {
            "nmi": np.zeros(shape_outputs),
            "loss": np.zeros(shape_outputs),
            "n_iter": np.zeros(shape_outputs, dtype=int),
        }
    results["true_partition"] = {
        "loss": np.zeros(shape_outputs),
    }
    params_dict = {
        "dimension_vals": dimension_vals,
        "noise_variance_vals": noise_variance_vals,
        "prior_variance": prior_variance,
        "n_clusters": n_clusters,
        "size_per_cluster": size_per_cluster,
        "n_experiments": n_experiments,
        "max_iters": max_iters,
        "i": 0,
        "j": 0,
    }
    results.update(params_dict)
    for dkey in params_dict:
        logging.info(f"{dkey}: {params_dict[dkey]}")

    for i in tqdm(range(len(dimension_vals))):
        results["i"] = i
        logging.info(f"  Running for d = {dimension_vals[i]}")
        for j in range(len(noise_variance_vals)):
            results["j"] = j
            logging.info(
                f"    Running for noise_variance_vals = {noise_variance_vals[j]}"
            )

            logging.info("      Running experiments")
            for k in range(n_experiments):
                key, subkey = jax.random.split(key)
                experiment_result = run_single_experiment(
                    key=subkey,
                    noise_variance=noise_variance_vals[j],
                    n_clusters=n_clusters,
                    size_per_cluster=size_per_cluster,
                    dimension=dimension_vals[i],
                    var_prior=prior_variance,
                    max_iters=max_iters,
                )

                for alg in algorithm_names:
                    results[alg]["nmi"][i, j, k] = experiment_result[alg]["nmi"]
                    results[alg]["loss"][i, j, k] = experiment_result[alg]["loss"]
                    results[alg]["n_iter"][i, j, k] = experiment_result[alg]["n_iter"]

                results["true_partition"]["loss"][i, j, k] = experiment_result[
                    "true_partition"
                ]["loss"]

            logging.info("      Done running experiments. Moving to next setting.")
            logging.info("=" * 100)

            jnp.savez(
                path_to_output,
                **results,
            )

            logging.info(f"Saved preliminary results to {path_to_output}")
    logging.info("Finished running all experiments.")
    return results


def run_gmm_experiments_continue(
    dimension_vals: Int[Array, " n_dims"],
    noise_variance_vals: Float[Array, " n_noise_variances"],
    prior_variance: Float,
    n_clusters: Int,
    size_per_cluster: Int,
    n_experiments: Int,
    path_to_output: str,
    *,
    max_iters: Int = 1000,
    seed: Int = 0,
) -> Dict[str, Float[Array, "n_dims n_noise_variances n_experiments"]]:
    """
    Continue running clustering experiments with SDP.

    **Arguments:**
        - dimension_vals: Array of dimensions to test.
        - noise_variance_vals: Array of noise variances to test.
        - prior_variance: Prior variance for the data generation.
        - n_clusters: Number of clusters.
        - size_per_cluster: Size of each cluster.
        - n_experiments: Number of experiments to run for each setting.
        - path_to_output: Path to save the results.
        - max_iters: Maximum number of iterations for k-means.
        - seed: Random seed for reproducibility.
    **Returns:**
        - results: Dictionary containing the results of the experiments.
        The results are the NMI vs the true labels_sdp, and loss values for each experiment.
    """
    assert jnp.all(dimension_vals > 0)
    assert jnp.all(noise_variance_vals > 0)
    assert jnp.all(prior_variance > 0)
    assert jnp.all(n_clusters > 1)
    assert jnp.all(size_per_cluster > 0)
    assert jnp.all(n_experiments > 0)
    assert jnp.all(max_iters > 0)

    if not os.path.exists(path_to_output):
        raise ValueError(
            f"Ouput file {path_to_output} does not exist, cannot continue."
        )

    else:
        results = dict(jnp.load(path_to_output, allow_pickle=True))

        for dkey in results.keys():  # relevant_keys:
            if isinstance(results[dkey], np.ndarray) and results[dkey].size == 1:
                results[dkey] = results[dkey].item()

    key = jax.random.key(seed)

    logging.info("Starting experiments")
    logging.info("=" * 100)

    algorithm_names = ["sdp"]

    last_i = results["i"]
    last_j = results["j"]

    for i in tqdm(range(len(dimension_vals))):
        results["i"] = i
        logging.info(f"  Running for d = {dimension_vals[i]}")
        for j in range(len(noise_variance_vals)):
            results["j"] = j
            logging.info(
                f"    Running for noise_variance_vals = {noise_variance_vals[j]}"
            )

            logging.info("      Running experiments")
            for k in range(n_experiments):
                key, subkey = jax.random.split(key)

                if i < last_i or (i == last_i and j < last_j):
                    continue

                else:
                    experiment_result = run_single_experiment(
                        key=subkey,
                        noise_variance=noise_variance_vals[j],
                        n_clusters=n_clusters,
                        size_per_cluster=size_per_cluster,
                        dimension=dimension_vals[i],
                        var_prior=prior_variance,
                        max_iters=max_iters,
                    )

                    for alg in algorithm_names:
                        results[alg]["nmi"][i, j, k] = experiment_result[alg]["nmi"]
                        results[alg]["loss"][i, j, k] = experiment_result[alg]["loss"]
                        results[alg]["n_iter"][i, j, k] = experiment_result[alg][
                            "n_iter"
                        ]

                    results["true_partition"]["loss"][i, j, k] = experiment_result[
                        "true_partition"
                    ]["loss"]
            jax.clear_caches()

            logging.info("      Done running experiments. Moving to next setting.")
            logging.info("=" * 100)

            jnp.savez(
                path_to_output,
                **results,
            )

            logging.info(f"Saved preliminary results to {path_to_output}")
    logging.info("Finished running all experiments.")
    return results


def add_args(parser):
    parser.add_argument("--output_file", type=str, help="Output file", required=True)
    parser.add_argument(
        "--resume", action="store_true", help="Continue from existing output file"
    )
    parser.add_argument(
        "--overwrite", action="store_true", help="Overwrite existing output file"
    )
    parser.add_argument("--n_clusters", type=int, default=2, help="Number of clusters")
    return parser


def main(output_file, continues_from_existing, overwrite, n_clusters):
    basedir = os.path.dirname(output_file)
    _mkbasedir(basedir)

    # set up logger

    logger = logging.getLogger()
    logger.handlers.clear()

    logger_fname = datetime.datetime.now().strftime("%Y-%m-%d")
    logger_fname = os.path.join(basedir, logger_fname + ".log")

    fhandler = logging.FileHandler(filename=logger_fname, mode="a")
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    fhandler.setFormatter(formatter)
    logger.addHandler(fhandler)
    logger.setLevel(logging.INFO)

    prior_variance = 1.0
    size_per_cluster = 20

    n_experiments = 100
    dimension_vals = jnp.logspace(0.8, 4, 20, dtype=int)
    noise_variance_vals = jnp.logspace(-1.0, 2.0, 20)

    if continues_from_existing:
        _ = run_gmm_experiments_continue(
            dimension_vals,
            noise_variance_vals,
            prior_variance,
            n_clusters,
            size_per_cluster,
            n_experiments,
            path_to_output=output_file,
            max_iters=500,
            seed=0,
        )

    else:
        _ = run_gmm_experiments(
            dimension_vals,
            noise_variance_vals,
            prior_variance,
            n_clusters,
            size_per_cluster,
            n_experiments,
            path_to_output=output_file,
            max_iters=500,
            seed=0,
            overwrite=overwrite,
        )
    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    args = add_args(parser).parse_args()

    main(args.output_file, args.resume, args.overwrite, args.n_clusters)
