import numpy as np
import random
from scipy.special import softmax


class MapElitesSelection:
    """Implements various selection mechanisms for the MAP-Elites grid"""
    def __init__(self, archive, config=None):
        self.archive = archive

    def select_random_uniform(self, size):
        """Select solutions uniformly random in the grid"""
        cells_without_none = [x for x in self.archive.cells if x is not None]

        if len(cells_without_none) < size:
            for _ in range(size-len(cells_without_none)):
                cells_without_none.append(random.randint(0,len(self.archive.container["params"])-1))
        inds = random.sample(cells_without_none, size)

        params = [self.archive.container["params"][ind] for ind in inds]
        return params

    def select_proportional_fitness(self, size):
        """Select solutions proportionately to their fitness in the grid"""
        cells_without_none = [x for x in self.archive.cells if x is not None]

        if len(cells_without_none) < size:
            for _ in range(size-len(cells_without_none)):
                cells_without_none.append(random.randint(0,len(self.archive.container["params"])-1))

        fitnesses = np.array([self.archive.container["fitnesses"][ind] for ind in cells_without_none])
        probas = softmax(fitnesses)

        inds = np.random.choice(cells_without_none, size=size, replace=False, p=probas)
        params = [self.archive.container["params"][ind] for ind in inds]
        return params

    def select_proportional_novelty(self, size):
        """Select solutions proportionately to their novelty in the grid"""
        #TODO: adapt to new novelty computation method
        cells_without_none = [x for x in self.archive.cells if x is not None]

        if len(cells_without_none) < size:
            for _ in range(size-len(cells_without_none)):
                cells_without_none.append(random.randint(0,len(self.archive.container["params"])-1))

        novelties = np.array([self.archive.novelty_score(self.archive.container["behaviors"][ind].reshape(1, -1))[0] for ind in cells_without_none]) # Not adapted to RND
        probas = softmax(novelties)

        inds = np.random.choice(cells_without_none, size=size, replace=False, p=probas)
        params = [self.archive.container["params"][ind] for ind in inds]
        return params

    def select_proportional_fitness_and_novelty(self, size):
        """Select solutions proportionately to their fitness and novelty in the grid"""
        #TODO: adapt to new novelty computation method
        cells_without_none = [x for x in self.archive.cells if x is not None]

        if len(cells_without_none) < size:
            for _ in range(size-len(cells_without_none)):
                cells_without_none.append(random.randint(0,len(self.archive.container["params"])-1))

        fitnesses = np.array([self.archive.container["fitnesses"][ind] for ind in cells_without_none])
        fitness_probas = softmax(fitnesses)

        novelties = np.array([self.archive.novelty_score(self.archive.container["behaviors"][ind].reshape(1, -1))[0] for ind in cells_without_none])
        novelties_probas = softmax(novelties)

        probas = (fitness_probas+novelties_probas)/2

        inds = np.random.choice(cells_without_none, size=size, replace=False, p=probas)
        params = [self.archive.container["params"][ind] for ind in inds]
        return params
