from typing import Dict
import math
import torch.nn as nn
import torch
from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.masking import LengthMask, TriangularCausalMask
from torchscale.architecture.retnet import RetNetDecoder

from latte_trans.models.modules.layers_pt import PositionalEncoding
from latte_trans.models.modules.seq_layers_pt import Decoder
from latte_trans.models.modules.sota_seq_layers_pt import Decoder as DecoderSota
from latte_trans.config import Config


class ZoologyWrapper(nn.Module):
    """Nice repo which has mant models like h3, Hyna, RWKV already implemented"""

    def __init__(self, base_model):
        super().__init__()
        self._base_model = base_model

    def forward(self, input_ids, labels):
        # attention_mask
        logits = self._base_model(input_ids=input_ids)

        if labels is not None:
            shift_labels = labels[:, 1:].contiguous()
            shift_logits = logits[:, :-1].contiguous()

            # Calculate per-token loss
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )

        return {"loss": loss, "logits": logits}
    
    
class HuggWrapper(nn.Module):
    def __init__(self, base_model, pad_id):
        super().__init__()
        self._base_model = base_model
        self._pad_id = pad_id

    def forward(self, input_ids, labels):
        attention_mask = torch.ones_like(input_ids)  # None# input_ids != self._pad_id
        output = self._base_model(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )
        return {"loss": output.loss, "logits": output.logits}


class MambaWrapper(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self._base_model = base_model

    def forward(self, input_ids, labels):
        # attention_mask
        output = self._base_model(input_ids=input_ids)

        logits = output.logits
        if labels is not None:
            shift_labels = labels[:, 1:].contiguous()
            shift_logits = logits[:, :-1].contiguous()

            # Calculate per-token loss
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )

        return {"loss": loss, "logits": logits}


class LMHeadVanilla(nn.Module):
    def __init__(self, config: Config, vocab_size: int, pad_id: int):
        super().__init__()
        self.config = config
        if config.block_type == "transformer":
            constructor = Decoder
        elif config.block_type == "transformer-sota":
            constructor = DecoderSota
        else:
            raise Exception("Not valid model type")
        self.encoder = constructor(
            vocab_size=vocab_size,
            pad_id=pad_id,
            config=config,
        )
        self.head = nn.Linear(self.config.hidden_dim, vocab_size)

    def forward(
        self,
        input_ids: torch.tensor,
        labels: torch.tensor = None,
    ) -> Dict[str, torch.tensor]:
        """
        Args:
            input_ids: jnp.array(BL) - input ids
            labels: jnp.array(BL)
            train: bool - used for dropout
        Returns:
            out: Dict[str, jnp.array] - loss and logits
        """
        X = self.encoder(input_ids)  # BLH
        logits = self.head(X)  # BLH -> BLV
        if labels is None:
            return {"logits": logits}

        # ignore pad tokens
        # labels = labels[:, 1:]
        shift_labels = labels[:, 1:].contiguous()
        shift_logits = logits[:, :-1].contiguous()

        # Calculate per-token loss
        loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
        )
        return {"loss": loss, "logits": logits}


class LinearTransWrapper(nn.Module):
    def __init__(
        self,
        d_model,
        sequence_length,
        vocab_size,
        attention_type="full",
        n_layers=4,
        n_heads=4,
        d_query=32,
        dropout=0.1,
        softmax_temp=None,
        attention_dropout=0.1,
        pad_id=0,
    ):
        super().__init__()

        self.pos_embedding = PositionalEncoding(d_model, max_len=sequence_length)
        self.value_embedding = torch.nn.Embedding(
            vocab_size, d_model, padding_idx=pad_id
        )

        self.transformer = TransformerEncoderBuilder.from_kwargs(
            attention_type=attention_type,
            n_layers=n_layers,
            n_heads=n_heads,
            feed_forward_dimensions=n_heads * d_query * 4,
            query_dimensions=d_query,
            value_dimensions=d_query,
            dropout=dropout,
            softmax_temp=softmax_temp,
            attention_dropout=attention_dropout,
        ).get()

        hidden_size = n_heads * d_query
        self.predictor = torch.nn.Linear(hidden_size, vocab_size)

    def forward(self, x, labels=None):
        x = self.value_embedding(x)
        x = self.pos_embedding(x)
        triangular_mask = TriangularCausalMask(x.shape[1], device=x.device)
        y_hat = self.transformer(x, attn_mask=triangular_mask)
        logits = self.predictor(y_hat)

        if labels is not None:
            shift_labels = labels[:, 1:].contiguous()
            shift_logits = logits[:, :-1].contiguous()
            # Calculate per-token loss
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
        return {"loss": loss, "logits": logits}


class RetNetWrapper(nn.Module):
    def __init__(self, pad_id, config):
        super().__init__()
        self.pad_id = pad_id

        value_embedding = torch.nn.Embedding(
            config.vocab_size, config.decoder_embed_dim, padding_idx=pad_id
        )
        predictor = torch.nn.Linear(config.decoder_embed_dim, config.vocab_size)
        self._base_model = RetNetDecoder(
            config, embed_tokens=value_embedding, output_projection=predictor
        )

    def forward(self, input_ids, labels=None):
        logits, _ = self._base_model(prev_output_tokens=input_ids)

        if labels is not None:
            shift_labels = labels[:, 1:].contiguous()
            shift_logits = logits[:, :-1].contiguous()
            # Calculate per-token loss
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
        return {"loss": loss, "logits": logits}
