import logging
import os
import pickle
import sys
import time
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Tuple

import haiku as hk
import hydra
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from models.clip.clip import CLIP

# from models.clip import load_clip
from models.clip.download_clip import load as load_clip
from models.vq_vae import Decoder, Encoder, VQVAEModel
from qdax.core.containers.mapelites_repertoire import (
    MapElitesRepertoire,
    compute_euclidean_centroids,
)
from qdax.core.emitters.emitter import EmitterState
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.core.map_elites import MAPElites
from qdax.types import Descriptor, Fitness, Genotype, Metrics, RNGKey
from scorer import CLIPData, VQVAEdata, build_scoring_fn
from utils import (
    discrete_k_points_crossover,
    discrete_uniform_crossover,
    flip_random_bit_fn,
    metrics_fn,
    plot_grid_clip,
)

default_device = "cpu"
jax.config.update("jax_platform_name", default_device)


@dataclass
class ExperimentConfig:
    batch_size: int
    num_iterations: int
    num_descriptors: int
    num_centroids: int
    descriptor_type: str
    scorer_type: str
    num_flip: int
    seed: int

    crossover_type: str
    crossover_proportion: float
    k_points: int

    num_steps: int
    init_random: bool
    init_b_and_w: bool

    log_freq: int
    log_image_freq: int
    num_plot_centroids_per_ax: Tuple[int]

    image_size: Tuple[int, int, int]
    num_category: int

    # vqvae config
    num_hiddens: int
    num_residual_hiddens: int
    num_residual_layers: int
    latent_map_size: Tuple[int, int]

    # clip params
    clip_descriptors_text: Optional[List[str]]
    clip_scoring_text: Optional[str]
    clip_descriptors_range: Optional[List[int]]


