import torch.nn as nn
import torch.nn.functional as F
import torch
import math

MAX_VAL = 1e4
MIN_VAL = 1e-12


##################
# Self-Attention #
##################

class Attention(nn.Module):
    "Compute 'Scaled Dot Product Attention"
    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1) == 0, -MAX_VAL)
        p_attn = F.softmax(scores, dim=-1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.matmul(p_attn, value), p_attn


class MultiHeadedAttention(nn.Module):
    "Take in model size and number of heads."
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0

        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linear_layers, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch.
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.output_linear(x)


#########
# Utils #
#########
# class GELU(nn.Module):
#     "Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU"
#     def forward(self, x):
#         return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

# class LayerNorm(nn.Module):
#     "Construct a layernorm module (See citation for details)."
#     def __init__(self, features, eps=1e-12):
#         super(LayerNorm, self).__init__()
#         self.a_2 = nn.Parameter(torch.ones(features))
#         self.b_2 = nn.Parameter(torch.zeros(features))
#         self.eps = eps

#     def forward(self, x):
#         mean = x.mean(-1, keepdim=True)
#         var = x.var(-1, keepdim=True, unbiased=False)
#         inv = (var + self.eps).rsqrt() * self.a_2
#         return x * inv + (self.b_2 - mean * inv)
    
    
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.activation = nn.GELU()

    def forward(self, x):
        return self.w_2(self.activation(self.w_1(x)))

class SublayerConnection(nn.Module):
    "A residual connection followed by a layer norm."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = nn.LayerNorm(size)#LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        # sublayer is a single or a combination of learnable NN layers
        # sublayer(x) returns the output after passing x through the sublayer
        # the output of sublayer(x) is added to x for residual connection
        return self.norm(x + self.dropout(sublayer(x))) 
    
class SublayerConnection1(nn.Module):
    "A residual connection."
    def __init__(self, size, dropout):
        super(SublayerConnection1, self).__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        # sublayer is a single or a combination of learnable NN layers
        # sublayer(x) returns the output after passing x through the sublayer
        # the output of sublayer(x) is added to x for residual connection
        return x + self.dropout(sublayer(x))

class OutputLayer(nn.Module):
    "Ouptut Layer for BERT model"
    def __init__(self, hidden_dim):
        super(OutputLayer, self).__init__()
        self.linear = nn.Linear(hidden_dim, hidden_dim)
        self.activation = nn.GELU()
        self.layer_norm = nn.LayerNorm(hidden_dim)#LayerNorm(hidden_dim)

    def forward(self, x):
        return self.layer_norm(self.activation(self.linear(x)))

####################
# TransformerBlock #
####################

class TransformerBlock(nn.Module):
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)

    def forward(self, x, attn_mask=None):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=attn_mask))   # use residual connection
        
        x = self.output_sublayer(x, self.feed_forward)
        
        return x


class TransformerBlock1(nn.Module):
    def __init__(self, hidden, attn_heads, dropout):
        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)

    def forward(self, x, attn_mask=None):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=attn_mask))   # use residual connection
        
        return x
    

class TransformerBlock3(nn.Module):
    def __init__(self, hidden, attn_heads, dropout):
        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection1(size=hidden, dropout=dropout)

    def forward(self, x, attn_mask=None):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=attn_mask))   # use residual connection
        
        return x


class TransformerBlock2(nn.Module):
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)

    def forward(self, x, attn_mask=None):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=attn_mask))   # use residual connection
        
        x = self.output_sublayer(x, self.feed_forward)
        
        return x
    
    
class CrossAttentionBlock(nn.Module):
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        super().__init__()
        self.cross_attn = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)

    def forward(self, query, key, value, attn_mask=None):
        # 1. Cross-Attention
        x = self.input_sublayer(query, lambda _q: self.cross_attn(_q, key, value, mask=attn_mask))
        
        # 2. Feed-Forward Network
        x = self.output_sublayer(x, self.feed_forward)
        return x
    

class CrossAttentionBlock1(nn.Module):
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        super().__init__()
        self.cross_attn = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.mid_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)


    def forward(self, query, key, value, attn_mask=None):
        # Cross-Attention
        x = self.input_sublayer(query, lambda _q: self.cross_attn(_q, key, value, mask=attn_mask))
        
        # Self-Attention
        x = self.mid_sublayer(x, lambda _x: self.attention(_x, _x, _x))
        x += query
        
        # Feed-Forward Network
        x = self.output_sublayer(x, self.feed_forward)
        
        return x
    
    
class CrossAttention(nn.Module):
    def __init__(self, hidden, attn_heads, dropout):
        super().__init__()
        self.cross_attn = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        
    def forward(self, query, key, value, attn_mask=None):
        # 1. Cross-Attention
        x = self.input_sublayer(query, lambda _q: self.cross_attn(_q, key, value, mask=attn_mask))
        
        return x