"""Tools for evaluation."""
import logging

import hydra
import numpy as np
from ribs.archives import CVTArchive, GridArchive
from ribs.visualize import grid_archive_heatmap

from src.cvt.cvt_archive_2 import CVTArchive2
from src.domains.domain_base import DomainBase
from src.qd_scatterplot import qd_scatterplot

log = logging.getLogger(__name__)


def compute_centers(archive):
    """Returns an array of size (archive.cells, measure_dim) containing the
    centers of the cells in a discrete archive."""
    if isinstance(archive, GridArchive):
        centers = [(b[:-1] + b[1:]) / 2.0 for b in archive.boundaries]
        feature_grid = np.meshgrid(*centers)
        feature_coords = np.stack([x.ravel() for x in feature_grid], axis=1)
        return feature_coords
    elif isinstance(archive, (CVTArchive, CVTArchive2)):
        return archive.centroids
    else:
        raise ValueError("Cannot compute centers of this archive")


def compute_feature_err(features, target_features, interval_size,
                        normalize_features_before_err):
    """Computes distance/error between features and target_features.

    Args:
        normalize_features_before_err: Divide the feature distances by the
            interval_size before taking the norm. This ensures that even if the
            two features have different scales, they contribute equally to the
            feature error.
    """
    feature_dists = features - target_features
    if normalize_features_before_err:
        feature_dists /= interval_size
    return np.linalg.norm(feature_dists, ord=2, axis=-1)


def infer_and_evaluate(solution_model,
                       domain_module,
                       input_features,
                       inference_batch_size,
                       samples=None):
    """Inputs the given features into the model and evaluates the
    results in the domain."""

    # Create inputs for the model -- either just the features or include the
    # objectives too.
    if solution_model.cfg.get("input_objs", False):
        inputs = np.concatenate(
            [
                input_features,
                np.full(
                    (len(input_features), 1), domain_module.config.obj_high),
            ],
            axis=1,
        )
    else:
        inputs = input_features

    solutions = solution_model.chunked_inference(
        inputs=inputs,
        batch_size=inference_batch_size,
        samples=samples,
        verbose=True,
    )

    if solution_model.cfg.output_type == "mnist_img":
        objectives, features, _ = domain_module.evaluate_images_torch(solutions)
    elif solution_model.cfg.output_type == "solution":
        objectives, features, _ = domain_module.evaluate_torch(solutions)
    else:
        raise ValueError(
            f"Unknown model input {solution_model.cfg.output_type}")

    solutions = solutions.detach().cpu().numpy()
    objectives = objectives.detach().cpu().numpy()
    features = features.detach().cpu().numpy()

    return solutions, objectives, features


def discrete_archive_cqd_score(
    archive_data,
    target_features,
    interval_size,
    normalize_features_before_err,
    obj_min,
    obj_max,
    penalties,
    neighbors,
    kdtree,
):
    """CQD score for a discrete archive.

    Args:
        archive_data: Dict with data from pyribs archive; see ArchiveBase.data()
        target_features: (batch_size, feature_dim) array of features that we are
            comparing to.
        interval_size: See compute_feature_err.
        normalize_features_before_err: See compute_feature_err.
        obj_min: Minimum objective.
        obj_max: Maximum objective.
        penalties: 1D array of penalty weights.
        neighbors: Number of neighbors to consider. Regular CQD score allows any
            point in the archive to be considered for each target. This integer
            restricts it to just the `neighbors` closest points in feature space
            to each target. Pass None to turn this off.
        kdtree: KDTree to use for finding nearest neighbors.
    """
    objectives = archive_data["objective"]
    features = archive_data["measures"]

    # Normalize objectives to the range [0, 1].
    objectives = objectives / (obj_max - obj_min)

    if neighbors is None:
        # Compute feature errors to all points, i.e., (normalized) distances
        # between each feature and all the reference features. Shape:
        # (len(archive), n_target_points).
        feature_err = compute_feature_err(features[:, None], target_features,
                                          interval_size,
                                          normalize_features_before_err)
    else:
        # Compute `neighbors` closest feature errors. Shape: (len(archive),
        # neighbors).
        if normalize_features_before_err:
            feature_err, indices = kdtree.query(target_features / interval_size,
                                                k=neighbors)
        else:
            feature_err, indices = kdtree.query(target_features, k=neighbors)

        objectives = objectives[indices.ravel()].reshape(
            (len(target_features), neighbors))

    # Normalize feature_err to the range [0, 1].
    if normalize_features_before_err:
        # The max dist is between [0, 0, ...] and [1, 1, ...]
        max_feature_dist = np.sqrt(len(interval_size))
    else:
        max_feature_dist = np.linalg.norm(interval_size)
    feature_err = feature_err / max_feature_dist

    score = 0.0
    for penalty in penalties:
        if neighbors is None:
            # Known as omega in Kent 2022 -- a (len(archive),
            # n_target_points) array.
            values = objectives[:, None] - penalty * feature_err

            # (n_target_points,) array.
            max_values_per_target = np.max(values, axis=0)
        else:
            # (n_target_points, neighbors)
            values = objectives - penalty * feature_err

            # (n_target_points)
            max_values_per_target = np.max(values, axis=1)

        score_for_penalty = np.sum(max_values_per_target)
        score += score_for_penalty

    # Normalize by dividing by the number of target points and the number of
    # penalties.
    score = score / (len(target_features) * len(penalties))

    return score


