import logging
import os
import pickle
import sys
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import haiku as hk
import hydra
import jax
import jax.numpy as jnp
import jmp
import matplotlib.pyplot as plt
import numpy as np

from models.clip.clip import CLIP
from models.clip.download_clip import load as load_clip
from models.vq_vae import Decoder, Encoder, VQVAEModel
from qdax.core.containers.mapelites_repertoire import (
    MapElitesRepertoire,
    compute_euclidean_centroids,
)
from qdax.core.emitters.categorical_gide import CategoricalGIDE, CategoricalGIDEConfig
from qdax.core.emitters.emitter import EmitterState
from qdax.core.map_elites import MAPElites
from qdax.types import Descriptor, Fitness, Genotype, Metrics, RNGKey
from scorer import CLIPData, VQVAEdata, build_scoring_fn
from utils import metrics_fn, normalized_entropy, plot_grid_clip

default_device = "cpu"
jax.config.update("jax_platform_name", default_device)

# policy = jmp.get_policy("params=float32,compute=bfloat16,output=bfloat16")
# for _cls in [Encoder, Decoder, hk.Conv2D, hk.nets.VectorQuantizerEMA, CLIP]:
#     hk.mixed_precision.set_policy(_cls, policy)


@dataclass
class ExperimentConfig:
    batch_size: int
    num_iterations: int
    num_descriptors: int
    num_centroids: int
    descriptor_type: str
    scorer_type: str
    num_flip: int
    num_category: int
    seed: int
    num_steps: int
    # dataset config
    image_size: Tuple[int, int, int]
    init_random: bool

    log_freq: int
    log_image_freq: int
    num_plot_centroids_per_ax: Tuple[int]

    # GIDE config
    only_fitness_gradient: bool
    only_diversity_gradient: bool
    coef_sampling_type: str
    sigma_diag: float
    fitness_scale: float
    diversity_scale: float
    fitness_proportion: float
    use_adam: bool
    normalize_proposal: bool
    target_entropy: float
    auto_temperature: bool

    # vqvae config
    num_hiddens: int
    num_residual_hiddens: int
    num_residual_layers: int
    latent_map_size: Tuple[int, int]

    # clip params
    clip_descriptors_text: Optional[List[str]]
    clip_scoring_text: Optional[str]
    clip_descriptors_range: Optional[List[int]]


