"""
Jax implementation of lava. With bidirectional encoder and bidirectional decoder
"""

from typing import Dict
from functools import partial
from flax import linen as nn
from jax import numpy as jnp
import torch
import jax
from latte_trans.config import LMTaskConfig

from latte_trans.models.modules.pretrain.gemma import GemmaDecoder
from latte_trans.models.modules.pretrain.gemma_mach import (
    GemmaDecoder as GemmaMachiattoDecoder,
)
from latte_trans.models.modules.layers import Embedder
from latte_trans.evals.losses import cross_entropy_loss_lm
from .pretrain_utils import apply_trans, do_transpose

MAPPING_LLAMA = {
    "vision_model.CLIPVisionEmbedding_0.class_embedding": "vision_tower.vision_model.embeddings.class_embedding",
    "vision_model.CLIPVisionEmbedding_0.patch_embedding": (
        "vision_tower.vision_model.embeddings.patch_embedding.weight",
        lambda x: x.transpose(2, 3, 0, 1),
    ),
    "vision_model.CLIPVisionEmbedding_0.position_embedding.embedding": "vision_tower.vision_model.embeddings.position_embedding.weight",
    # "vision_tower.vision_model.pre_layrnorm.weight",
    # "vision_tower.vision_model.pre_layrnorm.bias",
    # "vision_tower.vision_model.encoder.layers.{x}.self_attn.k_proj.weight",
    # "vision_tower.vision_model.encoder.layers.{x}.self_attn.k_proj.bias",
    # "vision_tower.vision_model.encoder.layers.{x}.self_attn.v_proj.weight",
    # "vision_tower.vision_model.encoder.layers.{x}.self_attn.v_proj.bias",
    # "vision_tower.vision_model.encoder.layers.{x}.self_attn.q_proj.weight",
    # "vision_tower.vision_model.encoder.layers.{x}.self_attn.q_proj.bias",
    # "vision_tower.vision_model.encoder.layers.{x}.self_attn.out_proj.weight",
    # "vision_tower.vision_model.encoder.layers.{x}.self_attn.out_proj.bias",
    # "vision_tower.vision_model.encoder.layers.{x}.layer_norm1.weight",
    # "vision_tower.vision_model.encoder.layers.{x}.layer_norm1.bias",
    # "vision_tower.vision_model.encoder.layers.{x}.mlp.fc1.weight",
    # "vision_tower.vision_model.encoder.layers.{x}.mlp.fc1.bias",
    # "vision_tower.vision_model.encoder.layers.{x}.mlp.fc2.weight",
    # "vision_tower.vision_model.encoder.layers.{x}.mlp.fc2.bias",
    # "vision_tower.vision_model.encoder.layers.{x}.layer_norm2.weight",
    # "vision_tower.vision_model.encoder.layers.{x}.layer_norm2.bias",
    # "multi_modal_projector.linear_1.weight",
    # "multi_modal_projector.linear_1.bias",
    # "multi_modal_projector.linear_2.weight",
    # "multi_modal_projector.linear_2.bias",
    # "model.embed_tokens.weight",
    # "model.layers.{x}.self_attn.q_proj.weight",
    # "model.layers.{x}.self_attn.k_proj.weight",
    # "model.layers.{x}.self_attn.v_proj.weight",
    # "model.layers.{x}.self_attn.o_proj.weight",
    # "model.layers.{x}.mlp.gate_proj.weight",
    # "model.layers.{x}.mlp.up_proj.weight",
    # "model.layers.{x}.mlp.down_proj.weight",
    # "model.layers.{x}.input_layernorm.weight",
    # "model.layers.{x}.post_attention_layernorm.weight",
    # "model.norm.weight",
    # "model.lm_head.weight",
}