def evaluate_discrete_archive_solution_model(cfg, solution_model, domain_module,
                                             input_features):
    """Specialized evaluation for discrete archives."""
    domain_cfg = domain_module.config
    interval_size = (np.asarray(domain_cfg.feature_high) -
                     np.asarray(domain_cfg.feature_low))

    if len(solution_model.model) == 0:
        # Default to have max_dist as the feature_error if the grid archive is
        # empty.
        if cfg.eval.normalize_features_before_err:
            # The max dist is between [0, 0, ...] and [1, 1, ...]
            max_feature_dist = np.sqrt(domain_cfg.feature_dim)
        else:
            max_feature_dist = np.linalg.norm(interval_size)
        feature_error = np.full(len(input_features), max_feature_dist)
        solutions = np.full((len(input_features), domain_cfg.solution_dim), 0.0)
    else:
        # Compute feature error by selecting closest solution in the archive.
        kdtree, archive_data = solution_model.make_kd_tree()
        if cfg.eval.normalize_features_before_err:
            feature_error, indices = kdtree.query(input_features /
                                                  interval_size)
        else:
            feature_error, indices = kdtree.query(input_features)
        solutions = archive_data["solution"][indices]

    eval_info = {}
    eval_info["objective_archive"] = solution_model.model
    if cfg.eval.eval_points == "archive_centers":
        eval_info["feature_error_archive"] = hydra.utils.instantiate(
            cfg.eval.archive.args,
            solution_dim=domain_cfg.solution_dim,
            extra_fields=None)
        eval_info["feature_error_archive"].add(solutions, feature_error,
                                               input_features)
    elif cfg.eval.eval_points == "uniform":
        # Enough info to plot the feature error points.
        eval_info["points"] = {
            "input_features": input_features,
            "feature_err": feature_error,
        }
    eval_info["corrected_archive"] = solution_model.model

    eval_metrics = {}
    eval_metrics["Mean Feature Error"] = np.mean(feature_error)
    eval_metrics["Min Feature Error"] = np.min(feature_error)
    eval_metrics["Max Feature Error"] = np.max(feature_error)
    eval_metrics["Corrected QD Score"] = \
        eval_info['corrected_archive'].stats.qd_score
    eval_metrics["Corrected Coverage"] = \
        eval_info['corrected_archive'].stats.coverage

    if len(solution_model.model) == 0 or cfg.eval.skip_cqd:
        eval_metrics["CQD Score"] = 0.0
        eval_metrics["Limited CQD Score"] = 0.0
    else:
        eval_metrics["CQD Score"] = discrete_archive_cqd_score(
            archive_data,
            input_features,
            interval_size,
            cfg.eval.normalize_features_before_err,
            domain_cfg.obj_low,
            domain_cfg.obj_high,
            cfg.eval.cqd_penalties,
            None,
            kdtree,
        )
        eval_metrics["Limited CQD Score"] = discrete_archive_cqd_score(
            archive_data,
            input_features,
            interval_size,
            cfg.eval.normalize_features_before_err,
            domain_cfg.obj_low,
            domain_cfg.obj_high,
            cfg.eval.cqd_penalties,
            cfg.eval.cqd_discrete_neighbors,
            kdtree,
        )

    return eval_metrics, eval_info


