import numpy as np

from core.archives.archive import Archive


class MapElitesArchive(Archive):
    """Implements the MAP-Elites archive from the ME paper"""
    def __init__(self, config, capacity, k, threshold=0.01):
        super(MapElitesArchive, self).__init__(capacity=capacity, k=k)
        # Minimum density of the archive
        self.threshold = threshold

        self.dim_bc = config.dim_bc  # number of dimensions of the bc
        self.nb_cells_per_dimension = config.nb_cells_per_dimension  # number of cells per dimension for map discretization
        self.min_max_bcs = config.min_max_bcs  # min and max values for each dimension of the bc

        self.cell_ids = np.arange(self.nb_cells_per_dimension ** self.dim_bc).reshape([self.nb_cells_per_dimension] * self.dim_bc)  # array containing all cell ids
        self.cells = [None for _ in range(self.nb_cells_per_dimension ** self.dim_bc)]

        # Define boundaries
        self.boundaries = []  # boundaries for each cell
        self.cell_sizes = []  # compute cell size
        for i in range(self.dim_bc):
            bc_min = self.min_max_bcs[i][0]
            bc_max = self.min_max_bcs[i][1]
            boundaries = np.arange(bc_min, bc_max + 1e-5, (bc_max - bc_min) / self.nb_cells_per_dimension)
            boundaries[0] = - np.inf
            boundaries[-1] = np.inf
            self.boundaries.append(boundaries)
            self.cell_sizes.append((bc_max - bc_min) / self.nb_cells_per_dimension)

    def add(self, param, behavior, fitness, from_novelty):
        """Add a solution to the grid"""
        cell_id = self.find_cell_id(behavior)
        if not self.init:
            self._add_to_grid(cell_id, fitness, self.position)
            super().add(param, behavior, fitness, from_novelty)
            return

        # Check if the distance is high enough from its closest neighbor
        neigh_dist, neigh_ind = self.neigh.kneighbors(X=behavior.reshape(1, -1), n_neighbors=1, return_distance=True)
        neigh_dist, neigh_ind = float(neigh_dist), int(neigh_ind)
        neig_cell_id = self.find_cell_id(self.container["behaviors"][neigh_ind])

        # Far enough from other
        if (neigh_dist > self.threshold) or (cell_id != neig_cell_id):
            self._add_to_grid(cell_id, fitness, self.position)
            super().add(param, behavior, fitness, from_novelty)
        else:
            neigh_fit = self.container["fitnesses"][neigh_ind]
            if fitness >= neigh_fit:
                # Erase the neigh
                self._add_to_grid(cell_id, fitness, neigh_ind)
                self._replace_in_container(neigh_ind, param, behavior, fitness, from_novelty)

    def _add_to_grid(self, cell_id, fitness, position):
        if self.cells[cell_id] is None:
            self.cells[cell_id] = position
        else:
            if fitness > self.container["fitnesses"][self.cells[cell_id]]:
                self.cells[cell_id] = position

    def find_cell_id(self, bc):
        """Find cell identifier of the BC map corresponding to bc"""
        coords = []
        for j in range(self.dim_bc):
            inds = np.atleast_1d(np.argwhere(self.boundaries[j] < bc[j]).squeeze())
            coords.append(inds[-1])
        coords = tuple(coords)
        cell_id = self.cell_ids[coords]
        return cell_id

    def get_best(self):
        """Return solution with best fitness from the current grid"""
        cells_without_none = [x for x in self.cells if x is not None]
        assert len(cells_without_none) > 0, "No solution found in the grid"
        fitnesses = np.array([self.container["fitnesses"][ind] for ind in cells_without_none])
        best_index = np.argmax(fitnesses)
        best_individual = cells_without_none[best_index]
        params = self.container["params"][best_individual]
        return params
