from typing import Dict
import math
import torch.nn as nn
import torch


class HuggWrapper(nn.Module):
    def __init__(self, base_model, pad_id):
        super().__init__()
        self._base_model = base_model
        self._pad_id = pad_id

    def forward(self, input_ids, labels):
        attention_mask = input_ids != self._pad_id
        output = self._base_model.forward(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )
        return {"loss": output.loss, "logits": output.logits}


class HuggWrapper2(nn.Module):
    """
    Ties embed and output weights.
    """

    def __init__(self, base_model, embed, pad_id):
        super().__init__()
        self._base_model = base_model
        self._input_embedding = embed
        self._pad_id = pad_id

    def lm_head(self, x):
        return x @ self._input_embedding.weight.T

    def forward(self, input_ids, labels):
        attention_mask = (input_ids != self._pad_id).int()
        output = self._base_model.forward(
            input_ids=input_ids,
            attention_mask=None,  # attention_mask,
            output_hidden_states=False,
        )
        logits = output.last_hidden_state
        logits = self.lm_head(logits)
        # print("New logits shape: ", logits.shape, labels.shape)
        if labels is None:
            return {"logits": logits}

        # num_patches
        loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
        loss = loss_fct(
            logits[:, :-1, :].reshape(-1, logits.size(-1)),
            labels[:, 1:].reshape(-1),
        )
        return {"loss": loss, "logits": logits}


MAPPING_GEMMA = {
    "model_embed.weight": "embed_tokens.weight",
    "model.residual_block_{x}.self_attn.o_proj.weight": "layers.{x}.self_attn.o_proj.weight",
    "model.residual_block_{x}.self_attn.q_proj.weight": "layers.{x}.self_attn.q_proj.weight",
    "model.residual_block_{x}.self_attn.k_proj.weight": "layers.{x}.self_attn.k_proj.weight",
    "model.residual_block_{x}.self_attn.v_proj.weight": "layers.{x}.self_attn.v_proj.weight",
    "model.residual_block_{x}.mlp.down_proj.weight": "layers.{x}.mlp.down_proj.weight",
    "model.residual_block_{x}.mlp.gate_proj.weight": "layers.{x}.mlp.gate_proj.weight",
    "model.residual_block_{x}.mlp.up_proj.weight": "layers.{x}.mlp.up_proj.weight",
    "model.residual_block_{x}.input_layernorm.weight": "layers.{x}.input_layernorm.weight",
    "model.residual_block_{x}.post_attention_layernorm.weight": "layers.{x}.post_attention_layernorm.weight",
    "model.residual_block_{x}.pre_feedforward_layernorm.weight": "layers.{x}.pre_feedforward_layernorm.weight",
    "model.residual_block_{x}.post_feedforward_layernorm.weight": "layers.{x}.post_feedforward_layernorm.weight",
    "model.norm.weight": "norm.weight",
    # "lm_head.weight": "lm_head.weight",
}

from latte_trans.models.modules.pretrain.gemma_pt import GemmaDecoder
import regex as re


def apply_trans(MAPPING, key_s, pt_state):
    get_torch_layer = lambda k, nr: k.replace(".{x}.", f".{nr}.") if nr else k
    """Expects flattened np_state and torch pt_state"""
    # transform layer nameing
    layer_nr = None
    if "residual_block" in key_s:
        res = re.search(r"residual_block.[0-9]+", key_s).group(0)
        layer_nr = res.split(".")[-1]
        mapping_key = key_s.replace(res, "residual_block_{x}")
    else:
        mapping_key = key_s

    if not mapping_key in MAPPING:
        return None, None
    y = MAPPING[mapping_key]
    torch_key = get_torch_layer(y, layer_nr)
    tmp = pt_state[torch_key]
    return tmp, torch_key


class Gemma(nn.Module):
    def __init__(self, config, pad_id: int):
        super().__init__()
        self.config = config
        self.pad_id = pad_id

        self.model_embed = nn.Embedding(
            num_embeddings=self.config.text_vocab_size,
            embedding_dim=self.config.hidden_dim,
            padding_idx=self.pad_id,
        )

        self.model = GemmaDecoder(
            self.config,
        )
        self.final_logit_softcapping = 30.0

    @staticmethod
    def _load(pt_state, params):
        for key_s in params:
            tmp, torch_key = apply_trans(MAPPING_GEMMA, key_s, pt_state)
            params[key_s] = tmp
        return params

    def get_causal(self, input_tensor, attention_mask=None):
        dtype, device = input_tensor.dtype, input_tensor.device
        batch_size = input_tensor.shape[0]
        min_dtype = torch.finfo(dtype).min  # dtype
        sequence_length = input_tensor.shape[1]
        target_length = input_tensor.shape[1]

        causal_mask = torch.full(
            (sequence_length, target_length),
            fill_value=min_dtype,  # -9e15,
            dtype=dtype,
            device=device,
        )
        if sequence_length != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)

        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
        if attention_mask is not None:
            causal_mask = (
                causal_mask.clone()
            )  # copy to contiguous memory for in-place edit
            mask_length = attention_mask.shape[-1]
            padding_mask = (
                causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
            )
            padding_mask = padding_mask == 0
            causal_mask[:, :, :, :mask_length] = causal_mask[
                :, :, :, :mask_length
            ].masked_fill(padding_mask, min_dtype)

        return causal_mask

    def lm_head(self, x, embeds):
        return x @ embeds.weight.T

    def forward(self, input_ids, labels) -> Dict[str, torch.tensor]:
        """
        Args:
            input_ids: jnp.array(BL) - input ids
            labels: jnp.array(BL)
            train: bool - used for dropout
        Returns:
            out: Dict[str, jnp.array] - loss and logits
        """

        attention_mask = None
        input_embeds = self.model_embed(input_ids)
        attention_mask = self.get_causal(input_embeds, attention_mask=attention_mask)

        normalizer = torch.tensor(self.config.hidden_dim**0.5, dtype=input_embeds.dtype)
        input_embeds = input_embeds * normalizer

        logits = self.model(input_embeds, attention_mask=attention_mask)
        # jax.debug.print("logits {}", jnp.any(jnp.isnan(logits)))
        logits = self.lm_head(logits, self.model_embed)

        if self.final_logit_softcapping is not None:
            logits = logits / self.final_logit_softcapping
            logits = torch.tanh(logits)
            logits = logits * self.final_logit_softcapping

        if labels is None:
            return {"logits": logits}

        loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
        loss = loss_fct(
            logits[:, :-1, :].reshape(-1, logits.size(-1)),
            labels[:, 1:].reshape(-1),
        )
        return {"loss": loss, "logits": logits}
