from dataclasses import dataclass
from functools import partial
from typing import Callable, Optional, Tuple

import jax
import jax.numpy as jnp
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire
from qdax.core.emitters.binary_gide import DiscreteGradientState
from qdax.core.emitters.emitter import Emitter
from qdax.types import (
    Centroid,
    Descriptor,
    ExtraScores,
    Fitness,
    Genotype,
    Gradient,
    RNGKey,
)
from tensorflow_probability.substrates.jax.math import find_root_chandrupatla


@jax.jit
def normalized_entropy(x: jnp.ndarray) -> jnp.ndarray:
    n = x.shape[-1]
    minval = jnp.finfo(x.dtype).min
    x = jnp.clip(x, minval)
    s = (x * jnp.log(x + 1e-9)).sum(axis=-1)
    s_max = -jnp.log(n)
    return s / s_max


def create_target_function(target_entropy: jnp.ndarray, d: jnp.ndarray) -> Callable:
    def f(t: jnp.ndarray) -> jnp.ndarray:
        p = jax.nn.softmax(d / t, axis=(-1, -2)).reshape((d.shape[0], -1))
        return jnp.expand_dims(normalized_entropy(p).mean(axis=0) - target_entropy, -1)

    return f


@dataclass
class CategoricalGIDEConfig:
    batch_size: int
    num_descriptors: int
    centroids: Centroid
    num_category: int = 512
    gradient_mask: Optional[jnp.ndarray] = None

    temperature: int = 2
    num_flip: int = 1
    num_steps: int = 1

    normalize_proposal: bool = False
    auto_temperature: bool = False
    target_entropy: float = 0.8

    coef_sampling_type: str = "gaussian"
    sigma_diag: float = 1.0

    only_fitness_gradient: bool = False
    only_diversity_gradient: bool = False

    # only for coef_sampling_type = discrete uniform (alternate)
    fitness_scale: float = 1.0
    diversity_scale: float = 1.0
    fitness_proportion: float = 0.5

    use_adam: bool = False
    crossover_percentage: float = 0.0