def choose_min_feature_err(
    solutions,
    objectives,
    features,
    target_features,
    interval_size,
    normalize_features_before_err,
    initial_samples,
    considered_samples,
    return_err=False,
):
    """Chooses the solutions with the minimal feature error to the given
    `target_features`.

    Args:
        solutions: Parameters of the input solutions.
            (batch_size * initial_samples, *solution_shape)
        objectives: Objectives of the input solutions.
            (batch_size * initial_samples, 1)
        features: Features of the input solutions.
            (batch_size * initial_samples, feature_dim)
        target_features: (batch_size, feature_dim) array of features that we
            are comparing to.
        interval_size: See compute_feature_err.
        normalize_features_before_err: See compute_feature_err.
        initial_samples: Number of samples for the input solutions.
        considered_samples: Number of samples to consider, e.g., we can look at
            just the first 10 samples.
        return_err: Whether to return the feature error.
    """
    batch_size, feature_dim = target_features.shape

    # Reshape so that `samples` is along axis 1. `solutions` may not have
    # entries of size `solution_dim` eg if we are modeling images, so its shape
    # is more flexible here.
    solutions = solutions.reshape((batch_size, initial_samples) +
                                  solutions.shape[1:])
    objectives = objectives.reshape((batch_size, initial_samples))
    features = features.reshape((batch_size, initial_samples, feature_dim))

    # After reshaping, we make a view of the first `considered_samples` samples.
    solutions = solutions[:, :considered_samples]
    objectives = objectives[:, :considered_samples]
    features = features[:, :considered_samples]

    # Figure out which sample indices have the minimum feature error.
    repeated_features = np.repeat(target_features, considered_samples, axis=0)
    feature_err = compute_feature_err(
        features.reshape(batch_size * considered_samples, feature_dim),
        repeated_features,
        interval_size,
        normalize_features_before_err,
    ).reshape((batch_size, considered_samples))
    min_err_indices = feature_err.argmin(axis=1)

    # Select data with min err.
    b_range = np.arange(batch_size)
    solutions = solutions[b_range, min_err_indices]
    objectives = objectives[b_range, min_err_indices]
    features = features[b_range, min_err_indices]

    if return_err:
        feature_err = feature_err[b_range, min_err_indices]
        return solutions, objectives, features, feature_err
    else:
        return solutions, objectives, features


def solution_model_cqd_score(
    objectives,
    features,
    target_features,
    interval_size,
    normalize_features_before_err,
    obj_min,
    obj_max,
    initial_samples,
    considered_samples,
    penalties,
):
    """CQD score.

    Args:
        objectives: Objectives of the input solutions.
            (batch_size * samples, 1)
        features: Features of the input solutions.
            (batch_size * samples, feature_dim)
        target_features: (batch_size, feature_dim) array of features that we are
            comparing to.
        interval_size: See compute_feature_err.
        normalize_features_before_err: See compute_feature_err.
        obj_min: Minimum objective.
        obj_max: Maximum objective.
        initial_samples: Number of samples for the input solutions.
        considered_samples: Number of samples to consider, e.g., we can look at
            just the first 10 samples.
        penalties: 1D array of penalty weights.
    """

    batch_size, feature_dim = target_features.shape

    # Reshape so that `samples` is along axis 1.
    objectives = objectives.reshape((batch_size, initial_samples))
    features = features.reshape((batch_size, initial_samples, feature_dim))

    # After reshaping, we make a view of the first `considered_samples` samples.
    objectives = objectives[:, :considered_samples]
    features = features[:, :considered_samples]

    # Compute feature errors / distances, accounting for normalization.
    repeated_features = np.repeat(target_features, considered_samples, axis=0)
    feature_err = compute_feature_err(
        features.reshape(batch_size * considered_samples, feature_dim),
        repeated_features,
        interval_size,
        normalize_features_before_err,
    ).reshape((batch_size, considered_samples))

    # Normalize the feature errors to the range [0, 1].
    if normalize_features_before_err:
        # The max dist is between [0, 0, ...] and [1, 1, ...]
        max_feature_dist = np.sqrt(len(interval_size))
    else:
        max_feature_dist = np.linalg.norm(interval_size)
    feature_err = feature_err / max_feature_dist

    # Normalize objectives to the range [0, 1].
    objectives = objectives / (obj_max - obj_min)

    # Compute the CQD score for each penalty value and accumulate it into score.
    score = 0.0
    for penalty in penalties:
        # (batch_size, considered_samples)
        scores = objectives - penalty * feature_err

        # (batch_size,)
        max_values_per_target = np.max(scores, axis=1)

        score_for_penalty = np.sum(max_values_per_target)
        score += score_for_penalty

    # Normalize by dividing by the number of target points and the number of
    # penalties.
    score = score / (batch_size * len(penalties))

    return score


