"""
Just the language modelling head which is compatible with caching
"""

from typing import Dict, List
from flax import linen as nn
from jax import numpy as jnp
import jax
from tqdm import tqdm
from latte_trans.config import LMTaskConfig
from .gemma_mach import GemmaDecoder as GemmaMachiattoDecoder, GemmaMachiattoCache
from latte_trans.models.modules.layers import Embedder
from latte_trans.evals.losses import cross_entropy_loss_lm

GemmaCache = List[GemmaMachiattoCache]


class Gemma(nn.Module):
    config: LMTaskConfig
    pad_id: int
    dtype: jnp.dtype = jnp.float32
    sharded: bool = False
    final_logit_softcapping: float = 30.0

    @nn.compact
    def __call__(
        self,
        input_ids: jax.Array,
        labels: jax.Array = None,
        train: bool = False,
        do_inference: bool = False,
        cache: GemmaCache = None,
    ) -> Dict[str, jnp.array]:
        """
        Args:
            input_ids: jnp.array(BL) - input ids
            labels: jnp.array(BL)
            train: bool - used for dropout
        Returns:
            out: Dict[str, jnp.array] - loss and logits
        """

        text_embed = Embedder(
            vocab_size=self.config.text_vocab_size,
            embed_dim=self.config.hidden_dim,
            scale_by_sqrt_dim=False,
            dtype=self.dtype,
            name="model_embed",
        )

        decoder = GemmaMachiattoDecoder(
            self.config,
            sharded=self.sharded,
            dtype=self.dtype,
            name="model",
        )

        input_embeds = text_embed.encode(input_ids)

        normalizer = jnp.array(self.config.hidden_dim**0.5, dtype=input_embeds.dtype)
        input_embeds = input_embeds * normalizer

        logits, cache = decoder(
            input_embeds,
            attention_mask=None,
            train=train,
            do_inference=do_inference,
            cache=cache,
        )
        # jax.debug.print("logits {}", jnp.any(jnp.isnan(logits)))
        logits = text_embed.decode(logits)  # self.lm_head(logits, text_embed)
        if self.final_logit_softcapping is not None:
            logits = logits / self.final_logit_softcapping
            logits = jax.nn.tanh(logits)
            logits = logits * self.final_logit_softcapping

        if labels is None:
            return {"logits": logits, "loss": None, "cache": cache}

        loss = cross_entropy_loss_lm(
            logits=logits[:, :-1, :], target=labels[:, 1:], ignore_index=-100
        )
        return {"loss": loss, "logits": logits, "cache": cache}


class LanguageSampler:
    def __init__(self, compile, model, params):
        self.compile = compile
        self.model = model
        self.params = params
        self._compiled_sample_fn = jax.jit(
            self._sample_fn,
            donate_argnums=[1],
        )

    def _sample_fn(self, params, cache, X):
        out = self.model.apply(
            {"params": params},
            input_ids=X,
            labels=None,
            train=False,
            do_inference=True,
            cache=cache,
        )
        pred = jnp.argmax(out["logits"][:, -1, :], axis=-1)
        return {"pred": pred, "cache": out["cache"], "logits": out["logits"]}

    def generate(self, input_ids: jax.Array, gen_length):
        """
        Args:
            input_ids: jax.Array[B,T]
        """
        if self.compile:
            sample = self._compiled_sample_fn
        else:
            sample = self._sample_fn

        promt_len = input_ids.shape[1]
        # cache from promt
        cache = None
        logits = []
        jax.debug.print("--------", ordered=True)
        for i in range(1, promt_len - 1):
            X = input_ids[:, :i]
            out = sample(self.params, cache, X)
            cache = out["cache"]
            pred = out["pred"]
            logits.append(out["logits"][:, -1, :])

        logits_true = self.model.apply(
            {"params": self.params},
            input_ids=X,
            labels=None,
            train=False,
            do_inference=False,
            cache=None,
        )["logits"]
        jax.debug.print("Logits are {x}: ", x=logits_true)
        jax.debug.print("Logits are {x}: ", x=jnp.stack(logits, axis=1))
        # actual generation
        for i in tqdm(range(gen_length)):
            X = input_ids[:, : promt_len + i]
            out = sample(self.params, cache, X)
            cache = out["cache"]
            pred = out["pred"]
            input_ids = jnp.concatenate([input_ids, pred[:, None]], axis=1)

        return {"promt": X[:, :promt_len], "generated": X[:, promt_len:]}
