from functools import partial
from typing import Any, Callable, Optional, Tuple

import jax
import jax.numpy as jnp
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire
from qdax.core.emitters.emitter import Emitter, EmitterState
from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey


class OMGMEGAEmitterState(EmitterState):
    """
    Emitter state for the CMA-MEGA emitter.

    Args:
        gradients_repertoire: MapElites repertoire containing the gradients
            of the indivuals.
    """

    gradients_repertoire: MapElitesRepertoire


class OMGMEGAEmitter(Emitter):
    """
    Class for the emitter of OMG Mega from "Differentiable Quality Diversity" by
    Fontaine et al.

    NOTE: in order to implement this emitter while staying in the MAPElites
    framework, we had to make two temporary design choices:
    - in the emit_fn function, we use the same random key to sample from the
    genotypes and gradients repertoire, in order to get the gradients that
    correspond to the right genotypes. Although acceptable, this is definitely
    not the best coding practice and we would prefer to get rid of this in a
    future version. A solution that we are discussing with the development team
    is to decompose the sampling function of the repertoire into two phases: one
    sampling the indices to be sampled, the other one retrieving the corresponding
    elements. This would enable to reuse the indices instead of doing this double
    sampling.
    - in the state_update_fn, we have to insert the gradients in the gradients
    repertoire in the same way the individuals were inserted. Once again, this is
    slightly unoptimal because the same addition mecanism has to be computed two
    times. One solution that we are discussing and that is very similar to the first
    solution discussed above, would be to decompose the addition mecanism in two
    phases: one outputing the indices at which individuals will be added, and then
    the actual insertion step. This would enable to re-use the same indices to add
    the gradients instead of having to recompute them.

    The two design choices seem acceptable and enable to have OMG MEGA compatible
    with the current implementation of the MAPElites and MAPElitesRepertoire classes.

    Our suggested solutions seem quite simple and are likely to be useful for other
    variants implementation. They will be further discussed with the development team
    and potentially added in a future version of the package.
    """

    def __init__(
        self,
        batch_size: int,
        sigma_g: float,
        num_descriptors: int,
        centroids: Centroid,
        num_category: int,
        projection_fn: Optional[Callable[[Genotype], Genotype]] = None,
        embedding_matrix: Optional[jnp.ndarray] = None,
        gradient_mask: Optional[jnp.ndarray] = None,
    ):
        """Creates an instance of the OMGMEGAEmitter class.

        Args:
            batch_size: number of solutions sampled at each iteration
            sigma_g: standard deviation for the coefficients
            num_descriptors: number of descriptors
            centroids: centroids used to create the repertoire of solutions.
                This will be used to create the repertoire of gradients.
        """

        if projection_fn:
            self._projection_fn = projection_fn
        else:
            self._projection_fn = lambda x: x

        self._mu = jnp.zeros(num_descriptors + 1)
        self._sigma = jnp.eye(num_descriptors + 1) * sigma_g
        self._batch_size = batch_size
        self._centroids = centroids
        self._num_descriptors = num_descriptors
        self._num_category = num_category

        if gradient_mask is not None:
            self._gradient_mask = gradient_mask
        else:
            self._gradient_mask = jnp.ones(num_category)

        if embedding_matrix is None:
            self._embedding_matrix = jnp.eye(self._num_category)
        else:
            self._embedding_matrix = embedding_matrix

    @partial(jax.jit, static_argnames=("self",))
    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[OMGMEGAEmitterState, RNGKey]:
        """Initialises the state of the emitter. Creates an empty repertoire
        that will later contain the gradients of the individuals.

        Args:
            init_genotypes: The genotypes of the initial population.
            random_key: a random key to handle stochastic operations.

        Returns:
            The initial emitter state.
        """

        # Initialize grid with default values

        init_genotypes = jax.nn.one_hot(init_genotypes, self._num_category)
        init_genotypes = init_genotypes.astype(jnp.bfloat16)
        num_centroids = self._centroids.shape[0]
        default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

        embedding_size = self._embedding_matrix.shape[-1]
        default_gradients = jax.tree_map(
            lambda x: jnp.zeros(
                shape=(num_centroids,)
                + x.shape[1:-1]
                + (
                    embedding_size,
                    self._num_descriptors + 1,
                ),
                dtype=x.dtype,
            ),
            init_genotypes,
        )
        default_descriptors = jnp.zeros(
            shape=(num_centroids, self._centroids.shape[-1])
        )
        default_genotypes = {
            "gradients": default_gradients,
        }
        # instantiate de gradients repertoire
        gradients_repertoire = MapElitesRepertoire(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            centroids=self._centroids,
        )

        return (
            OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire),
            random_key,
        )

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: OMGMEGAEmitterState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey, Any]:
        """
        OMG emitter function that samples elements in the repertoire and does a gradient
        update with random coefficients to create new candidates.

        Args:
            repertoire: current repertoire
            emitter_state: current emitter state, contains the gradients
            random_key: random key

        Returns:
            new_genotypes: new candidates to be added to the grid
            random_key: updated random key
        """
        # sample genotypes
        (
            genotypes,
            _,
        ) = repertoire.sample(random_key, num_samples=self._batch_size)

        # sample gradients - use the same random key for sampling
        # See class docstrings for discussion about this choice
        (gradients_dict, random_key,) = emitter_state.gradients_repertoire.sample(
            random_key, num_samples=self._batch_size
        )

        grads = gradients_dict["gradients"]

        # b: batch size; z: latent map dimension; c: size of a vector of the codebook;
        # d: number of descriptors; e: number of vectors in the codebook
        grads = jnp.einsum("bzcd,ec->bzed", grads, self._embedding_matrix)

        mask = genotypes != 0
        genotypes = jax.nn.one_hot(
            genotypes, self._num_category
        )  

        # Draw random coefficients
        random_key, subkey = jax.random.split(random_key)
        coeffs = jax.random.multivariate_normal(
            subkey,
            shape=(self._batch_size,),
            mean=self._mu,
            cov=self._sigma,
        )
        coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))

        grads = jnp.where(jnp.expand_dims(self._gradient_mask, -1), grads, 0)
        grads = jnp.where(jnp.expand_dims(mask, (-1, -2)), grads, 0)
        norm_grad = jnp.sqrt(jnp.sum(grads**2, axis=(1, 2), keepdims=True))
        grads = jnp.where(grads == 0, grads, grads / norm_grad)

        update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1)

        # update the genotypes
        new_genotypes = jax.tree_map(lambda x, y: x + y, genotypes, update_grad)

        # Project if needed
        new_genotypes = self._projection_fn(new_genotypes)
        new_genotypes = new_genotypes.argmax(axis=-1)  

        extra_emit = {}
        return new_genotypes, random_key, extra_emit

    @partial(jax.jit, static_argnames=("self",), donate_argnums=(1,))
    def state_update(
        self,
        emitter_state: OMGMEGAEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> OMGMEGAEmitterState:
        """Update the gradients repertoire to have the right gradients.

        NOTE: see discussion in the class docstrings to see how this could
        be improved.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring.
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: a dictionary with other values outputted by the
                scoring function.

        Returns:
            The modified emitter state.
        """

        # get gradients out of the extra scores
        keys = extra_scores.keys()  # type: ignore
        assert "gradients" in keys, "Missing gradients or wrong key"
        gradients = extra_scores["gradients"]  # type: ignore
        gradients_dict = {"gradients": gradients}

        # update the gradients repertoire
        gradients_repertoire = emitter_state.gradients_repertoire.add(
            gradients_dict, descriptors, fitnesses
        )

        return emitter_state.replace(  # type: ignore
            gradients_repertoire=gradients_repertoire
        )
