"""Discount Model Search."""
import functools
import logging
import pathlib
import pickle as pkl
import shutil

import hydra
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from ribs.visualize import grid_archive_heatmap
from tqdm.contrib.logging import logging_redirect_tqdm

from src.evaluation import (evaluate_solution_model, make_discount_archive,
                            plot_discount_archive,
                            plot_solution_model_evaluation)
from src.models.discrete_archive import DiscreteArchiveSolutionModel
from src.mpl_styles.utils import mpl_style_file
from src.utils.code_timer import CodeTimer
from src.utils.hydra_utils import define_resolvers
from src.utils.logging import setup_logdir_from_hydra
from src.utils.metric_logger import MetricLogger
from src.visualize import (visualize_discount_points,
                           visualize_discount_points_2)

# Use when debugging warnings.
#  import src.utils.warn_traceback  # pylint: disable = unused-import

log = logging.getLogger(__name__)


def build_qd_algo(cfg, domain_module, device):
    """Creates a scheduler based on the algorithm configuration."""

    # Create models and archive.
    discount_model = hydra.utils.instantiate(
        cfg.algo.discount_model,
        seed=None if cfg.seed is None else cfg.seed + 420,
        device=device,
        _recursive_=False,
    )
    solution_model = hydra.utils.instantiate(
        cfg.algo.solution_model,
        seed=None if cfg.seed is None else cfg.seed + 4200,
        device=device,
        _recursive_=False,
    )
    archive = hydra.utils.instantiate(
        cfg.algo.archive.args,
        discount_model=discount_model,
        solution_model=solution_model,
        seed=cfg.seed,
        device=device,
        _recursive_=False,
    )

    # Create result archive.
    result_archive = None
    if "result_archive" in cfg.algo:
        result_archive = hydra.utils.instantiate(cfg.algo.result_archive.args,
                                                 seed=cfg.seed)

    # Create emitters. Each emitter needs a different seed so that they do not
    # all do the same thing, hence we use a SeedSequence to generate seeds.
    seed_sequence = np.random.SeedSequence(cfg.seed)
    emitters = []
    for e in cfg.algo.emitters:
        emitters += [
            hydra.utils.instantiate(
                e.type.args,
                archive=archive,
                x0=domain_module.initial_solution(),
                seed=s,
            ) for s in seed_sequence.spawn(e.num)
        ]

    # Create Scheduler
    scheduler = hydra.utils.instantiate(cfg.algo.scheduler.args,
                                        archive=archive,
                                        emitters=emitters,
                                        result_archive=result_archive)

    log.info(f"Created {scheduler.__class__.__name__} for "
             f"{HydraConfig.get().runtime.choices.algo}")

    return scheduler


def save_scheduler(scheduler, save_name, savedir: pathlib.Path,
                   tmpdir: pathlib.Path):
    # Clean up the tmpdir.
    shutil.rmtree(tmpdir, ignore_errors=True)
    tmpdir.mkdir()

    # Save files in the tmpdir.
    with (tmpdir / "scheduler.pkl").open("wb") as file:
        pkl.dump(scheduler, file)
    with (tmpdir / "name.txt").open("w") as file:
        file.write(f"{save_name}\n")

    if isinstance(scheduler.archive.solution_model,
                  DiscreteArchiveSolutionModel):
        discrete_archive = scheduler.archive.solution_model.model
        np.savez_compressed(tmpdir / "archive.npz", **discrete_archive.data())

    # Now that everything is saved, remove the old savedir and make the tmpdir
    # be the savedir. This ensures that if we fail in the middle of a save, we
    # do not lose the old checkpoint.
    shutil.rmtree(savedir, ignore_errors=True)
    tmpdir.rename(savedir)


