import torch
from lru import LRU
import math
from mamba_ssm import Mamba

class TransformerLanguageModel(torch.nn.Module):
    # LM
    def __init__(self, embed_dim, vocab_size, enc_layers, num_heads, 
        ff_dim, pad_idx, device, pos_enc=False):
        super(TransformerLanguageModel, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embed_dim)
        self.num_heads = num_heads
        self.transformer_enc = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(d_model=embed_dim, 
                nhead=num_heads, dim_feedforward=ff_dim, batch_first=True), 
            num_layers=enc_layers)
        self.output_layer = torch.nn.Linear(embed_dim, vocab_size)
        self.pad_idx = pad_idx
        self.device = device
        self.pos_enc = pos_enc
        if pos_enc:
            self.pe = PositionalEncoding(embed_dim).to(device)
        else:
            self.pe = None

    def forward(self, x):
        if self.pos_enc:
            embed = self.pe(self.embedding(x))
        else:
            embed = self.embedding(x)
        mask = torch.triu(
            torch.ones(
                x.shape[1], x.shape[1] ), diagonal=1).bool().to(self.device)
        src_padding_mask = (x == self.pad_idx)
        encoding = self.transformer_enc(embed, mask=mask,
            src_key_padding_mask=src_padding_mask)
        final_layer = self.output_layer(encoding)
        logits = torch.nn.functional.log_softmax(final_layer, dim=2)
        return logits

class LSTMLanguageModel(torch.nn.Module):
    # LM
    def __init__(self, embed_dim, vocab_size, hidden_size, layers, pad_idx, 
        device, use_attention=True):
        super(LSTMLanguageModel, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embed_dim)
        self.lstm = torch.nn.LSTM(input_size=embed_dim, hidden_size=hidden_size, 
            num_layers=layers, batch_first=True)
        self.use_attention = use_attention
        if self.use_attention:
            self.attention = torch.nn.MultiheadAttention(hidden_size, num_heads=1, 
                bias=False, batch_first=True)
        self.output_layer = torch.nn.Linear(hidden_size, vocab_size)
        self.pad_idx = pad_idx
        self.device = device

    def forward(self, x):
        x_embed = self.embedding(x)
        out, states = self.lstm(x_embed)
        mask = torch.triu(
            torch.ones(
                x.shape[1], x.shape[1] ), diagonal=1).bool().to(self.device)
        src_padding_mask = (x == self.pad_idx)
        if self.use_attention:
            weighted = self.attention(out, out, out, attn_mask=mask, 
                key_padding_mask=src_padding_mask, need_weights=False)[0]
        else:
            weighted = out
        final_layer = self.output_layer(weighted)
        logits = torch.nn.functional.log_softmax(final_layer, dim=2)
        return logits

class LRULanguageModel(torch.nn.Module):
    def __init__(self, embed_dim, vocab_size, hidden_size, layers, pad_idx, 
        device, r_min, r_max, use_attention=True):
        super(LRULanguageModel, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embed_dim)
        self.lru = LRU(embed_dim=embed_dim, hidden_size=hidden_size, 
            num_layers=layers, device=device, r_min=r_min, r_max=r_max)
        self.use_attention = use_attention
        if self.use_attention:
            self.attention = torch.nn.MultiheadAttention(hidden_size, num_heads=1, 
                bias=False, batch_first=True)
        self.output_layer = torch.nn.Linear(hidden_size, vocab_size)
        self.pad_idx = pad_idx
        self.device = device

    def forward(self, x):
        x_embed = self.embedding(x)
        out, hidden  = self.lru(x_embed)
        mask = torch.triu(
            torch.ones(
                x.shape[1], x.shape[1] ), diagonal=1).bool().to(self.device)
        src_padding_mask = (x == self.pad_idx)
        if self.use_attention:
            weighted = self.attention(out, out, out, attn_mask=mask, 
                key_padding_mask=src_padding_mask, need_weights=False)[0]
        else:
            weighted = out
        final_layer = self.output_layer(weighted)
        logits = torch.nn.functional.log_softmax(final_layer, dim=2)
        return logits

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=256):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.double).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).double() 
            * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x, offset=0):
        x = x + self.pe[:,offset:offset+x.shape[1], :]
        return self.dropout(x)

class MambaLanguageModel(torch.nn.Module):
    # LM
    def __init__(self, embed_dim, vocab_size, d_conv, d_state, expand, 
        pad_idx, device):
        super(MambaLanguageModel, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embed_dim)
        self.mamba = Mamba(d_model=embed_dim, d_state=d_state,
            d_conv=d_conv, expand=expand)
        self.output_layer = torch.nn.Linear(embed_dim, vocab_size)
        self.pad_idx = pad_idx
        self.device = device

    def forward(self, x):
        x_embed = self.embedding(x)
        out = self.mamba(x_embed)
        final_layer = self.output_layer(out)
        logits = torch.nn.functional.log_softmax(final_layer, dim=2)
        return logits