import logging
import os
import pickle
import sys
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import haiku as hk
import hydra
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

# from models.clip import load_clip
from models.clip.download_clip import load as load_clip
from models.vq_vae import Decoder, Encoder, VQVAEModel
from qdax.core.containers.mapelites_repertoire import (
    MapElitesRepertoire,
    compute_euclidean_centroids,
)
from qdax.core.emitters.emitter import EmitterState
from qdax.core.emitters.omg_mega_emitter import OMGMEGAEmitter
from qdax.core.map_elites import MAPElites
from qdax.types import Descriptor, Fitness, Genotype, Metrics, RNGKey
from scorer import CLIPData, VQVAEdata, build_scoring_fn
from utils import metrics_fn, plot_grid_clip

default_device = "cpu"
jax.config.update("jax_platform_name", default_device)


@dataclass
class ExperimentConfig:
    batch_size: int
    num_iterations: int
    num_descriptors: int
    num_centroids: int
    descriptor_type: str
    scorer_type: str
    num_category: int
    seed: int
    num_steps: int

    log_freq: int
    log_image_freq: int
    num_plot_centroids_per_ax: Tuple[int]

    # dataset config
    image_size: Tuple[int, int, int]
    init_random: bool

    # GDP config
    sigma_diag: float

    # vqvae config
    num_hiddens: int
    num_residual_hiddens: int
    num_residual_layers: int
    latent_map_size: Tuple[int, int]

    # clip params
    clip_descriptors_text: Optional[List[str]]
    clip_scoring_text: Optional[str]
    clip_descriptors_range: Optional[List[int]]


def projection_fn(x: Genotype) -> Genotype:
    return (x == x.max(axis=-1, keepdims=True)).astype(jnp.float32)


