import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import torch.nn as nn
from torch.nn import functional as F
from model import CaracalForCausalLM
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import LlamaForCausalLM, LlamaConfig, JambaForCausalLM, JambaConfig


class CARACAL(nn.Module):
    def __init__(self, d_model, n_layers, n_heads, vocab_size, attn_layers=(), window_size=256, **kwargs):
        super().__init__()
        ffn_dim = int(2 * d_model * 4 / 3)
        intermediate_size = (ffn_dim + 127) // 128 * 128
        self.model = CaracalForCausalLM(d_model, n_layers, n_heads, vocab_size, intermediate_size, attn_layers, window_size)

    def forward(self, x, targets=None, attention_mask=None, **kwargs):
        # x: [batch_size, seq_len]
        logits = self.model(x)
        # logits: [batch_size, seq_len, vocab_size]
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
        return logits, loss


class LLAMA(nn.Module):
    def __init__(self, d_model, n_layers, n_heads, vocab_size, **kwargs):
        super().__init__()
        ffn_dim = int(2 * d_model * 4 / 3)
        intermediate_size = (ffn_dim + 127) // 128 * 128
        self.config = LlamaConfig(
            vocab_size=vocab_size,
            hidden_size=d_model,
            num_hidden_layers=n_layers,
            num_attention_heads=n_heads,
            num_key_value_heads=n_heads,
            intermediate_size=intermediate_size,
            max_position_embeddings=131072,
            hidden_act="silu",
            rms_norm_eps=1e-6,
            tie_word_embeddings=True
        )
        self.transformer = LlamaForCausalLM(self.config)

    def forward(self, x, targets=None, attention_mask=None, **kwargs):
        # x: [batch_size, src_len]
        logits = self.transformer(input_ids=x, attention_mask=attention_mask).logits
        # logits: [batch_size, src_len, vocab_size]
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1))
        return logits, loss


class MAMBA(nn.Module):
    def __init__(self, d_model, n_layers, vocab_size, d_state=16, d_conv=4, expand=2, **kwargs):
        super().__init__()
        self.n_layers = n_layers * 2
        self.config = MambaConfig(
            d_model=d_model,
            n_layer=self.n_layers,
            vocab_size=vocab_size,
            ssm_cfg={'d_state': d_state, 'd_conv': d_conv, 'expand': expand}
        )
        self.model = MambaLMHeadModel(config=self.config)

    def forward(self, x, targets=None, attention_mask=None, **kwargs):
        # x: [batch_size, seq_len]
        logits = self.model(x).logits
        # logits: [batch_size, seq_len, vocab_size]
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
        return logits, loss


class MAMBA2(nn.Module):
    def __init__(self, d_model, n_layers, vocab_size, d_state=16, d_conv=4, expand=2, **kwargs):
        super().__init__()
        self.n_layers = n_layers * 2
        self.config = MambaConfig(
            d_model=d_model,
            n_layer=self.n_layers,
            vocab_size=vocab_size,
            ssm_cfg={'d_state': d_state, 'd_conv': d_conv, 'expand': expand, 'layer': 'Mamba2'}
        )
        self.model = MambaLMHeadModel(config=self.config)

    def forward(self, x, targets=None, attention_mask=None, **kwargs):
        # x: [batch_size, seq_len]
        logits = self.model(x).logits
        # logits: [batch_size, seq_len, vocab_size]
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
        return logits, loss


class JAMBA(nn.Module):
    def __init__(self, d_model, n_layers, n_heads, vocab_size, **kwargs):
        super().__init__()
        ffn_dim = int(2 * d_model * 4 / 3)
        intermediate_size = (ffn_dim + 127) // 128 * 128
        self.config = JambaConfig(
            vocab_size=vocab_size,
            hidden_size=d_model,
            num_hidden_layers=n_layers,
            num_attention_heads=n_heads,
            num_key_value_heads=n_heads,
            intermediate_size=intermediate_size,
            attn_layer_period=6,
            attn_layer_offset=5,
            num_experts=1, 
            tie_word_embeddings=True,
            rms_norm_eps=1e-6
        )
        self.model = JambaForCausalLM(self.config)

    def forward(self, x, targets=None, attention_mask=None, **kwargs):
        # x: [batch_size, seq_len]
        logits = self.model(input_ids=x, attention_mask=attention_mask).logits
        # logits: [batch_size, src_len, vocab_size]
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
        return logits, loss
