from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Tuple

import haiku as hk
import jax
import jax.numpy as jnp
from flax.struct import PyTreeNode

from models.clip.download_clip import tokenize
from models.vq_vae import Decoder
from qdax.types import Descriptor, Fitness, Genotype, Gradient, Params


class VQVAEdata(PyTreeNode):
    decoder_apply: Callable
    decoder_params: Params
    codebook: jnp.ndarray
    latent_map_size: int


@dataclass
class CLIPData:
    image_apply: Callable
    text_apply: Callable
    similarity_score: Callable
    scoring_text: Optional[str]
    descriptors_text: Optional[List[str]]


def prepare_image(x: jnp.ndarray) -> jnp.ndarray:
    """Transforms an image outputed by the vqvae's decoder to an image ready to be fed
    in BiT.

    Args:
        x: image

    Returns:
        the transformed image
    """

    x = jnp.clip((x + 0.5) * 255, 0, 255)
    x = jax.image.resize(x, (128, 128, 3), method="bilinear")
    x = (x - 127.5) / 127.5
    return x


def build_scoring_fn(
    descriptor_params: Params,
    scorer_params: Params,
    vqvae: VQVAEdata,
    compute_gradients: bool,
    descriptor_type: Optional[str] = "clip",
    scorer_type: Optional[str] = "clip",
    clip_data: Optional[CLIPData] = None,
) -> Callable[[Genotype], Tuple[Fitness, Descriptor, Gradient, Any]]:

    num_category = vqvae.codebook.shape[0]

    def clip_compute_image_embedding(x: Genotype) -> jnp.ndarray:
        x = x.reshape(vqvae.latent_map_size + x.shape[1:])

        image = vqvae.decoder_apply(vqvae.decoder_params, x)
        image = jnp.clip(image, -0.5, 0.5)

        image = image[None]
        image = jax.image.resize(
            image,
            shape=(1, 224, 224, image.shape[-1]),
            method="bilinear",
        ).transpose(0, 3, 1, 2)
        return clip_data.image_apply(scorer_params, image)

    if scorer_type == "clip" and descriptor_type == "clip":

        scoring_text_embedding = clip_data.text_apply(
            scorer_params, tokenize(clip_data.scoring_text)
        )

        descriptor_text_embeddings = [
            clip_data.text_apply(descriptor_params, tokenize(text))
            for text in clip_data.descriptors_text
        ]

        def forward_fn(x: Genotype) -> Any:
            image_embedding = clip_compute_image_embedding(x)
            similarity_fit = (
                10
                - clip_data.similarity_score(
                    scorer_params, image_embedding, scoring_text_embedding
                )[0]
            ) * 10  # TODO: remove manual scale

            similarity_desc = jnp.array(
                [
                    clip_data.similarity_score(
                        descriptor_params,
                        image_embedding,
                        descriptor_text_embedding,
                    )[0]
                    for descriptor_text_embedding in descriptor_text_embeddings
                ]
            ).squeeze(axis=-1)

            return jnp.concatenate([similarity_fit, similarity_desc]).squeeze()

        forward = jax.vmap(forward_fn)

        def grad(x: Genotype) -> Gradient:
            grads = jax.vmap(jax.jacrev(forward_fn))(x)
            return grads.transpose(0, 2, 3, 1)

    @jax.jit
    def scoring_fn(
        x: Genotype,
    ) -> Tuple[Fitness, Descriptor, Gradient, Any]:
        x = jax.nn.one_hot(x, num_category)
        x = x @ vqvae.codebook
        x = x.astype(jnp.float32)
        fitness_and_descriptors = forward(x)
        fitness_and_descriptors.astype(jnp.bfloat16)
        fitness = fitness_and_descriptors[..., 0]
        descriptors = fitness_and_descriptors[..., 1:]

        extra_scores = {}
        if compute_gradients:
            grads = grad(x)
            extra_scores["gradients"] = grads.astype(jnp.bfloat16)
        return fitness, descriptors, extra_scores, None

    return scoring_fn