def evaluate_solution_model(
    cfg,
    solution_model,
    domain_module: DomainBase,
    rng,
    input_features: np.ndarray = None,
):
    """Computes metrics.

    Args:
        cfg: Global config object.
        solution_model: Model to evaluate.
        domain_module: Needed so that we can evaluate solutions from the model
            in the given domain.
        rng: Numpy random generator for sampling points.
        input_features: Can be passed in to indicate where the model should be
            evaluated. See code below for defaults.
    Returns:
        eval_metrics: Dict of computed metrics.
        eval_info: Contains any info for visualizing the evaluation. For
            example, it can contain various intermediate archives.
    """

    domain_cfg = domain_module.config
    eval_cfg = cfg.eval

    # Default values for input_features.
    if input_features is None:
        if eval_cfg.eval_points == "archive_centers":
            input_features = compute_centers(
                hydra.utils.instantiate(eval_cfg.archive.args, solution_dim=0))
        elif eval_cfg.eval_points == "uniform":
            input_features = rng.uniform(
                domain_cfg.feature_low,
                domain_cfg.feature_high,
                (eval_cfg.eval_n, domain_cfg.feature_dim),
            )

    return evaluate_discrete_archive_solution_model(cfg, solution_model,
                                                    domain_module,
                                                    input_features)


def plot_solution_model_evaluation(cfg, eval_info, axs):
    """Plots info from eval_solution_model."""
    domain_cfg = cfg.domain.config

    axs[0].set_title("Objective")
    if "objective_archive" in eval_info:
        grid_archive_heatmap(
            eval_info["objective_archive"],
            ax=axs[0],
            rasterized=True,
            vmin=domain_cfg.obj_low,
            vmax=domain_cfg.obj_high,
        )
    else:
        qd_scatterplot(
            eval_info["points"]["objectives"],
            eval_info["points"]["input_features"],
            domain_cfg.feature_low,
            domain_cfg.feature_high,
            ax=axs[0],
            rasterized=True,
            vmin=domain_cfg.obj_low,
            vmax=domain_cfg.obj_high,
            scatter_kwargs={"s": 32},
        )

    if cfg.eval.normalize_features_before_err:
        # The max dist is between [0, 0, ...] and [1, 1, ...]
        max_feature_dist = np.sqrt(domain_cfg.feature_dim)
        feature_err_title = "Normalized Feature Error"
    else:
        max_feature_dist = np.linalg.norm(
            np.array(domain_cfg.feature_high) -
            np.array(domain_cfg.feature_low))
        feature_err_title = "Feature Error"

    axs[1].set_title(feature_err_title)
    if "feature_error_archive" in eval_info:
        grid_archive_heatmap(
            eval_info["feature_error_archive"],
            ax=axs[1],
            rasterized=True,
            vmin=0,
            # Divide by 10 as we want to see smaller distances.
            vmax=max_feature_dist / 10.0,
        )
    else:
        qd_scatterplot(
            eval_info["points"]["feature_err"],
            eval_info["points"]["input_features"],
            domain_cfg.feature_low,
            domain_cfg.feature_high,
            ax=axs[1],
            rasterized=True,
            vmin=0,
            vmax=max_feature_dist / 10.0,
            scatter_kwargs={"s": 32},
        )

    axs[2].set_title("Corrected")
    if "corrected_archive" in eval_info:
        grid_archive_heatmap(
            eval_info["corrected_archive"],
            ax=axs[2],
            rasterized=True,
            vmin=domain_cfg.obj_low,
            vmax=domain_cfg.obj_high,
        )
    else:
        qd_scatterplot(
            eval_info["points"]["objectives"],
            eval_info["points"]["features"],
            domain_cfg.feature_low,
            domain_cfg.feature_high,
            ax=axs[2],
            rasterized=True,
            vmin=domain_cfg.obj_low,
            vmax=domain_cfg.obj_high,
            scatter_kwargs={"s": 32},
        )


def make_discount_archive(discount_model, eval_cfg):
    """Creates an archive that stores the value of the discount model at each
    cell."""
    discount_archive = hydra.utils.instantiate(eval_cfg.archive.args,
                                               solution_dim=0)

    feature_coords = compute_centers(discount_archive)

    discounts = discount_model.chunked_inference(
        feature_coords).detach().cpu().numpy()

    discount_archive.add(
        np.empty((len(feature_coords), 0)),
        discounts,
        feature_coords,
    )

    return discount_archive


def plot_discount_archive(discount_archive, ax, domain_cfg):
    """Heatmap showing discount values."""
    ax.set_title("Discount Model")
    grid_archive_heatmap(
        discount_archive,
        ax=ax,
        rasterized=True,
        # Note: This assumes the min threshold / discount is same as the min
        # objective.
        vmin=domain_cfg["obj_low"],
        vmax=domain_cfg["obj_high"],
    )
