from typing import Optional, Tuple

import jax
import jax.numpy as jnp
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire
from qdax.core.emitters.emitter import Emitter
from qdax.core.emitters.omg_mega_emitter import OMGMEGAEmitterState
from qdax.types import (
    Centroid,
    Descriptor,
    ExtraScores,
    Fitness,
    Genotype,
    Gradient,
    RNGKey,
)


class DiscreteGradientState(OMGMEGAEmitterState):
    """Contains training state for the learner."""

    genotypes: Genotype
    gradients: Gradient
    current_step: int


class BinaryGIDE(Emitter):
    def __init__(
        self,
        batch_size: int,
        num_descriptors: int,
        centroids: Centroid,
        temperature: int = 2,
        num_flip: int = 1,
        num_steps: int = 1,
        only_fitness_gradient: bool = False,
        only_diversity_gradient: bool = False,
        coef_sampling_type: str = "gaussian",
        sigma_diag: float = 1.0,
        fitness_scale: float = 1.0,
        diversity_scale: float = 1.0,
        fitness_proportion: float = 0.5,
    ):

        self._batch_size = batch_size
        self._temperature = temperature
        self._mu = jnp.zeros((num_descriptors + 1,))
        self._sigma = jnp.eye(num_descriptors + 1) * sigma_diag
        self._centroids = centroids
        self._num_descriptors = num_descriptors

        self._coef_sampling_type = coef_sampling_type
        self._sigma_diag = sigma_diag
        self._num_flip = num_flip
        self._num_steps = num_steps
        self._only_fitness_gradient = only_fitness_gradient
        self._only_diversity_gradient = only_diversity_gradient

        self._fitness_proportion = fitness_proportion
        self._fitness_scale = fitness_scale
        self._diversity_scale = diversity_scale

    def emit(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: DiscreteGradientState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """
        Emitter based on gradient proposal.

        Args:
            repertoire: a repertoire of genotypes and gradients.
            emitter_state: the state of the emitter.
            random_key: a random key to handle random operations.

        Returns:
            A batch of offspring, a new random key.
        """

        # sample genotypes
        (
            genotypes,
            _,
        ) = repertoire.sample(random_key, num_samples=self._batch_size)

        # sample gradients - use the same random key for sampling
        # See class docstrings for discussion about this choice
        (gradients, random_key,) = emitter_state.gradients_repertoire.sample(
            random_key, num_samples=self._batch_size
        )

        num_features = genotypes.shape[-1]
        num_descriptors = len(self._mu) - 1
        # Choose whether to keep new random genotypes
        sample_new = emitter_state.current_step >= self._num_steps

        genotypes = jnp.where(sample_new, genotypes, emitter_state.genotypes)
        grads = jnp.where(sample_new, gradients, emitter_state.gradients)

        # Draw random coefficients
        random_key, subkey = jax.random.split(random_key)

        if self._coef_sampling_type == "gaussian":
            coeffs = jax.random.multivariate_normal(
                subkey,
                shape=(self._batch_size,),
                mean=self._mu,
                cov=self._sigma,
            )

            if self._sigma_diag == 0.0:
                coeffs = jnp.zeros_like(coeffs)

        elif self._coef_sampling_type == "discrete_uniform":
            is_fitness_grad = jax.random.bernoulli(
                subkey, p=self._fitness_proportion, shape=(self._batch_size, 1)
            )
            grad_mask = jax.nn.one_hot(0, num_descriptors + 1)

            random_key, subkey = jax.random.split(random_key)
            descriptor_idx = jax.random.randint(
                random_key,
                minval=1,
                maxval=num_descriptors + 1,
                shape=(self._batch_size,),
            )
            descriptor_mask = jax.nn.one_hot(descriptor_idx, num_descriptors + 1)
            random_key, subkey = jax.random.split(random_key)
            descriptor_mask *= (
                jax.random.bernoulli(subkey, shape=(self._batch_size, 1)) * 2 - 1
            )

            coeffs = (
                is_fitness_grad * grad_mask * self._fitness_scale
                + (1 - is_fitness_grad)
                * (1 - grad_mask)
                * descriptor_mask
                * self._diversity_scale
            )

        # Set positive coefficients for fitness
        coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
        assert not (
            self._only_fitness_gradient and self._only_diversity_gradient
        ), "Can't set only fitness gradient AND only diversity gradient"

        if self._only_fitness_gradient:
            coeffs = coeffs.at[:, 1:].set(jnp.zeros_like(coeffs[:, 1:]))

        if self._only_diversity_gradient:
            coeffs = coeffs.at[:, 0].set(jnp.zeros_like(coeffs[:, 0]))

        # # Normalize the gradients
        norm_grad = jnp.sqrt(jnp.sum(grads**2, axis=(1, 2), keepdims=True))

        grads = jnp.where(grads == 0, grads, grads / norm_grad)

        # Compute update gradient
        update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1)

        # Compute probabilities of flipping
        p_flip = self.compute_p_flip(genotypes, update_grad)

        # Draw updates
        random_keys = jax.random.split(random_key, num=self._batch_size)
        rand_index = jax.vmap(self._generate_random_index, in_axes=(0, 0, None))(
            random_keys, p_flip, num_features
        )
        random_key, _ = jax.random.split(random_keys[0])
        new_genotypes = jnp.where(rand_index, (1 - genotypes), genotypes)

        extra_emit = {}
        return new_genotypes, random_key, extra_emit

    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[DiscreteGradientState, RNGKey]:
        """Initialises the state of the emitter. Creates an empty repertoire
        that will later contain the gradients of the individuals.

        Args:
            init_genotypes: The genotypes of the initial population.
            random_key: a random key to handle stochastic operations.

        Returns:
            The initial emitter state.
        """
        # Initialize grid with default values
        num_centroids = self._centroids.shape[0]
        default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
        default_gradients = jax.tree_map(
            lambda x: jnp.zeros(
                shape=(num_centroids,) + x.shape[1:] + (self._num_descriptors + 1,)
            ),
            init_genotypes,
        )
        default_descriptors = jnp.zeros(
            shape=(num_centroids, self._centroids.shape[-1])
        )

        # instantiate de gradients repertoire
        gradients_repertoire = MapElitesRepertoire(
            genotypes=default_gradients,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            centroids=self._centroids,
        )

        return (
            DiscreteGradientState(gradients_repertoire, None, None, 0),
            random_key,
        )

    def state_update(
        self,
        emitter_state: DiscreteGradientState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores] = None,  # type: ignore
    ) -> DiscreteGradientState:
        # get gradients out of the extra scores
        keys = extra_scores.keys()  # type: ignore
        assert "gradients" in keys, "Missing gradients or wrong key"
        gradients = extra_scores["gradients"]  # type: ignore
        current_step = emitter_state.current_step + 1
        current_step = current_step * (current_step <= self._num_steps)

        # update the gradients repertoire
        gradients_repertoire = emitter_state.gradients_repertoire.add(
            gradients, descriptors, fitnesses
        )

        return emitter_state.replace(  # type: ignore
            genotypes=genotypes,
            gradients=gradients,
            gradients_repertoire=gradients_repertoire,
            current_step=current_step,
        )

    def compute_p_flip(
        self,
        genotypes: Genotype,
        grads: Gradient,
    ) -> jnp.ndarray:
        """
        Informed proposal based on the gradients.

        Args:
            genotypes: sampled genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            grads: corresponding gradients, pytree in which leaves
                have shape (batch_size, num_features)

        Returns:
            p: probability of flipping each position, array of
                shape (batch_size, num_features)
        """

        d = -(2 * genotypes - 1) * grads / self._temperature
        # p_old = jnp.exp(d) / (jnp.sum(jnp.exp(d), axis=-1, keepdims=True) + 1e-8)
        p = jax.nn.softmax(d, axis=-1)
        return p

    def _generate_random_index(
        self,
        random_key: RNGKey,
        p: jnp.ndarray,
        num_features: int,
    ) -> jnp.ndarray:
        """
        Draws a random index position  based on probability p and
        returns a boolean index of length (num_features).
        """
        rand_index = jax.random.choice(
            random_key, num_features, shape=(self._num_flip, 1), replace=False, p=p
        )

        rand_index = jax.vmap(lambda idx: jnp.arange(num_features) == idx)(rand_index)

        rand_index = jnp.sum(rand_index, axis=0)
        return rand_index