@hydra.main(config_path="config", config_name="map_elites")
def main(config: ExperimentConfig) -> None:

    script_dir = sys.path[0]
    os.makedirs("images", exist_ok=True)
    # Setup logging
    logging.basicConfig(level=logging.DEBUG)
    logging.getLogger().handlers[0].setLevel(logging.INFO)
    logger = logging.getLogger(f"{__name__}")
    random_key = jax.random.PRNGKey(config.seed)
    dataset = "ImageNet"
    print(f"Setting default device to {default_device}")
    devices = jax.devices("tpu")[1:]
    print(f"TPU devices: {devices}")

    # objects_dir = os.path.join(script_dir, "objects")
    models_dir = os.path.join(script_dir, "model_weights")
    latent_map_full_size = int(jnp.prod(jnp.array(config.latent_map_size)))
    with open(
        os.path.join(
            models_dir,
            "vqvae",
            f"{dataset.lower()}",
            f"vq-vae-{config.num_category}-{latent_map_full_size}.pkl",
        ),
        "rb",
    ) as f:
        decoder_params, state = pickle.load(f)
    codebook = state["vector_quantizer_ema"]["embeddings"].T
    decoder_params = jax.tree_map(
        lambda x: jnp.array(x, dtype=jnp.float32), decoder_params
    )
    clip_data = CLIPData(
        image_apply=None,
        text_apply=None,
        similarity_score=None,
        scoring_text=config.clip_scoring_text,
        descriptors_text=config.clip_descriptors_text,
    )

    if config.descriptor_type == "clip" or config.scorer_type == "clip":

        if os.path.exists(os.path.join(models_dir, "clip", "ViT-B-32.pt")):
            image_fn, text_fn, similarity_score, clip_params, _ = load_clip(
                os.path.join(models_dir, "clip", "ViT-B-32.pt")
            )

        else:
            image_fn, text_fn, similarity_score, clip_params, _ = load_clip(
                "ViT-B/32", "cpu"
            )
        clip_data.image_apply = image_fn
        clip_data.text_apply = text_fn
        clip_data.similarity_score = similarity_score
        num_img = int(np.sqrt(config.num_centroids) // 1)
        if config.descriptor_type == "clip":
            assert config.clip_descriptors_range is not None
            descriptor_params = clip_params
            random_key, subkey = jax.random.split(random_key)
            centroids = compute_euclidean_centroids(
                grid_shape=[num_img] * 2,
                minval=config.clip_descriptors_range[:1] * 2,
                maxval=config.clip_descriptors_range[1:] * 2,
            )

        if config.scorer_type == "clip":
            scorer_params = clip_params

    def decoder_fn(x: jnp.ndarray) -> jnp.ndarray:
        d = Decoder(
            config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens
        )
        return d(x)

    def vq_vae_fn(data: jnp.ndarray) -> jnp.ndarray:
        encoder = Encoder(
            config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens
        )
        decoder = Decoder(
            config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens
        )
        pre_vq_conv1 = hk.Conv2D(
            output_channels=64,
            kernel_shape=(1, 1),
            stride=(1, 1),
            name="to_vq",
        )

        vq_vae = hk.nets.VectorQuantizerEMA(
            embedding_dim=64,
            num_embeddings=config.num_category,
            commitment_cost=0.2,
            decay=0.99,
        )

        model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1, data_variance=1)
        return model(data, False)

    vq_vae = hk.without_apply_rng(hk.transform_with_state(vq_vae_fn))
    decoder_apply = hk.without_apply_rng(hk.transform(decoder_fn)).apply

    vqvae_data = VQVAEdata(
        codebook=codebook,
        decoder_params=decoder_params,
        decoder_apply=decoder_apply,
        latent_map_size=config.latent_map_size,
    )

    scoring_fun = build_scoring_fn(
        scorer_type=config.scorer_type,
        descriptor_type=config.descriptor_type,
        scorer_params=scorer_params,
        descriptor_params=descriptor_params,
        vqvae=vqvae_data,
        clip_data=clip_data,
        compute_gradients=False,
    )

    def scoring_fn(
        x: Genotype, random_key: RNGKey
    ) -> Tuple[Fitness, Descriptor, Dict[str, jnp.ndarray], RNGKey]:
        fitness, descriptors, extra_scores, _ = scoring_fun(x)
        return fitness, descriptors, extra_scores, random_key

    random_mutation = partial(
        flip_random_bit_fn,
        categorical=True,
        num_category=config.num_category,
        num_flip=config.num_flip,
    )

    random_key, subkey = jax.random.split(random_key)
    random_code = jax.random.randint(
        random_key,
        minval=0,
        maxval=config.num_category,
        shape=(config.batch_size,) + config.latent_map_size,
        dtype=jnp.int16,
    )
    # random_code = jax.nn.one_hot(random_code, config.num_category)

    if config.init_random:
        initial_population = random_code.reshape(
            (config.batch_size, -1),
        )

    else:
        raise NotImplementedError()

    if config.crossover_type == "uniform":
        crossover_fn = discrete_uniform_crossover

    elif config.crossover_type == "k_points":
        crossover_fn = partial(discrete_k_points_crossover, k=config.k_points)
    else:
        raise NotImplementedError

    emitter = MixingEmitter(
        mutation_fn=random_mutation,
        variation_fn=crossover_fn,
        variation_percentage=config.crossover_proportion,
        batch_size=config.batch_size,
    )

    map_elites = MAPElites(
        scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn
    )

    repertoire, emitter_state, random_key = map_elites.get_distributed_init_fn_v2(
        centroids=centroids,
        devices=devices,
    )(init_genotypes=initial_population, random_key=random_key)

    update_fn = map_elites.get_distributed_update_fn_v2(
        num_iterations=1, devices=devices
    )

    start = time.time()
    intial_time = start
    current_step = 0
    num_loops = config.num_iterations

    if config.descriptor_type == "clip":
        assert config.clip_descriptors_range is not None
        assert config.clip_descriptors_text is not None
        assert config.clip_scoring_text is not None
        plot_centroids = compute_euclidean_centroids(
            grid_shape=config.num_plot_centroids_per_ax,
            minval=config.clip_descriptors_range[:1] * 2,
            maxval=config.clip_descriptors_range[1:] * 2,
        )
    num_plot_centroids = np.prod(config.num_plot_centroids_per_ax) // 1

    plot_repertoire = MapElitesRepertoire.init(
        genotypes=jnp.zeros((num_plot_centroids,) + initial_population.shape[1:]),
        fitnesses=jnp.ones(
            num_plot_centroids,
        )
        * -jnp.inf,
        descriptors=jnp.ones((num_plot_centroids, 2)),
        centroids=plot_centroids,
    )

    for i in range(num_loops + 1):

        if i > 0:
            repertoire, emitter_state, random_key, metrics = update_fn(
                repertoire, emitter_state, random_key
            )
        else:
            metrics = metrics_fn(repertoire)
            metrics = jax.tree_map(lambda x: jnp.ravel(x[None]), metrics)

        current_step = i * config.batch_size

        if i % config.log_freq == 0:

            end = time.time()

            logger.info(f"Number of steps: {current_step:.0f}")
            logger.info(
                f"Time for {config.log_freq:.0f} iterations: {end - start:.2f}s"
            )
            start = end

            print(metrics)

        if i % config.log_image_freq == 0:

            if config.descriptor_type == "clip":

                plot_repertoire = plot_repertoire.add(
                    repertoire.genotypes, repertoire.descriptors, repertoire.fitnesses
                )
                z = (
                    jax.nn.one_hot(
                        plot_repertoire.genotypes, config.num_category
                    ).reshape((-1,) + config.latent_map_size + (config.num_category,))
                    @ codebook
                )
                images = jnp.clip(decoder_apply(decoder_params, z) + 0.5, 0, 1)
                images = jnp.where(
                    plot_repertoire.fitnesses[:, None, None, None] == -jnp.inf,
                    1.0,
                    images,
                )
                images = jax.image.resize(
                    images, (images.shape[0], 128, 128, 3), method="bilinear"
                )
                images = images.astype(jnp.float32)
                fig, axes = plot_grid_clip(
                    images=images,
                    positions=plot_repertoire.centroids,
                    descriptors_range=config.clip_descriptors_range,
                    scoring_text=config.clip_scoring_text,
                    descriptors_text=config.clip_descriptors_text,
                )

                plt.close(fig)
                # fig.savefig(
                # os.path.join(config.result_folder, "samples_random.png"), transparent=True
                # )

    logger.info(f"Elapsed time: {time.time()-intial_time:.0f} s")


if __name__ == "__main__":
    main()
