"""Provides qd_scatterplot."""
import matplotlib.pyplot as plt
import numpy as np
from ribs.visualize._utils import (retrieve_cmap, set_cbar,
                                   validate_heatmap_visual_args)

# Matplotlib functions tend to have a ton of args.
# pylint: disable = too-many-arguments


def qd_scatterplot(objectives,
                   measures,
                   lower_bounds,
                   upper_bounds,
                   ax=None,
                   *,
                   transpose_measures=False,
                   cmap="magma",
                   aspect=None,
                   vmin=None,
                   vmax=None,
                   cbar="auto",
                   cbar_kwargs=None,
                   rasterized=False,
                   scatter_kwargs=None):
    """Creates scatterplot of a set of points for a QD problem.

    Args:
        objectives (numpy.ndarray): (N,) array with objectives to be plotted.
        measures (numpy.ndarray): (N, measure_dim) array with measures to be
            plotted.
        lower_bounds: (measure_dim,) array with lower bounds of the measure
            space.
        upper_bounds: (measure_dim,) array with upper bounds of the measure
            space.
        ax (matplotlib.axes.Axes): Axes on which to plot the heatmap.
            If ``None``, the current axis will be used.
        transpose_measures (bool): By default, the first measure in the archive
            will appear along the x-axis, and the second will be along the
            y-axis. To switch this behavior (i.e. to transpose the axes), set
            this to ``True``. Does not apply for 1D archives.
        cmap (str, list, matplotlib.colors.Colormap): The colormap to use when
            plotting intensity. Either the name of a
            :class:`~matplotlib.colors.Colormap`, a list of RGB or RGBA colors
            (i.e. an :math:`N \\times 3` or :math:`N \\times 4` array), or a
            :class:`~matplotlib.colors.Colormap` object.
        aspect ('auto', 'equal', float): The aspect ratio of the heatmap (i.e.
            height/width). Defaults to ``'auto'`` for 2D and ``0.5`` for 1D.
            ``'equal'`` is the same as ``aspect=1``. See
            :meth:`matplotlib.axes.Axes.set_aspect` for more info.
        vmin (float): Minimum objective value to use in the plot. If ``None``,
            the minimum objective value in the archive is used.
        vmax (float): Maximum objective value to use in the plot. If ``None``,
            the maximum objective value in the archive is used.
        cbar ('auto', None, matplotlib.axes.Axes): By default, this is set to
            ``'auto'`` which displays the colorbar on the archive's current
            :class:`~matplotlib.axes.Axes`. If ``None``, then colorbar is not
            displayed. If this is an :class:`~matplotlib.axes.Axes`, displays
            the colorbar on the specified Axes.
        cbar_kwargs (dict): Additional kwargs to pass to
            :func:`~matplotlib.pyplot.colorbar`.
        rasterized (bool): Whether to rasterize the heatmap. This can be useful
            for saving to a vector format like PDF. Essentially, only the
            heatmap will be converted to a raster graphic so that the archive
            cells will not have to be individually rendered. Meanwhile, the
            surrounding axes, particularly text labels, will remain in vector
            format. This is implemented by passing ``rasterized`` to
            :func:`~matplotlib.pyplot.pcolormesh`, so passing ``"rasterized"``
            in the ``scatter_kwargs`` below will raise an error.
        scatter_kwargs (dict): Additional kwargs to pass to
            :func:`~matplotlib.pyplot.scatter`.

    Raises:
        ValueError: The archive's measure dimension must be 2D.
    """
    validate_heatmap_visual_args(
        aspect, cbar, measures.shape[1], [2],
        "Scatterplot can only be plotted for 2D measures")

    if aspect is None:
        aspect = "auto"

    # Try getting the colormap early in case it fails.
    cmap = retrieve_cmap(cmap)

    if transpose_measures:
        # Since the archive is 2D, transpose by flipping the bounds and the
        # measures (the bounds are arrays of length 2).
        lower_bounds = np.flip(lower_bounds)
        upper_bounds = np.flip(upper_bounds)

        measures = np.flip(measures, axis=1)

    # Initialize the axis.
    ax = plt.gca() if ax is None else ax
    ax.set_xlim(lower_bounds[0], upper_bounds[0])
    ax.set_ylim(lower_bounds[1], upper_bounds[1])

    ax.set_aspect(aspect)

    # Create the plot.
    scatter_kwargs = {} if scatter_kwargs is None else scatter_kwargs
    vmin = np.min(objectives) if vmin is None else vmin
    vmax = np.max(objectives) if vmax is None else vmax
    t = ax.scatter(measures[:, 0],
                   measures[:, 1],
                   c=objectives,
                   cmap=cmap,
                   vmin=vmin,
                   vmax=vmax,
                   rasterized=rasterized,
                   **scatter_kwargs)

    # Create color bar.
    set_cbar(t, ax, cbar, cbar_kwargs)
