import torch
import torch.nn as nn
import torch.nn.functional as F
from .Attention import LunaAttentionLayer


class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x):
        new_x = self.attention(
            x, x, x
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        # [bsz, seq_len, hidden_dim]
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        # [bsz, hidden_dim*4, seq_len]
        y = self.dropout(self.conv2(y).transpose(-1, 1))
        return self.norm2(x + y)

class Encoder(nn.Module):
    def __init__(self, attn_layers, norm_layer=None): 
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.norm = norm_layer

    def forward(self, x):
        # x [B, L, D]
        for attn_layer in self.attn_layers:
            x = attn_layer(x)

        if self.norm is not None:
            x = self.norm(x)

        return x

class LunaEncoder(nn.Module):
    def __init__(self, attn_layers, d_model, drop_out, project_embedding_length, norm_layer=None): 
        super(LunaEncoder, self).__init__()
        self.project_embedding_length = project_embedding_length
        self.attn_layers = nn.ModuleList(attn_layers)
        self.norm_1 = norm_layer
        self.project_embedding = LunaLearnablePositionalEncoding(d_model=d_model, dropout=drop_out, max_len=project_embedding_length)
        self.drop_out=nn.Dropout(p=drop_out)

    def forward(self, x):
        # x [B, L, D]
        B, L, D = x.shape
        position_ids = torch.arange(self.project_embedding_length, dtype = torch.long, device = x.device)[:]
        projected_embedded = self.project_embedding(position_ids).squeeze()
        seq_length, dim = projected_embedded.size()
        projected_embedded = projected_embedded.unsqueeze(0).expand(B, seq_length, dim)
        p = self.drop_out(projected_embedded)
        for attn_layer in self.attn_layers:
            x, p = attn_layer(x, p)

        if self.norm_1 is not None:
            x = self.norm_1(x)

        return x, p

class LunaLearnablePositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=1024):
        super(LunaLearnablePositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # Each position gets its own embedding
        # Since indices are always 0 ... max_len, we don't have to do a look-up
        self.pe = nn.Parameter(torch.empty(max_len, 1, d_model))  # requires_grad automatically set to True
        nn.init.uniform_(self.pe, -0.02, 0.02)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = self.pe[:x.size(0), :]
        return self.dropout(x)