@hydra.main(config_path="config", config_name="gdp_me")
def main(config: ExperimentConfig) -> None:
    assert config.dataset in ["CIFAR-10", "Imagenet"]

    script_dir = sys.path[0]
    os.makedirs("images", exist_ok=True)
    # Setup logging
    logging.basicConfig(level=logging.DEBUG)
    logging.getLogger().handlers[0].setLevel(logging.INFO)
    logger = logging.getLogger(f"{__name__}")
    random_key = jax.random.PRNGKey(config.seed)

    print(f"Setting default device to {default_device}")
    devices = jax.devices("tpu")[1:]
    print(f"TPU devices: {devices}")

    # objects_dir = os.path.join(script_dir, "objects")
    models_dir = os.path.join(script_dir, "model_weights")
    latent_map_full_size = int(jnp.prod(jnp.array(config.latent_map_size)))
    with open(
        os.path.join(
            models_dir,
            "vqvae",
            f"{config.dataset.lower()}",
            f"vq-vae-{config.num_category}-{latent_map_full_size}.pkl",
        ),
        "rb",
    ) as f:
        decoder_params, state = pickle.load(f)
    codebook = state["vector_quantizer_ema"]["embeddings"].T

    clip_data = CLIPData(
        image_apply=None,
        text_apply=None,
        similarity_score=None,
        scoring_text=config.clip_scoring_text,
        descriptors_text=config.clip_descriptors_text,
    )

    if config.descriptor_type == "clip" or config.scorer_type == "clip":

        if os.path.exists(os.path.join(models_dir, "clip", "ViT-B-32.pt")):
            image_fn, text_fn, similarity_score, clip_params, _ = load_clip(
                os.path.join(models_dir, "clip", "ViT-B-32.pt")
            )

        else:
            image_fn, text_fn, similarity_score, clip_params, _ = load_clip(
                "ViT-B/32", "cpu"
            )
        clip_data.image_apply = image_fn
        clip_data.text_apply = text_fn
        clip_data.similarity_score = similarity_score
        num_img = int(np.sqrt(config.num_centroids) // 1)
        if config.descriptor_type == "clip":
            assert config.clip_descriptors_range is not None
            descriptor_params = clip_params
            random_key, subkey = jax.random.split(random_key)
            centroids = compute_euclidean_centroids(
                grid_shape=[num_img] * 2,
                minval=config.clip_descriptors_range[:1] * 2,
                maxval=config.clip_descriptors_range[1:] * 2,
            )

        if config.scorer_type == "clip":
            scorer_params = clip_params

    def decoder_fn(x: jnp.ndarray) -> jnp.ndarray:
        d = Decoder(
            config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens
        )
        return d(x)

    def vq_vae_fn(data: jnp.ndarray) -> jnp.ndarray:
        encoder = Encoder(
            config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens
        )
        decoder = Decoder(
            config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens
        )
        pre_vq_conv1 = hk.Conv2D(
            output_channels=64,
            kernel_shape=(1, 1),
            stride=(1, 1),
            name="to_vq",
        )

        vq_vae = hk.nets.VectorQuantizerEMA(
            embedding_dim=64,
            num_embeddings=config.num_category,
            commitment_cost=0.2,
            decay=0.99,
        )

        model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1, data_variance=1)
        return model(data, False)

    vq_vae = hk.without_apply_rng(hk.transform_with_state(vq_vae_fn))
    decoder_apply = hk.without_apply_rng(hk.transform(decoder_fn)).apply
    vqvae_data = VQVAEdata(
        codebook=codebook,
        decoder_params=decoder_params,
        decoder_apply=decoder_apply,
        latent_map_size=config.latent_map_size,
    )

    scoring_fun = build_scoring_fn(
        scorer_type=config.scorer_type,
        descriptor_type=config.descriptor_type,
        scorer_params=scorer_params,
        descriptor_params=descriptor_params,
        vqvae=vqvae_data,
        clip_data=clip_data,
        compute_gradients=True,
    )

    def scoring_fn(
        x: Genotype, random_key: RNGKey
    ) -> Tuple[Fitness, Descriptor, Dict[str, jnp.ndarray], RNGKey]:
        fitness, descriptors, extra_scores, _ = scoring_fun(x)
        return fitness, descriptors, extra_scores, random_key

    random_key, subkey = jax.random.split(random_key)
    random_code = jax.random.randint(
        random_key,
        minval=0,
        maxval=config.num_category,
        shape=(config.batch_size,) + config.latent_map_size,
    )

    # random_code = jax.nn.one_hot(random_code, config.num_category)
    if config.init_random:
        initial_population = random_code.reshape(
            (config.batch_size, -1),
        )

    else:

        raise NotImplementedError()

    emitter = OMGMEGAEmitter(
        batch_size=config.batch_size,
        sigma_g=config.sigma_diag,
        num_descriptors=config.num_descriptors,
        centroids=centroids,
        projection_fn=projection_fn,
        num_category=config.num_category,
        embedding_matrix=codebook,
    )

    map_elites = MAPElites(
        scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn
    )

    repertoire, emitter_state, random_key = map_elites.get_distributed_init_fn_v2(
        centroids=centroids,
        devices=devices,
    )(init_genotypes=initial_population, random_key=random_key)

    update_fn = map_elites.get_distributed_update_fn_v2(
        num_iterations=1, devices=devices
    )

    start = time.time()
    intial_time = start
    current_step = 0
    num_loops = config.num_iterations

    if config.descriptor_type == "clip":
        assert config.clip_descriptors_range is not None
        assert config.clip_descriptors_text is not None
        assert config.clip_scoring_text is not None
        plot_centroids = compute_euclidean_centroids(
            grid_shape=config.num_plot_centroids_per_ax,
            minval=config.clip_descriptors_range[:1] * 2,
            maxval=config.clip_descriptors_range[1:] * 2,
        )
        num_plot_centroids = np.prod(config.num_plot_centroids_per_ax) // 1

        plot_repertoire = MapElitesRepertoire.init(
            genotypes=jnp.zeros((num_plot_centroids,) + initial_population.shape[1:]),
            fitnesses=jnp.ones(
                num_plot_centroids,
            )
            * -jnp.inf,
            descriptors=jnp.ones((num_plot_centroids, 2)),
            centroids=plot_centroids,
        )

    for i in range(num_loops + 1):

        if i > 0:
            repertoire, emitter_state, random_key, metrics = update_fn(
                repertoire, emitter_state, random_key
            )
        else:
            metrics = metrics_fn(repertoire)
            metrics = jax.tree_map(lambda x: jnp.ravel(x[None]), metrics)

        current_step = i * config.batch_size

        if i % config.log_freq == 0:

            end = time.time()

            logger.info(f"Number of steps: {current_step:.0f}")
            logger.info(
                f"Time for {config.log_freq:.0f} iterations: {end - start:.2f}s"
            )
            start = end
            print(metrics)

        if i % config.log_image_freq == 0:
            plot_repertoire = plot_repertoire.add(
                repertoire.genotypes, repertoire.descriptors, repertoire.fitnesses
            )
            z = (
                jax.nn.one_hot(plot_repertoire.genotypes, config.num_category).reshape(
                    (-1,) + config.latent_map_size + (config.num_category,)
                )
                @ codebook
            )
            images = jnp.clip(decoder_apply(decoder_params, z) + 0.5, 0, 1)
            images = jnp.where(
                plot_repertoire.fitnesses[:, None, None, None] == -jnp.inf, 1.0, images
            )
            images = jax.image.resize(
                images, (images.shape[0], 128, 128, 3), method="bilinear"
            )
            images = images.astype(jnp.float32)
            fig, axes = plot_grid_clip(
                images=images,
                positions=plot_repertoire.centroids,
                descriptors_range=config.clip_descriptors_range,
                scoring_text=config.clip_scoring_text,
                descriptors_text=config.clip_descriptors_text,
            )

            plt.close(fig)


if __name__ == "__main__":
    main()
