"""Provides ModelBasedArchive."""
import logging
import warnings

import numpy as np
import torch
from hydra.core.hydra_config import HydraConfig
from ribs._utils import np_scalar, validate_batch
from ribs.archives import ArchiveDataFrame, ArrayStore, GridArchive
from ribs.archives._archive_base import _ARCHIVE_FIELDS, parse_dtype
from scipy.spatial import cKDTree

from src.cvt.cvt_archive_2 import CVTArchive2

log = logging.getLogger(__name__)


class ModelBasedArchive:
    """The discount model maps from measures to threshold / discount values.

    Args:
        solution_dim: Dimensionality of solutions.
        ranges (array-like of (float, float)): Upper and lower bound of each
            dimension of the measure space, e.g. ``[(-1, 1), (-2, 2)]``
            indicates the first dimension should have bounds :math:`[-1,1]`
            (inclusive), and the second dimension should have bounds
            :math:`[-2,2]` (inclusive). ``ranges`` should be the same length as
            ``dims``.
        threshold_min: CMA-MAE initial threshold value.
        learning_rate: CMA-MAE archive learning rate.
        buffer_size: Initial size of the buffer. The buffer will double in size
            every time it fills up.
        only_record_added_sols: If True, the buffer will only record solutions
            that exceeded the discount threshold in add(). Otherwise, the buffer
            will store all solutions passed to add().
        probe_from_archive: If True, we will ignore probe_info in
            train_discount_model() and instead query directly from the discrete
            archive.
        use_old_probing_data: When probe_data is not provided, if this flag is
            true, then we will use previous probing data. Only applies when
            probe_from_archive is False.
        discount_model: Model for discount function.
        solution_model: Storage for solutions.
        device: PyTorch device for the models (the archive data itself is stored
            on CPU).
        seed (int): Value to seed the random number generator. Set to None to
            avoid a fixed seed.
        dtype (str or data-type or dict): Data type of the solutions,
            objectives, and measures. We only support ``"f"`` / ``np.float32``
            and ``"d"`` / ``np.float64``. Alternatively, this can be a dict
            specifying separate dtypes, of the form ``{"solution": <dtype>,
            "objective": <dtype>, "measures": <dtype>}``.
        extra_fields (dict): Description of extra fields of data that is stored
            next to elite data like solutions and objectives. The description is
            a dict mapping from a field name (str) to a tuple of ``(shape,
            dtype)``. For instance, ``{"foo": ((), np.float32), "bar": ((10,),
            np.float32)}`` will create a "foo" field that contains scalar values
            and a "bar" field that contains 10D values. Note that field names
            must be valid Python identifiers, and names already used in the
            archive are not allowed.
    """

    def __init__(
        self,
        solution_dim,
        ranges,
        threshold_min,
        learning_rate,
        buffer_size,
        use_discount_buffer: bool,
        discount_buffer_size: int,
        only_record_added_sols: bool,
        probe_from_archive: bool,
        use_old_probing_data: bool,
        discount_model,
        solution_model,
        device,
        seed=None,
        # Note: ribs.archives.ArchiveBase defaults to float64. This defaults to
        # float32 to work more easily with PyTorch.
        dtype=np.float32,
        extra_fields=None,
    ):
        self.solution_dim = solution_dim
        self.measure_dim = len(ranges)

        self.seed = seed
        self.rng = np.random.default_rng(seed)

        extra_fields = extra_fields or {}
        if _ARCHIVE_FIELDS & extra_fields.keys():
            raise ValueError("The following names are not allowed in "
                             f"extra_fields: {_ARCHIVE_FIELDS}")

        dtype = parse_dtype(dtype)
        self.buffer = ArrayStore(
            field_desc={
                "solution": ((self.solution_dim,), dtype["solution"]),
                "objective": ((), dtype["objective"]),
                "measures": ((self.measure_dim,), dtype["measures"]),
                **extra_fields,
            },
            capacity=buffer_size,
        )

        ranges2 = list(zip(*ranges))
        self.lower_bounds = np.array(ranges2[0], dtype=self.dtypes["measures"])
        self.upper_bounds = np.array(ranges2[1], dtype=self.dtypes["measures"])
        self.interval_size = self.upper_bounds - self.lower_bounds

        if threshold_min != -np.inf and learning_rate is None:
            raise ValueError(
                "You set threshold_min without setting learning_rate. "
                "Please note that threshold_min is only used in CMA-MAE; "
                "it is not intended to be used for only filtering archive "
                "solutions. If you would like to run CMA-MAE, please also set "
                "learning_rate.")
        if learning_rate is None:
            learning_rate = 1.0  # Default value.
        if threshold_min == -np.inf and learning_rate != 1.0:
            raise ValueError("threshold_min can only be -np.inf if "
                             "learning_rate is 1.0")
        self.learning_rate = np_scalar(learning_rate, self.dtypes["objective"])
        self.threshold_min = np_scalar(threshold_min, self.dtypes["objective"])

        self.only_record_added_sols = only_record_added_sols

        self.discount_model = discount_model
        self.solution_model = solution_model
        self.device = device

        self.empty_dist = self.discount_model.cfg.train.empty_dist

        self.probe_from_archive = probe_from_archive
        self.use_old_probing_data = use_old_probing_data
        self.probing_cache = {
            "empty_features": np.empty((0, self.measure_dim)),
            "non_empty_features": np.empty((0, self.measure_dim)),
            "n_empty": 0,
        }

        # Discount buffer.
        self.use_discount_buffer = use_discount_buffer
        self.discount_buffer_size = discount_buffer_size
        self.discount_buffer_measures = np.empty(
            (self.discount_buffer_size, self.measure_dim))
        self.discount_buffer_values = np.empty(self.discount_buffer_size)
        self.discount_buffer_occupied = 0

        # The acceptance threshold for the buffer.
        self._w = np.exp(np.log(self.rng.uniform()) / buffer_size)
        # Number of solutions to skip.
        self._n_skip = int(np.log(self.rng.uniform()) / np.log(1 - self._w))

    @property
    def empty(self):
        """bool: Whether the archive is empty."""
        return len(self.buffer) == 0

    @property
    def dtypes(self):
        """dict: Mapping from field name to dtype for all fields in the
        archive."""
        return self.buffer.dtypes

    def sample_elites(self, n):
        """Randomly samples elites from the archive.

        See ribs.archives.ArchiveBase.sample_elites for background.
        """
        if self.solution_model.model.empty:
            self.train_solution_model()
        return self.solution_model.model.sample_elites(n)

    def add(self, solution, objective, measures, **fields):
        """Adds solutions to the archive.

        See ribs.archives.ArchiveBase.add for background.
        """
        data = validate_batch(
            self,
            {
                "solution": solution,
                "objective": objective,
                "measures": measures,
                **fields,
            },
        )

        discount = self.discount_model.chunked_inference(
            data["measures"]).detach().cpu().numpy()
        added = data["objective"] > discount

        add_info = {
            # If objective > discount, then the solution exceeded the threshold
            # and thus was "added" -- we multiply by 2 since the new status is
            # 2, and we don't distinguish between new and improve status in this
            # archive.
            "status": 2 * added,
            "value": data["objective"] - discount,
            "discount": discount,
        }

        if self.only_record_added_sols:
            new_size = len(self.buffer) + np.sum(added)
            new_data = {
                name: None if arr is None else arr[added]
                for name, arr in data.items()
            }
        else:
            new_size = len(self.buffer) + len(data["objective"])
            new_data = data

        if new_size > self.buffer.capacity:
            # Resize the buffer by doubling its capacity. We may need to double
            # the capacity multiple times. The log2 below indicates how many
            # times we would need to double the capacity. We obtain the final
            # multiplier by raising to a power of 2.
            multiplier = 2**int(
                np.ceil(np.log2(new_size / self.buffer.capacity)))
            self.buffer.resize(multiplier * self.buffer.capacity)

        self.buffer.add(
            indices=np.arange(len(self.buffer), new_size, dtype=np.int32),
            new_data=new_data,
            extra_args={},
            transforms=[],
        )

        ## DISCOUNT BUFFER UPDATE ##

        batch_size = measures.shape[0]
        discount_targets = np.where(
            objective > add_info["discount"],
            (1.0 - self.learning_rate) * add_info["discount"] +
            self.learning_rate * objective,
            add_info["discount"],
        )

        # Downsampling the buffer using reservoir sampling.
        # https://dl.acm.org/doi/pdf/10.1145/198429.198435

        # Fill the buffer.
        n_fill = 0
        if self.discount_buffer_size > self.discount_buffer_occupied:
            n_fill = min(
                self.discount_buffer_size - self.discount_buffer_occupied,
                batch_size)
            # yapf: disable
            self.discount_buffer_measures[self.discount_buffer_occupied:self.discount_buffer_occupied + n_fill] = measures[:n_fill]
            self.discount_buffer_values[self.discount_buffer_occupied:self.discount_buffer_occupied + n_fill] = discount_targets[:n_fill]
            # yapf: enable
            measures = measures[n_fill:]
            self.discount_buffer_occupied += n_fill

        # Replace measures in the buffer using reservoir sampling.
        n_remaining = measures.shape[0]
        while n_remaining > 0:
            # Done with skipping, replace measures.
            if self._n_skip < n_remaining:
                replace = self.rng.integers(self.discount_buffer_size)
                # yapf: disable
                self.discount_buffer_measures[replace] = measures[self._n_skip]
                self.discount_buffer_values[replace] = discount_targets[self._n_skip]
                # yapf: enable
                self._w *= np.exp(
                    np.log(self.rng.uniform()) / self.discount_buffer_size)
                self._n_skip = int(
                    np.log(self.rng.uniform()) / np.log(1 - self._w))
            skip = min(self._n_skip, n_remaining)
            n_remaining -= skip
            self._n_skip -= skip

        return add_info

    def buffer_data(self, fields=None, return_type="dict"):
        """Identical to ArchiveBase.data, but only returns data in current
        buffer."""
        data = self.buffer.data(fields, return_type)
        if return_type == "pandas":
            data = ArchiveDataFrame(data)
        return data

    def train_solution_model(self):
        # Add everything in the buffer to the archive to "train" it.
        data = self.buffer.data()
        data.pop("index")
        add_info = self.solution_model.model.add(**data)

        # Clear the buffer since we have used up the data.
        self.buffer.clear()

        return {"losses": [0.0], "index": add_info["index"]}

    def initialize_discount_model_to_min(self):
        """Initializes the discount model so that it (roughly) outputs
        threshold_min everywhere."""
        if self.discount_model.cfg.init.train_min:
            if HydraConfig.get().runtime.choices.domain in [
                    "triangles_mnist", "triangles_afhq", "lsi_face"
            ]:
                measures = torch.tensor(
                    self.solution_model.model.centroids,
                    dtype=torch.float32,
                    device=self.device,
                )
            else:
                measures = torch.tensor(
                    self.rng.uniform(
                        low=self.lower_bounds,
                        high=self.upper_bounds,
                        size=(self.discount_model.cfg.init.train_points,
                              self.measure_dim),
                    ),
                    dtype=torch.float32,
                    device=self.device,
                )

            targets = torch.full(
                (len(measures),),
                self.threshold_min,
                dtype=torch.float32,
                device=self.device,
            )

            return self.discount_model.training_loop(measures, targets)
        else:
            # Return zero training losses.
            return np.zeros(self.discount_model.cfg.train.epochs)

    def _rejection_sample_same_features(self, cur_features):
        """Uses rejection sampling to find points where the discount function
        should maintain the same value. These points are points that are (1)
        within the bounds of the archive's feature space and (2) at least
        `same_dist` away from the current features in the batch."""

        if self.discount_model.cfg.train.normalize_features_before_dist:
            # When normalizing features by interval size, we usually do
            # norm_dist = (actual - query) / interval_size. Since we are using a
            # kdtree, we instead need to do
            # norm_dist = (actual / interval_size) - (query / interval_size),
            # i.e., normalize the features whenever they are passed to the
            # kdtree.
            cur_kd_tree = cKDTree(cur_features / self.interval_size)
        else:
            cur_kd_tree = cKDTree(cur_features)

        # Start with an empty array so that the code still runs when n_same=0.
        same_features = [np.empty((0, self.measure_dim))]

        n_found = 0
        sampling_itrs = 0

        # Keep sampling features until we find at least n_same that satisfy our
        # conditions.
        while n_found < self.discount_model.cfg.train.n_same:
            new_candidates = self.rng.uniform(
                low=self.lower_bounds,
                high=self.upper_bounds,
                # Sample X times the batch size at a time to give more samples.
                size=(3 * self.discount_model.cfg.train.n_same,
                      self.measure_dim),
            )

            # Count how many of the candidates are invalid by being too close to
            # the current batch.
            if self.discount_model.cfg.train.normalize_features_before_dist:
                # See comment on initialization of cur_kd_tree.
                dists, _ = cur_kd_tree.query(new_candidates /
                                             self.interval_size,
                                             k=1)
            else:
                dists, _ = cur_kd_tree.query(new_candidates, k=1)

            far_enough = dists > self.discount_model.cfg.train.same_dist

            valid_candidates = new_candidates[far_enough]

            same_features.append(valid_candidates)
            n_found += len(valid_candidates)

            # Warn if we have resampled too many times.
            sampling_itrs += 1
            if sampling_itrs > 100:
                warnings.warn("Rejection sampling re-sampled a lot of times!")

        # Concatenate everything and only take the first n_same.
        same_features = np.concatenate(same_features).astype(
            np.float32)[:self.discount_model.cfg.train.n_same]

        return same_features

    def train_discount_model(self, data, add_info, probe_data, probe_add_info,
                             current_itr):
        """Trains the discount model based on information from evaluations."""
        if self.probe_from_archive:
            base_archive = self.solution_model.model
            if isinstance(base_archive, GridArchive):
                # Sample empty indices in the grid archive to determine where
                # the threshold should be held at discount_min.
                empty_indices = np.arange(
                    base_archive.cells)[~base_archive._store.occupied]
                empty_indices = self.rng.choice(
                    empty_indices,
                    size=min(len(empty_indices),
                             self.discount_model.cfg.train.n_empty),
                    replace=False)
                empty_grid_indices = base_archive.int_to_grid_index(
                    empty_indices)

                # Find the center of the corresponding cells.
                empty_features = (
                    ((empty_grid_indices + 0.5) / base_archive.dims) *
                    base_archive.interval_size + base_archive.lower_bounds)

                non_empty_features = np.empty((0, self.measure_dim))
                n_empty = len(empty_features)
            elif isinstance(base_archive, CVTArchive2):
                # Sample empty indices in the CVT archive to determine where
                # the threshold should be held at discount_min.
                empty_indices = np.arange(
                    base_archive.cells)[~base_archive._store.occupied]
                empty_indices = self.rng.choice(
                    empty_indices,
                    size=min(len(empty_indices),
                             self.discount_model.cfg.train.n_empty),
                    replace=False)
                empty_features = base_archive.centroids[empty_indices]
                non_empty_features = np.empty((0, self.measure_dim))
                n_empty = len(empty_features)

        objective = data["objective"]
        new_features = data["measures"]

        # Find features where the discount model should maintain the same value.
        same_features = self._rejection_sample_same_features(
            np.concatenate([new_features, empty_features]))

        feature_list = [new_features, empty_features, same_features]
        target_list = [
            # Original features from the batch result in the threshold
            # update rule.
            np.where(
                objective > add_info["discount"],
                (1.0 - self.learning_rate) * add_info["discount"] +
                self.learning_rate * objective,
                add_info["discount"],
            ),
            # Empty features get threshold_min.
            np.full(len(empty_features), self.threshold_min),
            # Same features get the same value.
            (self.discount_model.chunked_inference(same_features).detach().cpu(
            ).numpy() if len(same_features) > 0 else np.empty((0,))),
        ]

        if self.use_discount_buffer:
            feature_list.append(
                self.discount_buffer_measures[:self.discount_buffer_occupied])
            #  target_list.append(
            #      self.discount_buffer_values[:self.discount_buffer_occupied])
            target_list.append(
                self.discount_model.chunked_inference(
                    self.discount_buffer_measures[:self.
                                                  discount_buffer_occupied]).
                detach().cpu().numpy())

        features = torch.tensor(
            np.concatenate(feature_list),
            dtype=torch.float32,
            device=self.device,
        )

        targets = torch.tensor(
            np.concatenate(target_list),
            dtype=torch.float32,
            device=self.device,
        )

        losses = self.discount_model.training_loop(features, targets)

        return {
            # Number of points marked empty.
            "n_empty": n_empty,
            # New features from the emitters.
            "new_features": new_features,
            # Features that were marked as empty.
            "empty_features": empty_features,
            # Features that were considered when probing but are not considered
            # empty.
            "non_empty_features": non_empty_features,
            # Features that were marked as staying the same.
            "same_features": same_features,
            # Training losses.
            "losses": losses,
            # Training epochs.
            "epochs": len(losses),
        }