def eval_and_make_plots_full(
    scheduler,
    cfg,
    domain_module,
    logdir,
    rng,
    name,
    filetype="png",
    discount_train_info=None,
):
    filename = f"{name:06}" if isinstance(name, int) else name

    if cfg.fast_metrics:
        discrete_archive = scheduler.archive.solution_model.model
        eval_metrics = {
            "Corrected QD Score": discrete_archive.stats.qd_score,
            "Corrected Coverage": discrete_archive.stats.coverage,
            "Mean Feature Error": 0.0,
            "CQD Score": 0.0,
            "Limited CQD Score": 0.0,
        }
        eval_info = None
    else:
        eval_metrics, eval_info = evaluate_solution_model(
            cfg, scheduler.archive.solution_model, domain_module, rng)

    if domain_module.config.feature_dim == 2:
        if cfg.plot_discount_points:
            fig, axs = plt.subplots(nrows=1, ncols=5, figsize=((27, 4)))
        else:
            fig, axs = plt.subplots(nrows=1, ncols=4, figsize=((20, 4)))

        plot_solution_model_evaluation(cfg, eval_info, axs[:3])

        # Discount archive is only needed if we want to plot it. Also note that
        # the discount archive is pretty cheap to make since the discount model
        # is small.
        discount_archive = make_discount_archive(
            scheduler.archive.discount_model, cfg.eval)
        plot_discount_archive(discount_archive, axs[3], domain_module.config)
        if cfg.plot_discount_points:
            visualize_discount_points(discount_train_info, discount_archive,
                                      axs[4], domain_module.config)

        fig.tight_layout()
        fig.savefig(logdir.pfile(f"heatmaps_{filename}.{filetype}"), dpi=300)
        plt.close(fig)

        if cfg.compressed_plot:
            with mpl_style_file("discount_viz.mplstyle") as f:
                with plt.style.context(f):
                    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=((6, 2)))
                    discrete_archive = scheduler.archive.solution_model.model

                    grid_archive_heatmap(
                        discrete_archive,
                        ax=axs[0],
                        rasterized=True,
                        vmin=domain_module.config.obj_low,
                        vmax=domain_module.config.obj_high,
                        cbar=None,
                    )
                    visualize_discount_points_2(discount_train_info,
                                                discount_archive, axs[1],
                                                domain_module.config)

                    axs[0].set_aspect("equal")
                    axs[1].set_aspect("equal")
                    axs[0].set_title("Archive")
                    axs[1].set_title("Discount Model")

                    axs[0].set_xticks([
                        discrete_archive.lower_bounds[0],
                        discrete_archive.upper_bounds[0]
                    ])
                    axs[0].set_yticks([
                        discrete_archive.lower_bounds[1],
                        discrete_archive.upper_bounds[1]
                    ])
                    axs[1].set_xticks([
                        discrete_archive.lower_bounds[0],
                        discrete_archive.upper_bounds[0]
                    ])
                    axs[1].set_yticks([
                        discrete_archive.lower_bounds[1],
                        discrete_archive.upper_bounds[1]
                    ])

                    fig.tight_layout()
                    fig.savefig(logdir.pfile(f"small_heatmaps_{filename}.pdf"),
                                dpi=600)
                    plt.close(fig)

    discrete_archive = scheduler.archive.solution_model.model
    if (HydraConfig.get().runtime.choices.domain == "triangles_mnist" and
            len(discrete_archive) > 0):
        imgs = discrete_archive.sample_elites(
            min(len(discrete_archive), 100),
            replace=False,
        )["measures"]
        from src.domains.triangles import render_mnist_batch_pil
        pil_image = render_mnist_batch_pil(imgs.reshape(-1, 28, 28))
        pil_image.save(logdir.pfile(f"mnist_samples_{filename}.png"))

    if (HydraConfig.get().runtime.choices.domain
            in ["triangles_afhq", "triangles_afhq_l2"] and
            len(discrete_archive) > 0):
        from src.domains.triangles import afhq_pil_image
        sample = discrete_archive.sample_elites(
            min(len(discrete_archive), 100),
            replace=False,
        )
        pil_image = afhq_pil_image(sample["solution"], sample["index"], 128)
        pil_image.save(logdir.pfile(f"afhq_samples_{filename}.png"))

    if (HydraConfig.get().runtime.choices.domain == "lsi_face" and
            len(discrete_archive) > 0):
        from src.domains.lsi_face import human_landscape_face
        sample = discrete_archive.sample_elites(
            min(len(discrete_archive), 40),
            replace=False,
        )
        pil_image = human_landscape_face(sample["solution"],
                                         domain_module.classifier,
                                         sample["index"])
        pil_image.save(logdir.pfile(f"face_samples_{filename}.png"))

    log.info(f"{name:>5} | "
             f"Corr Cov: {eval_metrics['Corrected Coverage'] * 100:.3f}%  "
             f"Corr QD: {eval_metrics['Corrected QD Score']:.3f}  "
             f"Mean Feat Err: {eval_metrics['Mean Feature Error']:.6f}  "
             f"CQD: {eval_metrics['CQD Score']:.6f}  "
             f"LCQD: {eval_metrics['Limited CQD Score']:.6f}")

    return eval_metrics