@hydra.main(config_path="config", config_name="me_gide")
def main(config: ExperimentConfig) -> None:
    script_dir = sys.path[0]
    os.makedirs("images", exist_ok=True)
    # Setup logging
    logging.basicConfig(level=logging.DEBUG)
    logging.getLogger().handlers[0].setLevel(logging.INFO)
    logger = logging.getLogger(f"{__name__}")
    random_key = jax.random.PRNGKey(config.seed)

    print(f"Setting default device to {default_device}")
    devices = jax.devices("tpu")[1:]
    print(f"TPU devices: {devices}")

    # objects_dir = os.path.join(script_dir, "objects")
    models_dir = os.path.join(script_dir, "model_weights")
    latent_map_full_size = int(jnp.prod(jnp.array(config.latent_map_size)))
    with open(
        os.path.join(
            models_dir,
            "vqvae",
            f"{config.dataset.lower()}",
            f"vq-vae-{config.num_category}-{latent_map_full_size}.pkl",
        ),
        "rb",
    ) as f:
        decoder_params, state = pickle.load(f)

    codebook = state["vector_quantizer_ema"]["embeddings"].T
    decoder_params = jax.tree_map(
        lambda x: jnp.array(x, dtype=jnp.float32), decoder_params
    )
    clip_data = CLIPData(
        image_apply=None,
        text_apply=None,
        similarity_score=None,
        scoring_text=config.clip_scoring_text,
        descriptors_text=config.clip_descriptors_text,
    )

    if config.descriptor_type == "clip" or config.scorer_type == "clip":

        if os.path.exists(os.path.join(models_dir, "clip", "ViT-B-32.pt")):
            image_fn, text_fn, similarity_score, clip_params, _ = load_clip(
                os.path.join(models_dir, "clip", "ViT-B-32.pt")
            )

        else:
            image_fn, text_fn, similarity_score, clip_params, _ = load_clip(
                "ViT-B/32", "cpu"
            )
        clip_data.image_apply = image_fn
        clip_data.text_apply = text_fn
        clip_data.similarity_score = similarity_score
        num_img = int(np.sqrt(config.num_centroids) // 1)
        if config.descriptor_type == "clip":

            assert config.clip_descriptors_range is not None
            descriptor_params = clip_params
            random_key, subkey = jax.random.split(random_key)
            centroids = compute_euclidean_centroids(
                grid_shape=[num_img] * 2,
                minval=config.clip_descriptors_range[:1] * 2,
                maxval=config.clip_descriptors_range[1:] * 2,
            )

        if config.scorer_type == "clip":
            scorer_params = clip_params

    def decoder_fn(x: jnp.ndarray) -> jnp.ndarray:
        d = Decoder(
            config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens
        )
        return d(x)

    def vq_vae_fn(data: jnp.ndarray) -> jnp.ndarray:
        encoder = Encoder(
            config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens
        )
        decoder = Decoder(
            config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens
        )
        pre_vq_conv1 = hk.Conv2D(
            output_channels=64,
            kernel_shape=(1, 1),
            stride=(1, 1),
            name="to_vq",
        )

        vq_vae = hk.nets.VectorQuantizerEMA(
            embedding_dim=64,
            num_embeddings=config.num_category,
            commitment_cost=0.2,
            decay=0.99,
        )

        model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1, data_variance=1)
        return model(data, False)

    vq_vae = hk.without_apply_rng(hk.transform_with_state(vq_vae_fn))
    decoder_apply = hk.without_apply_rng(hk.transform(decoder_fn)).apply

    vqvae_data = VQVAEdata(
        codebook=codebook,
        decoder_params=decoder_params,
        decoder_apply=decoder_apply,
        latent_map_size=config.latent_map_size,
    )

    scoring_fun = build_scoring_fn(
        scorer_type=config.scorer_type,
        descriptor_type=config.descriptor_type,
        scorer_params=scorer_params,
        descriptor_params=descriptor_params,
        vqvae=vqvae_data,
        clip_data=clip_data,
        compute_gradients=True,
    )

    @jax.jit
    def scoring_fn(
        x: Genotype, random_key: RNGKey
    ) -> Tuple[Fitness, Descriptor, Dict[str, jnp.ndarray], RNGKey]:
        fitness, descriptors, extra_scores, _ = scoring_fun(x)
        return fitness, descriptors, extra_scores, random_key

    random_key, subkey = jax.random.split(random_key)
    random_code = jax.random.randint(
        random_key,
        minval=0,
        maxval=config.num_category,
        shape=(config.batch_size,) + config.latent_map_size,
    )
    # random_code = jax.nn.one_hot(random_code, config.num_category)
    if config.init_random:
        initial_population = random_code.reshape((config.batch_size, -1))

    else:
        raise NotImplementedError()

    emitter_config = CategoricalGIDEConfig(
        batch_size=config.batch_size,
        num_descriptors=config.num_descriptors,
        num_flip=config.num_flip,
        num_steps=config.num_steps,
        centroids=centroids,
        only_fitness_gradient=config.only_fitness_gradient,
        only_diversity_gradient=config.only_diversity_gradient,
        sigma_diag=config.sigma_diag,
        diversity_scale=config.diversity_scale,
        fitness_scale=config.fitness_scale,
        fitness_proportion=config.fitness_proportion,
        coef_sampling_type=config.coef_sampling_type,
        use_adam=config.use_adam,
        normalize_proposal=config.normalize_proposal,
        target_entropy=config.target_entropy,
        auto_temperature=config.auto_temperature,
    )

    emitter = CategoricalGIDE(config=emitter_config, embedding_matrix=codebook)

    map_elites = MAPElites(
        scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn
    )

    # Compute initial repertoire and emitter state
    repertoire, emitter_state, random_key = map_elites.get_distributed_init_fn_v2(
        centroids=centroids,
        devices=devices,
    )(init_genotypes=initial_population, random_key=random_key)

    update_fn = map_elites.get_distributed_update_fn_v2(
        num_iterations=1, devices=devices
    )

    start = time.time()
    intial_time = start
    current_step = 0
    num_loops = config.num_iterations
    if config.descriptor_type == "clip":
        assert config.clip_descriptors_range is not None
        assert config.clip_descriptors_text is not None
        assert config.clip_scoring_text is not None

        plot_centroids = compute_euclidean_centroids(
            grid_shape=config.num_plot_centroids_per_ax,
            minval=config.clip_descriptors_range[:1] * 2,
            maxval=config.clip_descriptors_range[1:] * 2,
        )
        num_plot_centroids = np.prod(config.num_plot_centroids_per_ax) // 1

        plot_repertoire = MapElitesRepertoire.init(
            genotypes=jnp.zeros((num_plot_centroids,) + initial_population.shape[1:]),
            fitnesses=jnp.ones(
                num_plot_centroids,
            )
            * -jnp.inf,
            descriptors=jnp.ones((num_plot_centroids, 2)),
            centroids=plot_centroids,
        )

    for i in range(num_loops + 1):

        if i > 0:
            repertoire, emitter_state, random_key, metrics = update_fn(
                repertoire, emitter_state, random_key
            )
        else:
            metrics = metrics_fn(repertoire)
            metrics = jax.tree_map(lambda x: jnp.ravel(x[None]), metrics)

        current_step = i * config.batch_size

        if i % config.log_freq == 0:

            end = time.time()

            logger.info(f"Number of steps: {current_step:.0f}")
            logger.info(
                f"Time for {config.log_freq:.0f} iterations: {end - start:.2f}s"
            )
            start = end

            print(metrics)

        if i % config.log_image_freq == 0:
            plot_repertoire = plot_repertoire.add(
                repertoire.genotypes, repertoire.descriptors, repertoire.fitnesses
            )
            z = (
                jax.nn.one_hot(plot_repertoire.genotypes, config.num_category).reshape(
                    (-1,) + config.latent_map_size + (config.num_category,)
                )
                @ codebook
            )
            images = jnp.clip(decoder_apply(decoder_params, z) + 0.5, 0, 1)
            images = jnp.where(
                plot_repertoire.fitnesses[:, None, None, None] == -jnp.inf, 1.0, images
            )
            images = jax.image.resize(
                images, (images.shape[0], 128, 128, 3), method="bilinear"
            )
            images = images.astype(jnp.float32)
            fig, axes = plot_grid_clip(
                images=images,
                positions=plot_repertoire.centroids,
                descriptors_range=config.clip_descriptors_range,
                scoring_text=config.clip_scoring_text,
                descriptors_text=config.clip_descriptors_text,
            )

            plt.close(fig)
            # fig.savefig(
            # os.path.join(config.result_folder, "samples_random.png"), transparent=True
            # )

    logger.info(f"Elapsed time: {time.time()-intial_time:.0f} s")


if __name__ == "__main__":

    # with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    main()
