import glob
import os
from typing import Any, Dict, List, Tuple

import jax
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire


@jax.jit
def normalized_entropy(x: jnp.ndarray) -> jnp.ndarray:
    n = x.shape[-1]
    minval = jnp.finfo(x.dtype).min
    x = jnp.clip(x, minval)
    s = (x * jnp.log(x + 1e-9)).sum(axis=-1)
    s_max = -jnp.log(n)
    return s / s_max


def metrics_fn(
    repertoire: MapElitesRepertoire,
) -> Dict[str, jnp.ndarray]:

    # get metrics
    adjusted_fitness = repertoire.fitnesses
    grid_empty = adjusted_fitness == -jnp.inf
    qd_score = jnp.sum(adjusted_fitness, where=~grid_empty)
    coverage = 100 * jnp.mean(1.0 - grid_empty)
    max_fitness = jnp.max(adjusted_fitness)
    return {
        "qd_score": qd_score,
        "max_fitness": max_fitness,
        "coverage": coverage,
    }


def flip_random_bit_fn(
    x: jnp.ndarray,
    random_key: jnp.ndarray,
    categorical: bool = False,
    num_category: int = 256,
    num_flip: int = 1,
) -> jnp.ndarray:

    x = jax.nn.one_hot(x, num_category)
    batch_size = x.shape[0]
    sequence_size = x.shape[1]
    random_key, sub_key = jax.random.split(random_key)
    keys = jax.random.split(sub_key, num=batch_size)

    def _mutation_fn(x: jnp.ndarray, random_key: jnp.ndarray) -> jnp.ndarray:
        # Choose a random pixel to flip uniformly
        random_key, sub_key = jax.random.split(random_key)
        # rand_index = jax.random.randint(sub_key, shape=(1,), minval=0, maxval=784)
        rand_index = jax.random.choice(
            sub_key, jnp.arange(sequence_size), replace=False, shape=(num_flip,)
        )
        new_x = x
        if categorical:
            for k in range(num_flip):
                random_key, sub_key = jax.random.split(random_key)
                new_value = jax.random.randint(
                    sub_key, minval=1, maxval=num_category, shape=(1,)
                )
                new_value_one_hot = jnp.roll(x[rand_index[k]], new_value)
                new_x = new_x.at[rand_index[k]].set(new_value_one_hot)

        else:
            # Flip it
            for k in range(num_flip):
                new_x = x.at[rand_index[k]].set(1 - x[rand_index[k]])
        return new_x

    x = jax.vmap(
        _mutation_fn,
    )(x, keys)

    x = x.argmax(axis=-1)
    return x, random_key


def discrete_uniform_crossover(
    x1: jnp.ndarray,
    x2: jnp.ndarray,
    random_key: jnp.ndarray,
    x1_proportion: float = 0.5,
) -> jnp.ndarray:
    """
    Randomly selects elements from x1 with probability x1_proportion and from x2 with
    probability 1-x1_proportion.
    """
    random_key, sub_key = jax.random.split(random_key)
    batch_size = x1.shape[0]
    is_x1 = (
        jax.random.uniform(sub_key, shape=(batch_size,) + x1.shape[1:]) < x1_proportion
    )

    y = jnp.where(is_x1, x1, x2)
    return y, random_key


def discrete_k_points_crossover(
    x1: jnp.ndarray, x2: jnp.ndarray, random_key: jnp.ndarray, k: int = 1
) -> jnp.ndarray:
    """
    Randomly selects k elements (that are not equal in x1 and x2) from x1 and
    insert them into x2.
    """

    n = x1.shape[-1]
    batch_size = x1.shape[0]
    random_keys = jax.random.split(random_key, batch_size)

    def crossover(
        x1: jnp.ndarray, x2: jnp.ndarray, random_key: jnp.ndarray
    ) -> jnp.ndarray:
        different = jnp.where(x1 != x2, 1, 0)
        swap_points = jax.random.choice(
            random_key,
            jnp.arange(n),
            p=different,
            shape=(k,),
        )

        y = x2.at[swap_points].set(x1[swap_points])

        return y, random_key

    y, random_keys = jax.vmap(crossover)(x1, x2, random_keys)
    return y, random_keys[-1]


def plot_grid_clip(
    images: jnp.ndarray,
    positions: jnp.ndarray,
    descriptors_range: List[int],
    scoring_text: str,
    descriptors_text: str,
) -> Tuple[Figure, Axes]:

    font_size = 16
    mpl_params = {
        "axes.labelsize": font_size,
        "axes.titlesize": font_size,
        "legend.fontsize": font_size,
        "xtick.labelsize": font_size,
        "ytick.labelsize": font_size,
        "text.usetex": False,
        "xtick.major.pad": 10,
        "axes.labelpad": 40,
    }

    mpl.rcParams.update(mpl_params)

    num_img = int(np.sqrt(len(images) // 1))
    num_images_rows = num_img
    num_images_columns = num_img

    fig, axes = plt.subplots(num_images_rows, num_images_columns, figsize=(9, 9))

    h = (descriptors_range[1] - descriptors_range[0]) / num_img

    for k in range(len(images)):

        x, y = positions[k]
        i = round(num_images_rows - ((y - descriptors_range[0]) / h + 1 / 2))
        j = round((x - descriptors_range[0]) / h - 1 / 2)

        ax = axes[i, j]
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(images[k], extent=[-1, 1, -1, 1])

    fig.add_axes([0.125, 0.125, 0.775, 0.755], zorder=-1)
    big_ax = fig.axes[-1]
    big_ax.set_xlim(*descriptors_range)
    big_ax.set_ylim(*descriptors_range)
    fig.subplots_adjust(wspace=0, hspace=0)
    fig.supxlabel(f"{descriptors_text[0]}", fontsize=16)
    fig.supylabel(f"{descriptors_text[1]}", fontsize=16)
    fig.suptitle(
        f"Score: {scoring_text}",
        fontsize=16,
        weight="bold",
        x=0.5,
        y=0.92,
    )
    return fig, fig.axes