@hydra.main(version_base=None, config_path="../config", config_name="config")
def main(cfg: DictConfig):
    define_resolvers()

    logdir = setup_logdir_from_hydra()
    log.info(f"Logging directory: {logdir.logdir}")

    ## COMPONENT INITIALIZATION ##

    rng = np.random.default_rng(cfg.seed)
    if cfg.seed is not None:
        torch.manual_seed(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    domain_module = hydra.utils.instantiate(
        cfg.domain,
        seed=None if cfg.seed is None else cfg.seed + 42,
        device=device,
    )
    scheduler = build_qd_algo(cfg, domain_module, device)

    log.info(f"Discount model: {scheduler.archive.discount_model}")
    log.info(
        f"Discount model params: {scheduler.archive.discount_model.num_params()}"
    )

    eval_and_make_plots = functools.partial(eval_and_make_plots_full, scheduler,
                                            cfg, domain_module, logdir, rng)

    ## MODEL INITIALIZATION ##

    if cfg.skip_initial_eval:
        scheduler.archive.initialize_discount_model_to_min()
        eval_metrics_dict = {
            "Corrected QD Score": 0.0,
            "Corrected Coverage": 0.0,
            "Mean Feature Error": 0.0,
            "Min Feature Error": 0.0,
            "Max Feature Error": 0.0,
            "CQD Score": 0.0,
            "Limited CQD Score": 0.0,
        }
    else:
        # These plots would show what the discount model looks like before it is
        # initialized to the min values.
        #  eval_and_make_plots("pre")
        losses = scheduler.archive.initialize_discount_model_to_min()
        log.info(f"Losses from Discount Init: {losses}")
        eval_metrics_dict = eval_and_make_plots(0)

    ## SET UP METRICS ##

    timer = CodeTimer([
        "All",  # Everything in the iteration (except metrics, which are negligible).
        "Algorithm",  # Algorithm only -- excludes logging and saving, includes evals.
        "Internal",  # Algorithm time excluding evals.
        "Domain Evaluation",  # Evaluation in the domain for main solutions.
        "Model Evaluation",  # Logging / evaluation of model.
    ])
    dense_metric_list = [
        ("Evaluations", True, 0),
        *([
            ("QD Score", True, 0.0),
            ("Archive Coverage", True, 0.0),
            ("Objective Max", False),
        ] if isinstance(scheduler.archive.solution_model,
                        DiscreteArchiveSolutionModel) else []),
        # Loss on the last epoch of solution / discount model training on each
        # iteration.
        ("Final Solution Loss", False),
        ("Final Discount Loss", False),
        ("Num Empty", False),
        ("Discount Epochs", False),
        ("Restarts", True, 0),
        ("Unique Cells", False),
        ("Solutions Per Cell", False),
        *timer.metric_list(),
    ]
    sparse_metric_list = [
        ("Corrected QD Score", True, eval_metrics_dict["Corrected QD Score"]),
        ("Corrected Coverage", True, eval_metrics_dict["Corrected Coverage"]),
        ("Mean Feature Error", True, eval_metrics_dict["Mean Feature Error"]),
        ("Min Feature Error", True, eval_metrics_dict["Min Feature Error"]),
        ("Max Feature Error", True, eval_metrics_dict["Max Feature Error"]),
        ("CQD Score", True, eval_metrics_dict["CQD Score"]),
        ("Limited CQD Score", True, eval_metrics_dict["Limited CQD Score"]),
    ]

    # Metrics that are recorded on every iteration.
    metrics = MetricLogger(dense_metric_list)
    # Metrics that are recorded only every log_freq iterations.
    metrics_sparse = MetricLogger(dense_metric_list + sparse_metric_list,
                                  x_scale=cfg.log_freq)

    ## EXECUTION LOOP ##

    with logging_redirect_tqdm():
        for itr in tqdm.trange(1, cfg.itrs + 1):
            timer.start(["All", "Algorithm"])

            ## DQD ASK-TELL ##

            if cfg.algo.get("dqd"):
                solutions = scheduler.ask_dqd()
                objectives, features, info = domain_module.evaluate(solutions,
                                                                    grad=True)
                jacobian = np.concatenate(
                    (
                        info["objective_grads"][:, None, :],
                        info["measure_grads"],
                    ),
                    axis=1,
                )

                scheduler.tell_dqd(objectives, features, jacobian)

            ## PRIMARY ASK-TELL ##

            # Ask, eval, and tell regular solutions.
            solutions = scheduler.ask()
            timer.start("Domain Evaluation")
            objectives, features, info = domain_module.evaluate(solutions)
            if (HydraConfig.get().runtime.choices.domain == "triangles_mnist"
                    and domain_module.config.objective == "centroid_distance"):
                discrete_archive = scheduler.archive.solution_model.model
                # Compute distance from centroids to the measures.
                indices = discrete_archive.index_of(features)
                # (batch_size,)
                objectives = 1.0 - np.mean(
                    # (batch_size, measure_dim)
                    np.square(discrete_archive.centroids[indices] - features),
                    axis=1,
                )
            elif (HydraConfig.get().runtime.choices.domain
                  == "triangles_afhq_l2" and
                  domain_module.config.objective == "centroid_distance"):
                discrete_archive = scheduler.archive.solution_model.model
                indices, distances = discrete_archive.index_of(
                    features, return_distances=True)
                # The distance is L2, so convert it to MSE.
                objectives = 1.0 - np.square(distances) / features.shape[1]
            elif (HydraConfig.get().runtime.choices.domain == "triangles_afhq"
                  and
                  domain_module.config.objective == "centroid_cosine_distance"):
                discrete_archive = scheduler.archive.solution_model.model
                # pylint: disable-next = protected-access
                objectives, indices = discrete_archive._sklearn_nn.kneighbors(
                    features, n_neighbors=1, return_distance=True)
                objectives = objectives.reshape(len(features))
            elif (HydraConfig.get().runtime.choices.domain == "lsi_face" and
                  domain_module.config.add_centroid_dist):
                discrete_archive = scheduler.archive.solution_model.model
                # pylint: disable-next = protected-access
                distances, indices = discrete_archive._sklearn_nn.kneighbors(
                    features, n_neighbors=1, return_distance=True)
                distances = distances.reshape(len(features))
                # Rescale to [0, 1].
                distances = (distances + 1.0) / 2.0
                # Add the two objectives together, weighting them equally.
                objectives = (objectives + distances) / 2.0
            timer.end("Domain Evaluation")

            fields = {}
            data, add_info = scheduler.tell(objectives, features, **fields)

            if itr % cfg.algo.solution_model.cfg.train.freq == 0:
                solution_train_info = scheduler.archive.train_solution_model()
            else:
                solution_train_info = {"losses": [0.0]}

            # Train discount model.
            if itr % cfg.algo.discount_model.cfg.train.freq == 0:
                discount_train_info = scheduler.archive.train_discount_model(
                    data, add_info, None, None, itr)
            else:
                discount_train_info = {
                    "losses": [0.0],
                    "n_empty": 0,
                    "epochs": 0
                }

            timer.end("Algorithm")
            timer.itr_time["Internal"] = (timer.itr_time["Algorithm"] -
                                          timer.itr_time["Domain Evaluation"])

            # Logging.
            timer.start("Model Evaluation")
            if itr % cfg.log_freq == 0 or itr == cfg.itrs:
                eval_metrics_dict = eval_and_make_plots(
                    itr,
                    # Currently assume discount_train_info is available, i.e.,
                    # we train on all the iterations that we log.
                    # pylint: disable = possibly-used-before-assignment
                    discount_train_info=discount_train_info,
                )
            timer.end("Model Evaluation")

            # Saving.
            if itr % cfg.save_freq == 0 or itr == cfg.itrs:
                save_scheduler(scheduler, f"{itr}", logdir.pdir("scheduler"),
                               logdir.pdir("scheduler-tmp"))

            timer.end("All")

            # Metrics

            timer.calc_totals()
            timer_dict = timer.metrics_dict()
            timer.clear()

            # Logging and output (every iteration).
            metrics.start_itr()
            unique_indices = len(set(solution_train_info["index"]))
            dense_metrics_dict = {
                "Evaluations":
                    metrics.get_last("Evaluations") + len(solutions),
                "Unique Cells":
                    unique_indices,
                "Solutions Per Cell":
                    len(solution_train_info["index"]) / unique_indices,
            }
            if isinstance(scheduler.archive.solution_model,
                          DiscreteArchiveSolutionModel):
                stats = scheduler.archive.solution_model.model.stats
                dense_metrics_dict.update({
                    "QD Score": stats.qd_score,
                    "Archive Coverage": stats.coverage,
                    "Objective Max": stats.obj_max,
                })
            total_restarts = sum(e.restarts
                                 for e in scheduler.emitters
                                 if hasattr(e, "restarts"))
            dense_metrics_dict.update({
                "Final Solution Loss": solution_train_info["losses"][-1],
                "Final Discount Loss": discount_train_info["losses"][-1],
                "Num Empty": discount_train_info["n_empty"],
                "Discount Epochs": discount_train_info["epochs"],
                "Restarts": total_restarts,
            })
            dense_metrics_dict.update(timer_dict)
            metrics.add_dict(dense_metrics_dict)
            metrics.end_itr()

            if itr % cfg.log_freq == 0 or itr == cfg.itrs:
                # Add sparse metrics.
                metrics_sparse.start_itr()
                metrics_sparse.add_dict(dense_metrics_dict | eval_metrics_dict)
                metrics_sparse.end_itr()

    # Plot final archives as a PDF (not a PNG like during the run).
    eval_and_make_plots("final", "pdf")

    # Save and plot metrics.
    metrics.to_json(logdir.file("metrics.json"))
    metrics.plot_graphic(logdir.file("metrics_final.svg"))
    metrics_sparse.to_json(logdir.file("metrics_sparse.json"))
    metrics_sparse.plot_graphic(logdir.file("metrics_sparse_final.svg"))

    log.info("Summary:")
    for name, val in metrics_sparse.summary().items():
        log.info(f"- {name}:\t{val}")
    log.info(f"Logging directory: {logdir.logdir}")
    log.info("Done")


if __name__ == '__main__':
    main()  # pylint: disable = no-value-for-parameter