MAPPING_GEMMA = {
    "model_embed.embedding": "embed_tokens.weight",
    "model.residual_block_{x}.self_attn.out_proj.kernel": "layers.{x}.self_attn.o_proj.weight",
    "model.residual_block_{x}.self_attn.q_proj.kernel": "layers.{x}.self_attn.q_proj.weight",
    "model.residual_block_{x}.self_attn.k_proj.kernel": "layers.{x}.self_attn.k_proj.weight",
    "model.residual_block_{x}.self_attn.v_proj.kernel": "layers.{x}.self_attn.v_proj.weight",
    "model.residual_block_{x}.mlp.down_proj.kernel": "layers.{x}.mlp.down_proj.weight",
    "model.residual_block_{x}.mlp.gate_proj.kernel": "layers.{x}.mlp.gate_proj.weight",
    "model.residual_block_{x}.mlp.up_proj.kernel": "layers.{x}.mlp.up_proj.weight",
    "model.residual_block_{x}.input_layernorm.scale": "layers.{x}.input_layernorm.weight",
    "model.residual_block_{x}.post_attention_layernorm.scale": "layers.{x}.post_attention_layernorm.weight",
    "model.residual_block_{x}.pre_feedforward_layernorm.scale": "layers.{x}.pre_feedforward_layernorm.weight",
    "model.residual_block_{x}.post_feedforward_layernorm.scale": "layers.{x}.post_feedforward_layernorm.weight",
    "model.norm.scale": "norm.weight",
    # "lm_head.kernel": "lm_head.weight",
}


MAPPING_GEMMA2 = {
    "model_embed.embedding": "embed_tokens.weight",
    "model.residual_block.self_attn.out_proj.kernel": "layers.{x}.self_attn.o_proj.weight",
    "model.residual_block.self_attn.q_proj.kernel": "layers.{x}.self_attn.q_proj.weight",
    "model.residual_block.self_attn.k_proj.kernel": "layers.{x}.self_attn.k_proj.weight",
    "model.residual_block.self_attn.v_proj.kernel": "layers.{x}.self_attn.v_proj.weight",
    "model.residual_block.mlp.down_proj.kernel": "layers.{x}.mlp.down_proj.weight",
    "model.residual_block.mlp.gate_proj.kernel": "layers.{x}.mlp.gate_proj.weight",
    "model.residual_block.mlp.up_proj.kernel": "layers.{x}.mlp.up_proj.weight",
    "model.residual_block.input_layernorm.scale": "layers.{x}.input_layernorm.weight",
    "model.residual_block.post_attention_layernorm.scale": "layers.{x}.post_attention_layernorm.weight",
    "model.residual_block.pre_feedforward_layernorm.scale": "layers.{x}.pre_feedforward_layernorm.weight",
    "model.residual_block.post_feedforward_layernorm.scale": "layers.{x}.post_feedforward_layernorm.weight",
    "model.norm.scale": "norm.weight",
    # "lm_head.kernel": "lm_head.weight",
}


