"""Discrete QD algorithms."""
import functools
import logging
import pickle

import hydra
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from ribs.archives import GridArchive
from ribs.visualize import grid_archive_heatmap
from tqdm.contrib.logging import logging_redirect_tqdm

from src.evaluation import compute_centers, evaluate_solution_model
from src.models.discrete_archive import DiscreteArchiveSolutionModel
from src.utils.code_timer import CodeTimer
from src.utils.hydra_utils import define_resolvers
from src.utils.logging import setup_logdir_from_hydra
from src.utils.metric_logger import MetricLogger

# Use when debugging warnings.
#  import src.utils.warn_traceback  # pylint: disable = unused-import

log = logging.getLogger(__name__)


def build_qd_algo(cfg, domain_module):
    """Creates a scheduler based on the algorithm configuration.

    Returns:
        ribs.schedulers.Scheduler: A ribs scheduler for running the algorithm.
    """
    # Create archive.
    archive = hydra.utils.instantiate(cfg.algo.archive.args, seed=cfg.seed)

    # Create result archive.
    result_archive = None
    if cfg.algo.get("result_archive"):
        result_archive = hydra.utils.instantiate(cfg.algo.result_archive.args,
                                                 seed=cfg.seed)

    # Create emitters. Each emitter needs a different seed so that they do not
    # all do the same thing, hence we use a SeedSequence to generate seeds.
    seed_sequence = np.random.SeedSequence(cfg.seed)
    emitters = []
    for e in cfg.algo.emitters:
        emitters += [
            hydra.utils.instantiate(
                e.type.args,
                archive=archive,
                x0=domain_module.initial_solution(),
                seed=s,
            ) for s in seed_sequence.spawn(e.num)
        ]

    # Create Scheduler
    scheduler = hydra.utils.instantiate(cfg.algo.scheduler.args,
                                        archive=archive,
                                        emitters=emitters,
                                        result_archive=result_archive)

    log.info(f"Created {scheduler.__class__.__name__} for "
             f"{HydraConfig.get().runtime.choices.algo}")

    return scheduler


def eval_and_make_plots_full(
    scheduler,
    cfg,
    domain_module,
    solution_model,
    logdir,
    rng,
    name,
    filetype="png",
):
    filename = f"{name:06}" if isinstance(name, int) else name

    if cfg.domain.config.feature_dim == 2:
        # We only plot discount function for GridArchive for now. In the case
        # where learning rate is 1.0, this is just the thresholds.
        has_discount_func = isinstance(scheduler.archive, GridArchive)

        ncols = 1 + int(has_discount_func)
        fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=((ncols * 5, 4)))
        if not has_discount_func:
            axs = [axs]

        if not isinstance(scheduler.result_archive, GridArchive):
            raise ValueError("Only plots GridArchive for now.")

        axs[0].set_title("Result Archive")
        grid_archive_heatmap(
            scheduler.result_archive,
            ax=axs[0],
            vmin=cfg.domain.config.obj_low,
            vmax=cfg.domain.config.obj_high,
            rasterized=True,
        )

        if has_discount_func:
            # Make archive that shows the true discount function.
            true_discount_archive = GridArchive(
                solution_dim=0,
                dims=cfg.algo.archive.args.dims,
                ranges=cfg.algo.archive.args.ranges,
            )
            feature_coords = compute_centers(true_discount_archive)

            # Default discount of threshold_min.
            if "threshold_min" in cfg.algo.archive.args:
                true_discount_archive.add(
                    np.empty((len(feature_coords), 0)),
                    np.full(len(feature_coords),
                            cfg.algo.archive.args.threshold_min),
                    feature_coords,
                )

            # Add in thresholds from archive.
            data = scheduler.archive.data()
            true_discount_archive.add(
                np.empty((len(data["solution"]), 0)),
                data["threshold"],
                data["measures"],
            )

            axs[1].set_title("Discount Function")
            grid_archive_heatmap(
                true_discount_archive,
                ax=axs[1],
                vmin=cfg.domain.config.obj_low,
                vmax=cfg.domain.config.obj_high,
                rasterized=True,
            )

        fig.tight_layout()
        fig.savefig(logdir.pfile(f"heatmaps_{filename}.{filetype}"), dpi=300)
        plt.close(fig)

    if (HydraConfig.get().runtime.choices.domain == "triangles_mnist" and
            len(scheduler.result_archive) > 0):
        imgs = scheduler.result_archive.sample_elites(
            min(len(scheduler.result_archive), 100),
            replace=False,
        )["measures"]
        from src.domains.triangles import render_mnist_batch_pil
        pil_image = render_mnist_batch_pil(imgs.reshape(-1, 28, 28))
        pil_image.save(logdir.pfile(f"mnist_samples_{filename}.png"))

    if (HydraConfig.get().runtime.choices.domain
            in ["triangles_afhq", "triangles_afhq_l2"] and
            len(scheduler.result_archive) > 0):
        from src.domains.triangles import afhq_pil_image
        sample = scheduler.result_archive.sample_elites(
            min(len(scheduler.result_archive), 100),
            replace=False,
        )
        pil_image = afhq_pil_image(sample["solution"], sample["index"], 128)
        pil_image.save(logdir.pfile(f"afhq_samples_{filename}.png"))

    if (HydraConfig.get().runtime.choices.domain == "lsi_face" and
            len(scheduler.result_archive) > 0):
        from src.domains.lsi_face import human_landscape_face
        sample = scheduler.result_archive.sample_elites(
            min(len(scheduler.result_archive), 40),
            replace=False,
        )
        pil_image = human_landscape_face(sample["solution"],
                                         domain_module.classifier,
                                         sample["index"])
        pil_image.save(logdir.pfile(f"face_samples_{filename}.png"))

    eval_metrics, _ = evaluate_solution_model(cfg, solution_model,
                                              domain_module, rng)
    return eval_metrics


