from typing import Dict, Any
import abc
from tqdm import tqdm
from flax import linen as nn
from jax import numpy as jnp
import jax

from latte_trans.config import Config
from latte_trans.evals.losses import cross_entropy_loss_lm, cross_entropy_loss
from latte_trans.models.modules.seq_layers import Decoder
from latte_trans.models.modules.sota_seq_layers import DecoderSota
from latte_trans.models.modules.qual_seq_layers import QualDecoder
from latte_trans.models.modules.pretrain.gemma import GemmaDecoder
from latte_trans.models.modules.pretrain.gemma_mach import (
    GemmaDecoder as GemmaMachiattoDecoder,
)
from recurrentgemma import jax as recurrentgemma

# from latte_trans.models.modules.seq_layers_sharded import Decoder as ShardedDecoder


class LMHeadVanilla(nn.Module):
    config: Config
    vocab_size: int
    pad_id: int
    dtype: jnp.dtype = jnp.float32
    sharded: bool = False

    @nn.compact
    def __call__(
        self,
        input_ids: jnp.array,
        labels: jnp.array = None,
        train: bool = False,
        cache: Dict[str, Any] = None,
        do_inference: bool = False,
    ) -> Dict[str, jnp.array]:
        """
        Args:
            input_ids: jnp.array(BL) - input ids
            labels: jnp.array(BL)
            train: bool - used for dropout
        Returns:
            out: Dict[str, jnp.array] - loss and logits
        """
        encoder = Decoder(
            vocab_size=self.vocab_size,
            config=self.config,
            dtype=self.dtype,
            sharded=self.sharded,
        )

        # head = nn.Dense(
        #     self.vocab_size,
        #     dtype=self.dtype,
        #     use_bias=False,
        #     # kernel_init=jax.nn.initializers.normal(stddev=0.02),
        # )

        logits = encoder(
            input_ids, train=train, do_inference=do_inference, cache=cache
        )  # BLH

        out_cache = None
        if do_inference:
            out_cache, X = X

        if labels is None:
            return {"logits": logits, "cache": out_cache, "loss": None}

        # ignore pad tokens
        # labels = labels[:, 1:]
        loss = cross_entropy_loss_lm(
            logits=logits[:, :-1, :], target=labels[:, 1:], ignore_index=-100
        )

        return {"loss": loss, "logits": logits, "cache": out_cache}


class LMHeadSOTA(nn.Module):
    """Use state of the art improvemnt to the tranformer architecurte: Conv, RMSPROP, etc"""

    config: Config
    vocab_size: int
    pad_id: int
    dtype: jnp.dtype = jnp.float32
    sharded: bool = False

    @nn.compact
    def __call__(
        self,
        input_ids: jnp.array,
        labels: jnp.array = None,
        train: bool = False,
        cache: Dict[str, Any] = None,
        do_inference: bool = False,
    ) -> Dict[str, jnp.array]:
        """
        Args:
            input_ids: jnp.array(BL) - input ids
            labels: jnp.array(BL)
            train: bool - used for dropout
        Returns:
            out: Dict[str, jnp.array] - loss and logits
        """
        encoder = DecoderSota(
            vocab_size=self.vocab_size,
            config=self.config,
            dtype=self.dtype,
            sharded=self.sharded,
        )
        # head = nn.Dense(
        #     self.vocab_size,
        #     dtype=self.dtype,
        #     kernel_init=jax.nn.initializers.normal(
        #         stddev=self.config.initializer_range
        #     ),
        #     # use_bias=False,
        # )

        logits = encoder(
            input_ids, train=train, do_inference=do_inference, cache=cache
        )  # BLH

        # logits = head(logits)
        # logits = self.lm_head(logits, text_embed)
        if labels is None:
            return {"logits": logits, "loss": None}

        # ignore pad tokens
        # labels = labels[:, 1:]
        loss = cross_entropy_loss_lm(
            logits=logits[:, :-1, :], target=labels[:, 1:], ignore_index=-100
        )
        return {"loss": loss, "logits": logits}


