"""Core components of the MAP-Elites algorithm."""
from __future__ import annotations

from functools import partial
from typing import Any, Callable, List, Optional, Tuple

import jax
import jax.numpy as jnp
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire
from qdax.core.emitters.emitter import Emitter, EmitterState
from qdax.types import (
    Centroid,
    Descriptor,
    ExtraScores,
    Fitness,
    Genotype,
    Metrics,
    RNGKey,
)


@jax.jit
def unshard_fn(sharded_tree: Any) -> Any:
    """

    Args:
        sharded_tree:

    Returns:

    """
    tree = jax.tree_map(lambda x: jax.device_get(x), sharded_tree)
    # tree = jax.tree_map(lambda x: jax.device_put(x, jax.default_backend()), sharded_tree)
    tree = jax.tree_map(
        lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:]), tree
    )
    return tree


@partial(jax.jit, static_argnames=("num_devices",))
def shard_fn(tree: Any, num_devices: int) -> Any:
    """

    Args:
        tree:
        num_devices:

    Returns:

    """
    tree = jax.tree_map(
        lambda x: jnp.reshape(
            x,
            (
                num_devices,
                x.shape[0] // num_devices,
            )
            + x.shape[1:],
        ),
        tree,
    )
    return tree


class MAPElites:
    """Core elements of the MAP-Elites algorithm.

    Note: Although very similar to the GeneticAlgorithm, we decided to keep the
    MAPElites class independant of the GeneticAlgorithm class at the moment to keep
    elements explicit.

    Args:
        scoring_function: a function that takes a batch of genotypes and compute
            their fitnesses and descriptors
        emitter: an emitter is used to suggest offsprings given a MAPELites
            repertoire. It has two compulsory functions. A function that takes
            emits a new population, and a function that update the internal state
            of the emitter.
        metrics_function: a function that takes a MAP-Elites repertoire and compute
            any useful metric to track its evolution
    """

    def __init__(
        self,
        scoring_function: Callable[
            [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
        ],
        emitter: Emitter,
        metrics_function: Callable[[MapElitesRepertoire], Metrics],
    ) -> None:
        self._scoring_function = scoring_function
        self._emitter = emitter
        self._metrics_function = metrics_function

    @partial(jax.jit, static_argnames=("self",))
    def init(
        self,
        init_genotypes: Genotype,
        centroids: Centroid,
        random_key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]:
        """
        Initialize a Map-Elites repertoire with an initial population of genotypes.
        Requires the definition of centroids that can be computed with any method
        such as CVT or Euclidean mapping.

        Args:
            init_genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            centroids: tesselation centroids of shape (batch_size, num_descriptors)
            random_key: a random key used for stochastic operations.

        Returns:
            An initialized MAP-Elite repertoire with the initial state of the emitter,
            and a random key.
        """
        # score initial genotypes
        fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
            init_genotypes, random_key
        )

        # init the repertoire
        repertoire = MapElitesRepertoire.init(
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            centroids=centroids,
        )

        # get initial state of the emitter
        emitter_state, random_key = self._emitter.init(
            init_genotypes=init_genotypes, random_key=random_key
        )

        # update emitter state
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        return repertoire, emitter_state, random_key

    @partial(jax.jit, static_argnames=("self",))
    def _init_pmap(
        self,
        init_genotypes: Genotype,
        centroids: Centroid,
        random_key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]:
        """
        Initialize a Map-Elites repertoire with an initial population of genotypes.
        Requires the definition of centroids that can be computed with any method
        such as CVT or Euclidean mapping.

        Args:
            init_genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            centroids: tesselation centroids of shape (batch_size, num_descriptors)
            random_key: a random key used for stochastic operations.

        Returns:
            An initialized MAP-Elite repertoire with the initial state of the emitter,
            and a random key.
        """
        # score initial genotypes
        fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
            init_genotypes, random_key
        )

        # gather across all devices
        gathered_genotypes, gathered_fitnesses, gathered_descriptors = jax.tree_map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (init_genotypes, fitnesses, descriptors),
        )

        # init the repertoire
        repertoire = MapElitesRepertoire.init(
            genotypes=gathered_genotypes,
            fitnesses=gathered_fitnesses,
            descriptors=gathered_descriptors,
            centroids=centroids,
        )

        # get initial state of the emitter
        emitter_state, random_key = self._emitter.init(
            init_genotypes=init_genotypes, random_key=random_key
        )

        # update emitter state
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        return repertoire, emitter_state, random_key

    def _init_pmap_v2(
        self,
        devices: List[Any],
        init_genotypes: Genotype,
        centroids: Centroid,
        random_key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]:
        """ """
        # assumes that the number of init_genotypes is a multiple of
        # the number of devices
        num_devices = len(devices)

        @jax.jit
        def _get_keys(key: RNGKey) -> Tuple[RNGKey, RNGKey]:
            num_devices = len(devices)
            if num_devices > 1:
                random_key, *keys = jax.random.split(key, num=1 + num_devices)
                keys = jnp.stack(keys, axis=0)
            else:
                random_key, keys = jax.random.split(key, num=1 + num_devices)
                keys = jnp.expand_dims(keys, axis=0)
            return key, keys

        random_key, keys = _get_keys(random_key)

        # shard genotypes
        sharded_init_genotypes = shard_fn(init_genotypes, num_devices=num_devices)

        # score initial genotypes in parallel
        fitnesses, descriptors, extra_scores, _ = jax.pmap(
            self._scoring_function, devices=devices, axis_name="p"
        )(sharded_init_genotypes, keys)

        # unshard the arrays
        (fitnesses, descriptors, extra_scores) = unshard_fn(
            sharded_tree=(fitnesses, descriptors, extra_scores)
        )

        # init the repertoire
        repertoire = MapElitesRepertoire.init(
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            centroids=centroids,
        )

        # get initial state of the emitter
        emitter_state, random_key = self._emitter.init(
            init_genotypes=init_genotypes, random_key=random_key
        )

        # update emitter state
        emitter_state = self._emitter.state_update(
            emitter_state,
            repertoire=repertoire,
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        return repertoire, emitter_state, random_key

    def get_distributed_init_fn(
        self, centroids: Centroid, devices: List[Any]
    ) -> Callable:
        """

        Args:
            devices:

        Returns:

        """
        return jax.pmap(  # type: ignore
            partial(self._init_pmap, centroids=centroids),
            devices=devices,
            axis_name="p",
        )

    def get_distributed_init_fn_v2(
        self, centroids: Centroid, devices: List[Any]
    ) -> Callable:
        """

        Args:
            devices:

        Returns:

        """
        return partial(self._init_pmap_v2, devices=devices, centroids=centroids)

    @partial(jax.jit, static_argnames=("self",))
    def update(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: Optional[EmitterState],
        random_key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]:
        """
        Performs one iteration of the MAP-Elites algorithm.
        1. A batch of genotypes is sampled in the repertoire and the genotypes
            are copied.
        2. The copies are mutated and crossed-over
        3. The obtained offsprings are scored and then added to the repertoire.


        Args:
            repertoire: the MAP-Elites repertoire
            emitter_state: state of the emitter
            random_key: a jax PRNG random key

        Returns:
            the updated MAP-Elites repertoire
            the updated (if needed) emitter state
            metrics about the updated repertoire
            a new jax PRNG key
        """
        # generate offsprings with the emitter
        genotypes, random_key, extra_emit = self._emitter.emit(
            repertoire, emitter_state, random_key
        )
        # scores the offsprings
        fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
            genotypes, random_key
        )

        # add genotypes in the repertoire
        repertoire = repertoire.add(genotypes, descriptors, fitnesses)

        # update emitter state after scoring is made
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        # update the metrics
        metrics = self._metrics_function(repertoire)
        metrics.update(extra_emit)
        return repertoire, emitter_state, metrics, random_key

    @partial(jax.jit, static_argnames=("self",))
    def _update_pmap(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: Optional[EmitterState],
        random_key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]:
        """
        Performs one iteration of the MAP-Elites algorithm.
        1. A batch of genotypes is sampled in the repertoire and the genotypes
            are copied.
        2. The copies are mutated and crossed-over
        3. The obtained offsprings are scored and then added to the repertoire.


        Args:
            repertoire: the MAP-Elites repertoire
            emitter_state: state of the emitter
            random_key: a jax PRNG random key

        Returns:
            the updated MAP-Elites repertoire
            the updated (if needed) emitter state
            metrics about the updated repertoire
            a new jax PRNG key
        """
        # generate offsprings with the emitter
        genotypes, random_key = self._emitter.emit(
            repertoire, emitter_state, random_key
        )
        # scores the offsprings
        fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
            genotypes, random_key
        )

        # gather across all devices
        gathered_genotypes, gathered_fitnesses, gathered_descriptors = jax.tree_map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (genotypes, fitnesses, descriptors),
        )

        # add genotypes in the repertoire
        repertoire = repertoire.add(
            gathered_genotypes, gathered_descriptors, gathered_fitnesses
        )

        # update emitter state after scoring is made
        emitter_state = self._emitter.state_update(
            emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        # update the metrics
        metrics = self._metrics_function(repertoire)

        return repertoire, emitter_state, metrics, random_key

    def _update_pmap_v2(
        self,
        devices: List[Any],
        repertoire: MapElitesRepertoire,
        emitter_state: Optional[EmitterState],
        random_key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]:
        """ """
        # assumes that the number of init_genotypes is a multiple of
        # the number of devices
        num_devices = len(devices)

        @jax.jit
        def _get_keys(key: RNGKey) -> Tuple[RNGKey, RNGKey]:
            num_devices = len(devices)
            if num_devices > 1:
                random_key, *keys = jax.random.split(key, num=1 + num_devices)
                keys = jnp.stack(keys, axis=0)
            else:
                random_key, keys = jax.random.split(key, num=1 + num_devices)
                keys = jnp.expand_dims(keys, axis=0)
            return key, keys

        random_key, keys = _get_keys(random_key)

        # generate offsprings with the emitter
        genotypes, random_key, extra_emit = self._emitter.emit(
            repertoire, emitter_state, random_key
        )

        # shard genotypes
        sharded_genotypes = shard_fn(genotypes, num_devices=num_devices)

        # score initial genotypes in parallel
        fitnesses, descriptors, extra_scores, _ = jax.pmap(
            self._scoring_function, devices=devices, axis_name="p"
        )(sharded_genotypes, keys)

        # unshard the arrays
        (fitnesses, descriptors, extra_scores) = unshard_fn(
            sharded_tree=(fitnesses, descriptors, extra_scores)
        )

        # extra_scores.update(extra_emit)
        # add genotypes in the repertoire
        repertoire = repertoire.add(genotypes, descriptors, fitnesses)

        # update emitter state after scoring is made
        emitter_state = self._emitter.state_update(
            emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        # update the metrics
        metrics = self._metrics_function(repertoire)
        metrics.update(extra_emit)

        return repertoire, emitter_state, metrics, random_key

    def get_update_fn(self, num_iterations: int) -> Callable:
        """


        Args:
            num_iterations:

        Returns:

        """

        @partial(jax.jit, static_argnames=("self",))
        def _scan_update(
            carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
            unused: Any,
        ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
            """Rewrites the update function in a way that makes it compatible with the
            jax.lax.scan primitive.

            Args:
                carry: a tuple containing the repertoire, the emitter state and a
                    random key.
                unused: unused element, necessary to respect jax.lax.scan API.

            Returns:
                The updated repertoire and emitter state, with a new random key and
                    metrics.
            """
            repertoire, emitter_state, random_key = carry
            (repertoire, emitter_state, metrics, random_key,) = self.update(
                repertoire,
                emitter_state,
                random_key,
            )

            return (repertoire, emitter_state, random_key), metrics

        def update_fn(repertoire, emitter_state, random_key):  # type: ignore
            (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
                _scan_update,
                (repertoire, emitter_state, random_key),
                (),
                length=num_iterations,
            )
            return repertoire, emitter_state, random_key, metrics

        return update_fn

    def get_distributed_update_fn(
        self, num_iterations: int, devices: List[Any]
    ) -> Callable:
        """


        Args:
            num_iterations:

        Returns:

        """

        @partial(jax.jit, static_argnames=("self",))
        def _scan_update(
            carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
            unused: Any,
        ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
            """Rewrites the update function in a way that makes it compatible with the
            jax.lax.scan primitive.

            Args:
                carry: a tuple containing the repertoire, the emitter state and a
                    random key.
                unused: unused element, necessary to respect jax.lax.scan API.

            Returns:
                The updated repertoire and emitter state, with a new random key and
                    metrics.
            """
            repertoire, emitter_state, random_key = carry
            (repertoire, emitter_state, metrics, random_key,) = self._update_pmap(
                repertoire,
                emitter_state,
                random_key,
            )

            return (repertoire, emitter_state, random_key), metrics

        def update_fn(repertoire, emitter_state, random_key):  # type: ignore
            (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
                _scan_update,
                (repertoire, emitter_state, random_key),
                (),
                length=num_iterations,
            )
            return repertoire, emitter_state, random_key, metrics

        return jax.pmap(update_fn, devices=devices, axis_name="p")  # type: ignore

    def get_distributed_update_fn_v2(
        self, num_iterations: int, devices: List[Any]
    ) -> Callable:
        """


        Args:
            num_iterations:

        Returns:

        """

        def update_fn(repertoire, emitter_state, random_key):  # type: ignore
            all_metrics = []
            for _ in range(num_iterations):
                (
                    repertoire,
                    emitter_state,
                    metrics,
                    random_key,
                ) = self._update_pmap_v2(
                    devices=devices,
                    repertoire=repertoire,
                    emitter_state=emitter_state,
                    random_key=random_key,
                )
                all_metrics.append(metrics)
            all_metrics = jax.tree_map(
                lambda *leaves: jnp.stack(leaves, axis=0),
                *all_metrics,
            )
            return repertoire, emitter_state, random_key, all_metrics

        return update_fn  # type: ignore