class Gemma(nn.Module):
    config: LMTaskConfig
    pad_id: int
    dtype: jnp.dtype = jnp.float32
    sharded: bool = False
    final_logit_softcapping: float = 30.0

    @staticmethod
    def load_pret(pt_state, params):
        """Loading for models in which sequence of layers is not implemented with scan (just an array)"""
        seen = set()
        for key_s in params:
            tmp, torch_key = apply_trans(MAPPING_GEMMA, key_s, pt_state)
            seen.add(torch_key)
            if tmp is None:
                print(f"warning, {key_s} has no pretrained mapping")
                continue
            if isinstance(params[key_s], nn.Partitioned):
                orig = params[key_s].value
                tmp = jax.device_put(jnp.asarray(tmp), orig.sharding)
                tmp = params[key_s].replace(value=tmp)
            else:
                orig = params[key_s]
                tmp = jax.device_put(jnp.asarray(tmp), orig.sharding)
            params[key_s] = tmp
        return params

    @staticmethod
    def apply_trans(MAPPING, key_s, pt_state, layer_nr=None):
        get_torch_layer = lambda k, nr: (
            k.replace(".{x}.", f".{nr}.") if nr is not None else k
        )
        if not key_s in MAPPING:
            return None, None
        y = MAPPING[key_s]
        transpose = do_transpose(key_s)

        if transpose:
            torch_key = get_torch_layer(y, layer_nr)
            tmp = pt_state[torch_key]
            tmp = tmp.T
        else:
            torch_key = get_torch_layer(y, layer_nr)
            tmp = pt_state[torch_key]
        return tmp, torch_key

    @classmethod
    def load_pret_scan(cls, pt_state, params):
        """Load from base pre-trained model, when we use scan"""
        seen = set()
        for key_s in params:
            if "residual_block" in key_s:
                # 0'th dimension is the layers collapsed
                if isinstance(params[key_s], nn.Partitioned):
                    nr_layers = params[key_s].value.shape[0]
                    value = params[key_s].value
                else:
                    nr_layers = params[key_s].shape[0]
                    value = params[key_s]
                orig_sharding = value.sharding
                for layer_nr in range(nr_layers):
                    tmp, torch_key = Gemma.apply_trans(
                        MAPPING_GEMMA2, key_s, pt_state, layer_nr=layer_nr
                    )

                    if tmp is not None:
                        seen.add(torch_key)
                        value = value.at[layer_nr].set(jnp.asarray(tmp))
                    else:
                        print(f"warning, {key_s} has no pretrained mapping")
                        continue
                value = jax.device_put(value, orig_sharding)
                if isinstance(params[key_s], nn.Partitioned):
                    params[key_s] = params[key_s].replace(value=value)
                else:
                    params[key_s] = value
            else:
                tmp, torch_key = Gemma.apply_trans(
                    MAPPING_GEMMA2, key_s, pt_state, layer_nr=None
                )
                if tmp is None:
                    print(f"warning, {key_s} has no pretrained mapping")
                    continue

                seen.add(torch_key)
                if isinstance(params[key_s], nn.Partitioned):
                    orig = params[key_s].value
                    tmp = jax.device_put(jnp.asarray(tmp), orig.sharding)
                    tmp = params[key_s].replace(value=tmp)
                else:
                    orig = params[key_s]
                    tmp = jax.device_put(jnp.asarray(tmp), orig.sharding)
                params[key_s] = tmp

        return params

    def lm_head(self, x, embeds):
        return x @ embeds.embedding.T

    @nn.compact
    def __call__(
        self,
        input_ids: jax.Array,
        # attention_mask: jax.Array,
        labels: jax.Array = None,
        train: bool = False,
    ) -> Dict[str, jnp.array]:
        """
        Args:
            input_ids: jnp.array(BL) - input ids
            labels: jnp.array(BL)
            train: bool - used for dropout
        Returns:
            out: Dict[str, jnp.array] - loss and logits
        """

        # text_embed = nn.Embed(
        #     num_embeddings=self.config.text_vocab_size,
        #     features=self.config.hidden_dim,
        #     dtype=self.dtype,
        #     embedding_init=jax.nn.initializers.normal(
        #         stddev=self.config.initializer_range
        #     ),
        #     name="model_embed",
        # )

        text_embed = Embedder(
            vocab_size=self.config.text_vocab_size,
            embed_dim=self.config.hidden_dim,
            scale_by_sqrt_dim=False,
            dtype=self.dtype,
            name="model_embed",
        )

        match self.config.attention_type:
            case "standard_causal":
                decoder = GemmaDecoder(
                    self.config,
                    sharded=self.sharded,
                    dtype=self.dtype,
                    name="model",
                )
            case "latte_mach_sliding_causal":
                decoder = GemmaMachiattoDecoder(
                    self.config,
                    sharded=self.sharded,
                    dtype=self.dtype,
                    name="model",
                )
            case _:
                raise Exception("Attention type not implemented for Gemma ")

        attention_mask = None
        # print(num_patches)
        input_embeds = text_embed.encode(input_ids)

        normalizer = jnp.array(self.config.hidden_dim**0.5, dtype=input_embeds.dtype)
        input_embeds = input_embeds * normalizer

        logits = decoder(input_embeds, attention_mask=None, train=train)
        # jax.debug.print("logits {}", jnp.any(jnp.isnan(logits)))
        logits = text_embed.decode(logits)  # self.lm_head(logits, text_embed)
        if self.final_logit_softcapping is not None:
            logits = logits / self.final_logit_softcapping
            logits = jax.nn.tanh(logits)
            logits = logits * self.final_logit_softcapping

        labels = jnp.where(labels == self.pad_id, -100, labels)
        if labels is None:
            return {"logits": logits, "loss": None}

        loss = cross_entropy_loss_lm(
            logits=logits[:, :-1, :], target=labels[:, 1:], ignore_index=-100
        )
        return {"loss": loss, "logits": logits}