class CategoricalGIDE(Emitter):
    def __init__(
        self,
        config: CategoricalGIDEConfig,
        embedding_matrix: Optional[jnp.ndarray] = None,
        crossover_fn: Callable[[Genotype, Genotype, RNGKey], Genotype] = None,
    ):

        self._config = config
        self._mu = jnp.zeros((config.num_descriptors + 1,))
        self._sigma = jnp.eye(config.num_descriptors + 1) * config.sigma_diag

        if config.use_adam:
            raise NotImplementedError

        if config.gradient_mask is not None:
            self._gradient_mask = config.gradient_mask
        else:
            self._gradient_mask = jnp.ones(config.num_category)

        if embedding_matrix is None:
            self._embedding_matrix = jnp.eye(config.num_category)
        else:
            self._embedding_matrix = embedding_matrix
        if self._config.crossover_percentage > 0.0:
            self.GIDE_batch_size = int(
                self._config.crossover_percentage * self._config.batch_size
            )
            self.crossover_batch_size = self._config.batch_size - self.GIDE_batch_size
        else:
            self.GIDE_batch_size = self._config.batch_size
        if crossover_fn:
            self._crossover_fn = crossover_fn

    @partial(jax.jit, static_argnames=("self",))
    def emit(
        self,
        repertoire: DiscreteGradientState,
        emitter_state: DiscreteGradientState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey, ExtraScores]:
        """
        Emitter based on gradient proposal.

        Args:
            repertoire: a repertoire of genotypes and gradients.
            emitter_state: the state of the emitter.
            random_key: a random key to handle random operations.

        Returns:
            A batch of offspring, a new random key.
        """
        assert self._config.coef_sampling_type in [
            "gaussian",
            "discrete_uniform",
        ], "Unknown coef sampling type"

        # sample genotypes
        (
            genotypes,
            _,
        ) = repertoire.sample(random_key, num_samples=self._config.batch_size)

        # sample gradients - use the same random key for sampling
        # See class docstrings for discussion about this choice
        (gradients_dict, random_key,) = emitter_state.gradients_repertoire.sample(
            random_key, num_samples=self._config.batch_size
        )

        gradients = gradients_dict["gradients"]

        # Choose whether to keep new random genotypes
        sample_new = emitter_state.current_step >= self._config.num_steps
        genotypes = jnp.where(sample_new, genotypes, emitter_state.genotypes)
        grads = jnp.where(sample_new, gradients, emitter_state.gradients)
        genotypes = genotypes[: self.GIDE_batch_size]
        grads = grads[: self.GIDE_batch_size]
        # b: batch size; z: latent map dimension; c: size of a vector of the codebook;
        # d: number of descriptors; e: number of vectors in the codebook
        grads = jnp.einsum("bzcd,ec->bzed", grads, self._embedding_matrix)

        mask = genotypes != 0
        genotypes = jax.nn.one_hot(genotypes, self._config.num_category)

        num_features = genotypes.shape[-2]
        num_descriptors = grads.shape[-1] - 1
        # Draw random coefficients
        if self._config.coef_sampling_type == "gaussian":
            random_key, subkey = jax.random.split(random_key)
            coeffs = jax.random.multivariate_normal(
                subkey,
                shape=(self.GIDE_batch_size,),
                mean=self._mu,
                cov=self._sigma,
            )

            if self._config.sigma_diag == 0.0:
                coeffs = jnp.zeros_like(coeffs)

        elif self._config.coef_sampling_type == "discrete_uniform":
            random_key, subkey = jax.random.split(random_key)
            is_fitness_grad = jax.random.bernoulli(
                subkey,
                p=self._config.fitness_proportion,
                shape=(self.GIDE_batch_size, 1),
            )
            grad_mask = jax.nn.one_hot(0, num_descriptors + 1)
            random_key, subkey = jax.random.split(random_key)
            descriptor_idx = jax.random.randint(
                random_key,
                minval=1,
                maxval=num_descriptors + 1,
                shape=(self.GIDE_batch_size,),
            )
            descriptor_mask = jax.nn.one_hot(descriptor_idx, num_descriptors + 1)
            random_key, subkey = jax.random.split(random_key)
            descriptor_mask *= (
                jax.random.bernoulli(subkey, shape=(self.GIDE_batch_size, 1)) * 2 - 1
            )
            fitness_coeff = grad_mask * self._config.fitness_scale
            descriptor_coeff = descriptor_mask * self._config.diversity_scale
            coeffs = jnp.where(is_fitness_grad, fitness_coeff, descriptor_coeff)

        # Set positive coefficients for fitness
        coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
        if self._config.only_fitness_gradient:
            coeffs = coeffs.at[:, 1:].set(jnp.zeros_like(coeffs[:, 1:]))

        grads = jnp.where(jnp.expand_dims(self._gradient_mask, -1), grads, 0)

        grads = jnp.where(jnp.expand_dims(mask, (-1, -2)), grads, 0)
        # Normalize the gradients
        norm_grad = jnp.sqrt(jnp.sum(grads**2, axis=(1, 2), keepdims=True))
        # Handle the case of all-zeros gradients
        grads = jnp.where(grads == 0, grads, grads / norm_grad)
        # Compute update gradient
        update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1)

        # Compute probabilities of flipping
        p_flip, metrics = self.compute_p_flip(genotypes, update_grad, mask=mask)
        # Sample random features to replace

        # jax.debug.breakpoint()
        def _choice_feature_fn(key: RNGKey, p: jnp.ndarray) -> jnp.ndarray:
            return jax.random.choice(
                key=key,
                a=num_features,
                p=p,
                replace=False,
                shape=(self._config.num_flip,),
            )

        p_feature = jnp.sum(p_flip, axis=-1)
        random_keys = jax.random.split(random_key, self.GIDE_batch_size)

        features_to_flip_idx = jax.vmap(_choice_feature_fn)(
            key=random_keys, p=p_feature
        ).T

        # features_to_flip_idx = features_to_flip_idx.ravel()

        # Sample random codes for replacing chosen ones
        def _choice_code_fn(key: RNGKey, p: jnp.ndarray) -> jnp.ndarray:
            return jax.random.choice(key=key, a=self._config.num_category, p=p)

        random_keys = jax.random.split(random_keys[-1], self.GIDE_batch_size)
        p_code = p_flip[
            jnp.arange(self.GIDE_batch_size), features_to_flip_idx
        ].transpose(1, 0, 2)

        active_features = genotypes[
            jnp.arange(self.GIDE_batch_size), features_to_flip_idx
        ].transpose(1, 0, 2)

        p_code = jnp.where(active_features, jnp.zeros_like(p_code), p_code)
        random_keys = jax.random.split(
            random_keys[-1], self.GIDE_batch_size * self._config.num_flip
        )
        new_codes = jax.vmap(_choice_code_fn)(
            key=random_keys,
            p=p_code.reshape((self.GIDE_batch_size * self._config.num_flip, -1)),
        )

        new_genotypes = genotypes.at[
            jnp.arange(self.GIDE_batch_size), features_to_flip_idx
        ].set(
            jax.nn.one_hot(new_codes, self._config.num_category)
            .reshape((self.GIDE_batch_size, self._config.num_flip, -1))
            .transpose(1, 0, 2)
        )
        new_genotypes = new_genotypes.argmax(axis=-1)
        extra_emit = {}
        extra_emit.update(metrics)
        # if self._config.use_adam:
        #     extra_emit["adam_state"] = gradients_dict["adam_state"]
        if self._config.crossover_percentage > 0.0:
            (
                genotypes_1,
                _,
            ) = repertoire.sample(random_key, num_samples=self.GIDE_batch_size)
            random_key, _ = jax.random.split(random_key)
            (
                genotypes_2,
                _,
            ) = repertoire.sample(random_key, num_samples=self.GIDE_batch_size)
            random_key, _ = jax.random.split(random_key)
            genotypes_crossed, random_key = self._crossover_fn(
                genotypes_1, genotypes_2, random_key
            )
            new_genotypes = jax.tree_map(
                lambda x_1, x_2: jnp.concatenate([x_1, x_2], axis=0),
                new_genotypes,
                genotypes_crossed,
            )
            random_key, _ = jax.random.split(random_key)

        return new_genotypes, random_key, extra_emit

    @partial(jax.jit, static_argnames=("self",))
    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[DiscreteGradientState, RNGKey]:
        """Initialises the state of the emitter. Creates an empty repertoire
        that will later contain the gradients of the individuals.

        Args:
            init_genotypes: The genotypes of the initial population.
            random_key: a random key to handle stochastic operations.

        Returns:
            The initial emitter state.
        """

        init_genotypes = jax.nn.one_hot(init_genotypes, self._config.num_category)
        init_genotypes = init_genotypes.astype(jnp.bfloat16)
        # Initialize grid with default values
        num_centroids = self._config.centroids.shape[0]
        default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

        embedding_size = self._embedding_matrix.shape[-1]

        default_gradients = jax.tree_map(
            lambda x: jnp.zeros(
                shape=(num_centroids,)
                + x.shape[1:-1]
                + (
                    embedding_size,
                    self._config.num_descriptors + 1,
                ),
                dtype=x.dtype,
            ),
            init_genotypes,
        )
        default_genotypes = {
            "gradients": default_gradients,
        }
        print(default_gradients.dtype)

        default_descriptors = jnp.zeros(
            shape=(num_centroids, self._config.centroids.shape[-1])
        )

        # instantiate the gradients repertoire
        gradients_repertoire = MapElitesRepertoire(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            centroids=self._config.centroids,
        )

        return (
            DiscreteGradientState(gradients_repertoire, None, None, 0),
            random_key,
        )

    @partial(jax.jit, static_argnames=("self",), donate_argnums=(1,))
    def state_update(
        self,
        emitter_state: DiscreteGradientState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores] = None,  # type: ignore
        # Note: extra_scores might be better not "Optional"
    ) -> DiscreteGradientState:
        # get gradients out of the extra scores

        keys = extra_scores.keys()  # type: ignore
        assert "gradients" in keys, "Missing gradients or wrong key"
        gradients = extra_scores["gradients"]  # type: ignore
        gradients_dict = {"gradients": gradients}

        current_step = emitter_state.current_step + 1
        current_step = current_step * (current_step <= self._config.num_steps)

        # update the gradients repertoire
        gradients_repertoire = emitter_state.gradients_repertoire.add(
            gradients_dict, descriptors, fitnesses
        )

        return emitter_state.replace(  # type: ignore
            genotypes=genotypes,
            gradients=gradients,
            gradients_repertoire=gradients_repertoire,
            current_step=current_step,
        )

    def compute_p_flip(
        self, genotypes: Genotype, grads: Gradient, mask: jnp.ndarray
    ) -> Tuple[jnp.ndarray, ExtraScores]:
        """
        Informed proposal based on the gradients.

        Args:
            genotypes: sampled genotypes, pytree in which leaves
                have shape (batch_size, seq_len, vocab_size)
            grads: corresponding gradients, pytree in which leaves
                have shape (batch_size, seq_len, vocab_size)

        Returns:
            p: probability of flipping to a token at each position, array of
                shape (batch_size, seq_len, vocab_size)
        """

        d = grads - 1e9 * genotypes

        d = jnp.where(self._gradient_mask, d, -1e9)
        d = jnp.where(jnp.expand_dims(mask, -1), d, -1e9)
        if self._config.normalize_proposal:
            d -= (grads * genotypes).sum(axis=-1, keepdims=True)

        if self._config.auto_temperature:
            temp = self.find_optimal_temperature(
                target_entropy=self._config.target_entropy, d=d[:100]
            )
        else:
            temp = self._config.temperature
        p = jax.nn.softmax(d / temp, axis=(-1, -2))

        p_bit = (p / p.sum(axis=-1, keepdims=True)).reshape(
            (-1, self._config.num_category)
        )
        p_feat = p.sum(axis=-1)
        total_entropy = normalized_entropy(p.reshape((p.shape[0], -1)))
        total_entropy = jnp.nanmean(total_entropy)
        bit_entropy = normalized_entropy(p_bit)
        bit_entropy = jnp.nanmean(bit_entropy)
        feat_entropy = jnp.nanmean(normalized_entropy(p_feat))

        metrics = {
            "mean_entropy_category_distribution": bit_entropy,
            "mean_entropy_feature_distribution": feat_entropy,
            "mean_entropy_total": total_entropy,
            "inverse_temperature_scaling": 2 / temp,
        }
        return p, metrics

    def find_optimal_temperature(
        self, target_entropy: jnp.ndarray, d: jnp.ndarray
    ) -> jnp.ndarray:
        f = create_target_function(target_entropy=target_entropy, d=d)
        res = find_root_chandrupatla(f, low=0, high=100)
        return res.estimated_root
