import torch
import torch.nn as nn
import math

class Block(nn.Module):
    """Causal transformer block
    """
    def __init__(self, dim, num_heads):
        super().__init__()
        self.ln_1 = nn.LayerNorm(dim)
        self.ln_2 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads,batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, x):
        attn_mask = torch.full(
            (x.shape[1], x.shape[1]), -float("Inf"), device=x.device, dtype=x.dtype
        )
        attn_mask = torch.triu(attn_mask, diagonal=1)
        attn_mask[torch.isnan(attn_mask)] = 0.0 # fixes all 'nan' on 'mps' device

        x = self.ln_1(x)
        a, _ = self.attn(x, x, x,attn_mask= attn_mask)
        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m
        return x


class Causal_Transformer(nn.Module):
    """Causal Transformer decoder
    """
    def __init__(self, vocab_size, d_model, num_heads, num_layers,pretrained_path=None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)

        self.position_embeddings = nn.Embedding(4, d_model)
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(Block(d_model, num_heads))

        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.positional_encoding = PositionalEncoding(d_model)
        if pretrained_path is not None:
            state_dict = torch.load(pretrained_path)
            self.embedding.load_state_dict(state_dict['embedding'], strict=False)

    def forward(self, x,mm):
        # x = self.embedding(x)  # [batch_size, seq_len, embed_dim]
        # x = self.positional_encoding(x.permute(1, 0, 2))  # [seq_len, batch_size, embed_dim]
        # h = x.permute(1, 0, 2)

        h = self.embedding(x)
        positions = torch.arange(x.shape[1], device=x.device).unsqueeze(-1)
        h = h + (self.position_embeddings(positions).permute(1,0,2)).expand_as(h)
        for layer in self.layers:
            h = layer(h)
        h = self.ln_f(h)
        logits = self.head(h)
        return logits
    


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x
    
def custom_weights_init(m, a=0.5):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        m.weight.data *= a
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.MultiheadAttention):
        nn.init.xavier_uniform_(m.in_proj_weight)
        m.in_proj_weight.data *= a
        nn.init.zeros_(m.in_proj_bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.ones_(m.weight)
        m.weight.data *= a
        nn.init.zeros_(m.bias)