@hydra.main(version_base=None, config_path="../config", config_name="config")
def main(cfg: DictConfig):
    """Runs experiment."""
    define_resolvers()

    logdir = setup_logdir_from_hydra()
    log.info(f"Logging directory: {logdir.logdir}")

    ## COMPONENT INITIALIZATION ##

    rng = np.random.default_rng(cfg.seed)
    if cfg.seed is not None:
        torch.manual_seed(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    domain_module = hydra.utils.instantiate(
        cfg.domain,
        seed=None if cfg.seed is None else cfg.seed + 42,
        device=device,
    )
    scheduler = build_qd_algo(cfg, domain_module)

    # Wrap the discrete archive.
    solution_model = DiscreteArchiveSolutionModel(
        cfg=OmegaConf.create({"normalize_features_before_dist": True}),
        seed=None if cfg.seed is None else cfg.seed + 42,
        device=device,
        archive=scheduler.result_archive,
    )

    eval_and_make_plots = functools.partial(eval_and_make_plots_full, scheduler,
                                            cfg, domain_module, solution_model,
                                            logdir, rng)
    eval_metrics_dict = eval_and_make_plots(0)

    # Set up metrics.
    timer = CodeTimer([
        "All",  # Everything in the iteration (except metrics, which are negligible).
        "Algorithm",  # Algorithm only -- excludes logging and saving, includes evals.
        "Internal",  # Algorithm time excluding evals.
        "Domain Evaluation",  # Evaluation in the domain for main solutions.
    ])
    dense_metric_list = [
        ("Evaluations", True, 0),
        ("QD Score", True, 0.0),
        ("Archive Coverage", True, 0.0),
        ("Objective Max", False),
        ("Restarts", True, 0),
        ("Unique Cells", False),
        ("Solutions Per Cell", False),
        *timer.metric_list(),
    ]
    sparse_metric_list = [
        ("Corrected Coverage", True, eval_metrics_dict["Corrected Coverage"]),
        ("Corrected QD Score", True, eval_metrics_dict["Corrected QD Score"]),
        ("Mean Feature Error", True, eval_metrics_dict["Mean Feature Error"]),
        ("Min Feature Error", True, eval_metrics_dict["Min Feature Error"]),
        ("Max Feature Error", True, eval_metrics_dict["Max Feature Error"]),
        ("CQD Score", True, eval_metrics_dict["CQD Score"]),
        ("Limited CQD Score", True, eval_metrics_dict["Limited CQD Score"]),
    ]

    # Metrics that are recorded on every iteration.
    metrics = MetricLogger(dense_metric_list)
    # Metrics that are recorded only every log_freq iterations.
    metrics_sparse = MetricLogger(dense_metric_list + sparse_metric_list,
                                  x_scale=cfg.log_freq)

    ## EXECUTION LOOP ##

    with logging_redirect_tqdm():
        for itr in tqdm.trange(1, cfg.itrs + 1):
            timer.start(["Algorithm", "All"])

            ## DQD ASK-TELL ##

            if cfg.algo.get("dqd"):
                solutions = scheduler.ask_dqd()
                objectives, measures, info = domain_module.evaluate(solutions,
                                                                    grad=True)
                jacobian = np.concatenate(
                    (
                        info["objective_grads"][:, None, :],
                        info["measure_grads"],
                    ),
                    axis=1,
                )

                fields = {}
                if cfg.domain.config.get("is_mnist"):
                    fields["mnist_img"] = info["mnist_img"]

                scheduler.tell_dqd(objectives, measures, jacobian, **fields)

            ## PRIMARY ASK-TELL ##

            solutions = scheduler.ask()

            ## START EVAL ##
            # Evaluate objectives and measures based on experiment.
            timer.start("Domain Evaluation")
            objectives, measures, info = domain_module.evaluate(solutions)
            if (HydraConfig.get().runtime.choices.domain == "triangles_mnist"
                    and domain_module.config.objective == "centroid_distance"):
                # Compute distance from centroids to the measures.
                if domain_module.config.rollouts > 1:
                    mean_measures = np.mean(
                        measures.reshape((len(solutions),
                                          domain_module.config.rollouts, -1)),
                        axis=1,
                    )
                    indices = scheduler.result_archive.index_of(mean_measures)
                    objectives = 1.0 - np.mean(
                        np.square(
                            np.repeat(
                                scheduler.result_archive.centroids[indices],
                                domain_module.config.rollouts,
                                axis=0) - measures),
                        axis=1,
                    )
                    objectives = np.mean(objectives.reshape(
                        (len(solutions), domain_module.config.rollouts)),
                                         axis=1)
                    measures = mean_measures
                else:
                    indices = scheduler.result_archive.index_of(measures)
                    # (batch_size,)
                    objectives = 1.0 - np.mean(
                        # (batch_size, measure_dim)
                        np.square(scheduler.result_archive.centroids[indices] -
                                  measures),
                        axis=1,
                    )
            elif (HydraConfig.get().runtime.choices.domain
                  == "triangles_afhq_l2" and
                  domain_module.config.objective == "centroid_distance"):
                discrete_archive = scheduler.result_archive
                indices, distances = discrete_archive.index_of(
                    measures, return_distances=True)
                # The distance is L2, so convert it to MSE.
                objectives = 1.0 - np.square(distances) / measures.shape[1]
            elif (HydraConfig.get().runtime.choices.domain == "triangles_afhq"
                  and
                  domain_module.config.objective == "centroid_cosine_distance"):
                # pylint: disable-next = protected-access
                objectives, indices = scheduler.result_archive._sklearn_nn.kneighbors(
                    measures, n_neighbors=1, return_distance=True)
                objectives = objectives.reshape(len(measures))
            elif (HydraConfig.get().runtime.choices.domain == "lsi_face" and
                  domain_module.config.add_centroid_dist):
                discrete_archive = scheduler.result_archive
                # pylint: disable-next = protected-access
                distances, indices = discrete_archive._sklearn_nn.kneighbors(
                    measures, n_neighbors=1, return_distance=True)
                distances = distances.reshape(len(measures))
                # Rescale to [0, 1].
                distances = (distances + 1.0) / 2.0
                # Add the two objectives together, weighting them equally.
                objectives = (objectives + distances) / 2.0
            timer.end("Domain Evaluation")
            ## END EVAL ##

            fields = {}
            if cfg.domain.config.get("is_mnist"):
                fields["mnist_img"] = info["mnist_img"]

            # pylint: disable = unused-variable
            data, add_info = scheduler.tell(objectives, measures, **fields)

            # Update time.
            timer.end("Algorithm")
            timer.itr_time["Internal"] = (timer.itr_time["Algorithm"] -
                                          timer.itr_time["Domain Evaluation"])

            # Logging and output (occasional).
            final_itr = itr == cfg.itrs
            if itr % cfg.log_freq == 0 or final_itr:
                eval_metrics_dict = eval_and_make_plots(itr)

                log.info(
                    f"{itr:5d} | "
                    f"Cov: {stats.coverage * 100:.3f}%  "
                    f"Size: {stats.num_elites:5d}  "
                    f"QD: {stats.qd_score:.3f}  "
                    f"Mean Feat Err: {eval_metrics_dict['Mean Feature Error']:.6f}  "
                    f"CQD: {eval_metrics_dict['CQD Score']:.3f}  "
                    f"LCQD: {eval_metrics_dict['Limited CQD Score']:.3f}")

            timer.end("All")

            # Metrics.
            timer.calc_totals()
            timer_dict = timer.metrics_dict()
            timer.clear()

            stats = scheduler.result_archive.stats
            metrics.start_itr()
            total_restarts = sum(e.restarts
                                 for e in scheduler.emitters
                                 if hasattr(e, "restarts"))
            unique_indices = len(set(add_info["index"]))
            dense_metrics_dict = {
                "Evaluations": metrics.get_last("Evaluations") + len(solutions),
                "QD Score": stats.qd_score,
                "Archive Coverage": stats.coverage,
                "Objective Max": stats.obj_max,
                "Restarts": total_restarts,
                "Unique Cells": unique_indices,
                "Solutions Per Cell": len(add_info["index"]) / unique_indices,
            }
            dense_metrics_dict.update(timer_dict)
            metrics.add_dict(dense_metrics_dict)
            metrics.end_itr()

            if itr % cfg.log_freq == 0 or itr == cfg.itrs:
                # Add sparse metrics.
                metrics_sparse.start_itr()
                metrics_sparse.add_dict(dense_metrics_dict | eval_metrics_dict)
                metrics_sparse.end_itr()

    # Plot final archives as a PDF (not a PNG like during the run).
    eval_and_make_plots("final", "pdf")

    # Save and plot metrics.
    metrics.to_json(logdir.file("metrics.json"))
    metrics.plot_graphic(logdir.file("metrics_final.svg"))

    # Save the final archive as numpy.
    np.savez_compressed(logdir.pfile("archive.npz"),
                        **scheduler.result_archive.data())

    # Save scheduler.
    with logdir.pfile("scheduler.pkl", touch=True).open("wb") as file:
        pickle.dump(scheduler, file)

    log.info("Summary:")
    for name, val in metrics.summary().items():
        log.info(f"- {name}:\t{val}")
    log.info(f"Logging directory: {logdir.logdir}")
    log.info("Done")


if __name__ == '__main__':
    main()  # pylint: disable = no-value-for-parameter
