"""ES that wraps PGPE."""
import numpy as np
from ribs._utils import readonly
from ribs.emitters.opt._evolution_strategy_base import EvolutionStrategyBase
from threadpoolctl import threadpool_limits


class PGPEOptimizer(EvolutionStrategyBase):
    """Wrapper around PGPE.

    Args:
        sigma0 (float): Initial step size.
        batch_size (int or str): Number of solutions to evaluate at a time. This
            is passed directly as ``popsize`` in ``opts``.
        solution_dim (int): Size of the solution space.
        seed (int): Seed for the random number generator.
        dtype (str or data-type): Data type of solutions.
        lower_bounds (float or np.ndarray): scalar or (solution_dim,) array
            indicating lower bounds of the solution space. Scalars specify
            the same bound for the entire space, while arrays specify a
            bound for each dimension. Pass -np.inf in the array or scalar to
            indicated unbounded space.
        upper_bounds (float or np.ndarray): Same as above, but for upper
            bounds (and pass np.inf instead of -np.inf).
        opts (dict): Additional options for PGPE.
    """

    def __init__(  # pylint: disable = super-init-not-called
            self,
            sigma0,
            solution_dim,
            batch_size=None,
            seed=None,
            dtype=np.float64,
            lower_bounds=None,
            upper_bounds=None,
            opts=None):
        self.sigma0 = sigma0
        self.solution_dim = solution_dim
        self.dtype = dtype

        self._solutions = None

        self._es = None
        self._opts = opts or {}
        self._opts["solution_length"] = solution_dim
        self._opts["stdev_init"] = sigma0
        self._opts["popsize"] = batch_size

        self._rng = np.random.default_rng(seed)
        self._opts["seed"] = self._rng.integers(10000)

    @property
    def batch_size(self):
        """int: Number of solutions per iteration.

        Only valid after a call to :meth:`reset`.
        """
        return self._es._popsize

    def reset(self, x0):
        """Resets the optimizer to start at x0.

        Args:
            x0 (np.ndarray): Initial mean.
        """
        try:
            # We do not want to import at the top because that would require cma
            # to always be installed, as cma would be imported whenever this
            # class is imported.
            # pylint: disable = import-outside-toplevel
            from pgpelib import PGPE
        except ImportError as e:
            raise ImportError("pgpelib must be installed")

        self._es = PGPE(**self._opts)

    def check_stop(self, ranking_values):
        """Checks if the optimization should stop and be reset.

        Args:
            ranking_values (np.ndarray): Not used.
        """
        # We just run PGPE for a set number of iterations.
        return False

    # Limit OpenBLAS to single thread. This is typically faster than
    # multithreading because our data is too small.
    @threadpool_limits.wrap(limits=1, user_api="blas")
    def ask(self, batch_size=None):
        """Samples new solutions from the Gaussian distribution.

        Args:
            batch_size (int): batch size of the sample. Defaults to
                ``self.batch_size``.
        """
        if batch_size is not None:
            raise ValueError()
        self._solutions = np.asarray(self._es.ask())
        return readonly(self._solutions.astype(self.dtype))

    # Limit OpenBLAS to single thread. This is typically faster than
    # multithreading because our data is too small.
    @threadpool_limits.wrap(limits=1, user_api="blas")
    def tell(self, ranking_indices, ranking_values, num_parents):
        # Note: num_parents is not used here; PGPE handles it.
        #
        # Convert (batch_size, 1) array into (batch_size,).
        if ranking_values.ndim == 2 and ranking_values.shape[1] == 1:
            ranking_values = ranking_values[:, 0]

        if ranking_values.ndim == 1:
            # Directly tell values to ES since these are just 1D. The
            # ranking_values are presented with higher being better, but CMA-ES
            # minimizes, so we need to invert all values.
            self._es.tell(ranking_values)
        else:
            raise NotImplementedError()