class LMHeadQual(nn.Module):
    """Use state of the art improvemnt to the tranformer architecurte: Conv, RMSPROP, etc"""

    config: Config
    vocab_size: int
    pad_id: int
    dtype: jnp.dtype = jnp.float32
    sharded: bool = False

    @nn.compact
    def __call__(
        self,
        input_ids: jnp.array,
        labels: jnp.array = None,
        train: bool = False,
        cache: Dict[str, Any] = None,
        do_inference: bool = False,
    ) -> Dict[str, jnp.array]:
        """
        Args:
            input_ids: jnp.array(BL) - input ids
            labels: jnp.array(BL)
            train: bool - used for dropout
        Returns:
            out: Dict[str, jnp.array] - loss and logits
        """
        encoder = QualDecoder(
            vocab_size=self.vocab_size,
            config=self.config,
            dtype=self.dtype,
            sharded=self.sharded,
        )
        head = nn.Dense(
            self.vocab_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            use_bias=False,
        )

        logits = encoder(
            input_ids, train=train, do_inference=do_inference, cache=cache
        )  # BLH

        logits = head(logits)
        if labels is None:
            return {"logits": logits, "loss": None}

        # ignore pad tokens
        # labels = labels[:, 1:]
        loss = cross_entropy_loss_lm(
            logits=logits[:, :-1, :], target=labels[:, 1:], ignore_index=-100
        )
        return {"loss": loss, "logits": logits}


class RecurrentGemmaWrapper(nn.Module):
    """
    Useed for the reuccurent gemma repo.
    Wrapper for models like RecurrentGemma, Griffin, etc
    """

    config: recurrentgemma.GriffinConfig
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        input_ids: jnp.array,
        labels: jnp.array = None,
        train: bool = False,
        cache: Dict[str, Any] = None,
        do_inference: bool = False,
    ) -> Dict[str, jnp.array]:

        base_model = recurrentgemma.Griffin(config=self.config, dtype=self.dtype)
        batch_size, sequence_length = input_ids.shape
        pos = jnp.repeat(jnp.arange(sequence_length)[None], batch_size, axis=0)

        logits, _ = base_model(tokens=input_ids, segment_pos=pos, return_cache=False)

        if labels is None:
            return {"logits": logits, "loss": None}

        # ignore pad tokens
        # labels = labels[:, 1:]
        loss = cross_entropy_loss_lm(
            logits=logits[:, :-1, :], target=labels[:, 1:], ignore_index=-100
        )
        return {"loss": loss, "logits": logits}


class Gemma(nn.Module):
    config: Config
    vocab_size: int
    pad_id: int
    dtype: jnp.dtype = jnp.float32
    sharded: bool = False
    final_logit_softcapping: float = 30.0

    def lm_head(self, x, embeds):
        return x @ embeds.embedding.T

    @nn.compact
    def __call__(
        self,
        input_ids: jax.Array,
        # attention_mask: jax.Array,
        labels: jax.Array = None,
        train: bool = False,
    ) -> Dict[str, jnp.array]:
        """
        Args:
            input_ids: jnp.array(BL) - input ids
            labels: jnp.array(BL)
            train: bool - used for dropout
        Returns:
            out: Dict[str, jnp.array] - loss and logits
        """

        text_embed = nn.Embed(
            num_embeddings=self.vocab_size,
            features=self.config.hidden_dim,
            dtype=self.dtype,
            embedding_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="model_embed",
        )

        match self.config.attention_type:
            case "standard_causal":
                decoder = GemmaDecoder(
                    self.config,
                    sharded=self.sharded,
                    dtype=self.dtype,
                    name="model",
                )
            case "latte_mach_sliding_causal":
                decoder = GemmaMachiattoDecoder(
                    self.config,
                    sharded=self.sharded,
                    dtype=self.dtype,
                    name="model",
                )
            case _:
                raise Exception("Attention type not implemented for Gemma ")

        attention_mask = None
        # print(num_patches)
        input_embeds = text_embed(input_ids)

        normalizer = jnp.array(self.config.hidden_dim**0.5, dtype=input_embeds.dtype)
        input_embeds = input_embeds * normalizer

        logits = decoder(input_embeds, attention_mask=None, train=train)
        # jax.debug.print("logits {}", jnp.any(jnp.isnan(logits)))
        logits = self.lm_head(logits, text_embed)
        if self.final_logit_softcapping is not None:
            logits = logits / self.final_logit_softcapping
            logits = jax.nn.tanh(logits)
            logits = logits * self.final_logit_softcapping

        if labels is None:
            return {"logits": logits, "loss": None}

        loss = cross_entropy_loss_lm(
            logits=logits[:, :-1, :], target=labels[:, 1:], ignore_index=-100
        )
        return {"loss": loss, "logits": logits}
