

__all__ = ['TiTR']  # i means independent channels

# Cell
from typing import Callable, Optional
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
import numpy as np

#from collections import OrderedDict
from ..models.pos_encoding import *
from ..models.layers import *
from ..models.RevIN import RevIN

class TiTR(nn.Module):
    def __init__(self, c_in:int, context_window:int, target_window:int, patch_len:int, stride:int, patch_len2:int, 
                 stride2:int, padding_patch = None, padding_patch2 = None, c_out=1, max_seq_len:Optional[int]=1024, mlp_dim_1 = 512,
                 n_layers:int=1, n_layers_hist:int=1, d_model=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None,
                 d_ff:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str="gelu", key_padding_mask:bool='auto',
                 padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,
                 pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, future_patch = False,
                 pretrain_head:bool=False, head_type = 'flatten', revin = True, affine = True, subtract_last = False,
                 verbose:bool=False, probablistic=False, n_quantiles=11, masked=3, **kwargs):
        
        super().__init__()
        
        self.context_window = context_window
        self.target_window = target_window
        self.masked = masked
        
        # RevIn
        self.revin = revin
        if self.revin: 
            self.revin_layer = RevIN(c_in-1, affine=affine)
            self.revin_layer_hist = RevIN(c_in, affine=affine)
        
        # Patching
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch = padding_patch
        patch_num = int((context_window - patch_len)/stride + 1)
        if padding_patch == 'end': # can be modified to general case
            self.padding_patch_layer = nn.ReplicationPad1d((0, stride)) 
            patch_num += 1
            
        self.patch_len2 = patch_len2
        self.stride2 = stride2
        self.padding_patch2 = padding_patch2
        patch_num2 = int((target_window - patch_len2)/stride2 + 1)
        if padding_patch2 == 'end': # can be modified to general case
            self.padding_patch_layer2 = nn.ReplicationPad1d((0, stride2)) 
            patch_num2 += 1

        
        # Backbone 
        self.backbone = TSTiEncoder(c_in-masked, patch_num=patch_num2, patch_len=patch_len2, max_seq_len=max_seq_len,
                                n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff,
                                attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
                                attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
                                pe=pe, learn_pe=learn_pe, verbose=verbose, **kwargs)
        
        self.backbone_hist = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len,
                                n_layers=n_layers_hist, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff,
                                attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
                                attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
                                pe=pe, learn_pe=learn_pe, verbose=verbose, **kwargs)
        
        self.cross_attention = CTSTEncoder(q_len=patch_num, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff,
                                attn_dropout=attn_dropout, dropout=dropout, activation=act,
                                res_attention=res_attention, n_layers=1, pre_norm=pre_norm, store_attn=store_attn)

        # Head
        self.head_nf = d_model * patch_num2
        self.n_vars = c_in
        self.c_out = c_out
        self.pretrain_head = pretrain_head
        self.head_type = head_type

        if self.pretrain_head: 
            self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs
        elif head_type == 'flatten': 
            self.head = Flatten_Head(self.c_out, self.head_nf, target_window, head_dropout=head_dropout, probablistic=probablistic, n_quantiles=n_quantiles)
        elif head_type == 'linear': 
            self.head = Linear_Head(d_model, self.c_out, target_window, head_dropout=head_dropout, probablistic=probablistic, n_quantiles=n_quantiles)
        elif head_type == 'mlp': 
            self.head = MLP_Head(d_model, self.c_out, target_window, mlp_dim_1, head_dropout=head_dropout, probablistic=probablistic, n_quantiles=n_quantiles)
            
    
    def forward(self, z):                                                                   # z: [bs x nvars x (context_window+target_window)]

        # split data 
        z_hist = z[:,:,:self.context_window]                                                # z_hist: [bs x nvars x context_window]

        
        z = z[:,self.masked:,self.context_window:]                                          # z: [bs x nvars-1 x target_window]
        

        
        if self.padding_patch2 == 'end':
            z = self.padding_patch_layer2(z)
        z = z.unfold(dimension=-1, size=self.patch_len2, step=self.stride2)                 # z: [bs x nvars-1 x patch_num2 x patch_len2]
        

        z = z.permute(0,1,3,2)                                                              # z: [bs x nvars-1 x patch_len2 x patch_num2]

        
        z = self.backbone(z)                                                                # z: [bs x d_model x patch_num2]
        
        if self.padding_patch == 'end':
            z_hist = self.padding_patch_layer(z_hist)
        z_hist = z_hist.unfold(dimension=-1, size=self.patch_len, step=self.stride)         # z_hist: [bs x nvars x patch_num x patch_len]
        z_hist = z_hist.permute(0,1,3,2)                                                    # z_hist: [bs x nvars x patch_len x patch_num]
        z_hist = self.backbone_hist(z_hist)                                                 # z_hist: [bs x d_model x patch_num]
        
        z_hist = z_hist.permute(0,2,1)                                                      # z_hist: [bs x patch_num x d_model]
        z = z.permute(0,2,1)                                                                # z: [bs x patch_num2 x d_model]
        z = self.cross_attention(z, z_hist)                                                 # z: [bs x patch_num2 x d_model]
        z = z.permute(0,2,1)                                                                # z: [bs x d_model x patch_num2] 
        
        z = self.head(z)                                                                    # z: [bs x c_out x target_window] 
        
        return z
    
    def create_pretrain_head(self, head_nf, vars, dropout):
        return nn.Sequential(nn.Dropout(dropout),
                    nn.Conv1d(head_nf, vars, 1)
                    )


class Flatten_Head(nn.Module):
    def __init__(self, c_out, nf, target_window, head_dropout=0, probablistic=False, n_quantiles=11):
        super().__init__()

        self.c_out = c_out
        self.target_window = target_window
        
        self.flatten = nn.Flatten(start_dim=-2)
        self.dropout = nn.Dropout(head_dropout)
        self.probablistic = probablistic
        
        if self.probablistic:
            self.n_quantiles = n_quantiles
            self.linear = nn.Linear(nf, c_out*target_window*self.n_quantiles)
        else:
            self.linear = nn.Linear(nf, c_out*target_window)
            
    def forward(self, x):                     # x: [bs x d_model x target_window]
        x = self.flatten(x)                   # x: [bs x d_model * target_window]
        x = self.linear(x)                    # x: [bs x c_out * target_window]
        x = self.dropout(x)
        if self.probablistic:
            x = torch.reshape(x, (x.shape[0],self.c_out,self.target_window,self.n_quantiles))
            
            x1, x2 = x[:,:,:,0], x[:,:,:,1:]
            x1 = torch.unsqueeze(x1, -1)
            x2 = F.softplus(x2)
            x = torch.cat((x1,x2),dim=-1)
            x = torch.cumsum(x, dim=-1)
        else:
            x = torch.reshape(x, (x.shape[0],self.c_out,self.target_window))
        return x
    
    
class Linear_Head(nn.Module):
    def __init__(self, d_model, c_out, target_window, head_dropout=0, probablistic=False, n_quantiles=11):
        super().__init__()
        
        self.c_out = c_out
        self.target_window = target_window
        
        self.dropout = nn.Dropout(head_dropout)
        self.probablistic = probablistic
        
        if self.probablistic:
            self.n_quantiles = n_quantiles
            self.linear = nn.Linear(d_model, c_out*self.n_quantiles)
        else:
            self.linear = nn.Linear(d_model, c_out)
            
    def forward(self, x):                     # x: [bs x d_model x target_window]
        
        print("##############")
        print(x.shape)
        x = x.permute(0,2,1)                  # x: [bs x target_window x d_model]
        print(x.shape)
        x = self.linear(x)                    # x: [bs x target_window x c_out]
        x = self.dropout(x)
        if self.probablistic:
            x = torch.reshape(x, (x.shape[0],self.target_window,self.c_out,self.n_quantiles))

            
            x = x.permute(0,2,1,3)
 
            
            x1, x2 = x[:,:,:,0], x[:,:,:,1:]
            print("x1", x1.shape)
            print("x2", x1.shape)
            x1 = torch.unsqueeze(x1, -1)
            x2 = F.softplus(x2)
            print(x.shape)
            x = torch.cat((x1,x2),dim=-1)
            print(x.shape)
            x = torch.cumsum(x, dim=-1)
            print(x.shape)
            stop
        else:  
            x = x.permute(0,2,1)              # x: [bs x c_out x target_window]
        
        return x

    
class MLP_Head(nn.Module):
    def __init__(self, d_model, c_out, target_window, mlp_dim_1=64, head_dropout=0, probablistic=False, n_quantiles=11):
        super().__init__()
        
        self.c_out = c_out
        self.target_window = target_window
        
        self.linear = nn.Linear(d_model, mlp_dim_1)
        self.dropout = nn.Dropout(head_dropout)
        self.probablistic = probablistic
        
        if self.probablistic:
            self.n_quantiles = n_quantiles
            self.linear1 = nn.Linear(mlp_dim_1, c_out*self.n_quantiles)
        else:
            self.linear1 = nn.Linear(mlp_dim_1, c_out)
            
    def forward(self, x):                     # x: [bs x d_model x target_window]
        x = x.permute(0,2,1)                  # x: [bs x target_window x d_model]
        x = F.relu(self.linear(x))            # x: [bs x target_window x mlp_dim_1]
        x = self.dropout(x)
        x = self.linear1(x)                   # x: [bs x target_window x c_out]
        if self.probablistic:
            x = torch.reshape(x, (x.shape[0],self.target_window,self.c_out,self.n_quantiles))
            x = x.permute(0,2,1,3)
            
            x1, x2 = x[:,:,:,0], x[:,:,:,1:]
            x1 = torch.unsqueeze(x1, -1)
            x2 = F.softplus(x2)
            x = torch.cat((x1,x2),dim=-1)
            x = torch.cumsum(x, dim=-1)
        else:  
            x = x.permute(0,2,1)              # x: [bs x c_out x target_window]
        # x = x.permute(0,2,1)                  # x: [bs x c_out x target_window]
        return x          
    
    
class TSTiEncoder(nn.Module):
    def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,
                 n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None,
                 d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., act="gelu", store_attn=False,
                 key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False,
                 pe='zeros', learn_pe=True, verbose=False, **kwargs):
        
        
        super().__init__()
        
        self.patch_num = patch_num
        self.patch_len = patch_len
        
        # Input encoding
        q_len = patch_num
        self.W_P = nn.Linear(patch_len*c_in, d_model)        # Eq 1: projection of feature vectors onto a d-dim vector space
        self.seq_len = q_len

        # Positional encoding
        self.W_pos = positional_encoding(pe, learn_pe, q_len, d_model)

        # Residual dropout
        self.dropout = nn.Dropout(dropout)

        # Encoder
        self.encoder = TSTEncoder(q_len, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,
                                   pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn)

        
    def forward(self, x) -> Tensor:                                              # x: [bs x nvars x patch_len x patch_num]
        
        # torch.Size([1, 10, 1, 24])
        x = torch.reshape(x, (x.shape[0],x.shape[1]*x.shape[2],x.shape[3]))      # x: [bs x nvars * patch_len x patch_num]
        
        # Input encoding
        x = x.permute(0,2,1)                                                     # x: [bs x patch_num x nvars * patch_len]
        # (1,64,24)
        x = self.W_P(x)                                                          # x: [bs x patch_num x d_model]
        # (1,24,64)
        u = self.dropout(x + self.W_pos)                                         # u: [bs x patch_num x d_model]
        

        # Encoder
        z = self.encoder(u)                                                      # z: [bs x patch_num x d_model]
        z = z.permute(0,2,1)                                                     # z: [bs x d_model x patch_num] 
        return z        
            
            
    
# Cell
class TSTEncoder(nn.Module):
    def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None, 
                        norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
                        res_attention=False, n_layers=1, pre_norm=False, store_attn=False):
        super().__init__()

        self.layers = nn.ModuleList([TSTEncoderLayer(q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm,
                                                      attn_dropout=attn_dropout, dropout=dropout,
                                                      activation=activation, res_attention=res_attention,
                                                      pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)])
        self.res_attention = res_attention

    def forward(self, src:Tensor, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
        output = src
        scores = None
        if self.res_attention:
            for mod in self.layers: output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
            return output
        else:
            for mod in self.layers: output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
            return output


class CTSTEncoder(nn.Module):
    def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None, 
                        norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
                        res_attention=False, n_layers=1, pre_norm=False, store_attn=False):
        super().__init__()

        self.layer = CrossAttentionLayer(q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm,
                                                      attn_dropout=attn_dropout, dropout=dropout,
                                                      activation=activation, res_attention=res_attention,
                                                      pre_norm=pre_norm, store_attn=store_attn)
        self.res_attention = res_attention

    def forward(self, src:Tensor, src_hist:Tensor, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
        scores = None
        if self.res_attention:
            output, scores = self.layer(src, src_hist, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
            return output
        else:
            output = self.layer(src, src_hist, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
            return output


class CrossAttentionLayer(nn.Module):
    def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False,
                 norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation="gelu", res_attention=False, pre_norm=False):
        super().__init__()
        assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
        d_k = d_model // n_heads if d_k is None else d_k
        d_v = d_model // n_heads if d_v is None else d_v

        # Multi-Head attention
        self.res_attention = res_attention
        self.self_attn = _MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention)

        # Add & Norm
        self.dropout_attn = nn.Dropout(dropout)
        if "batch" in norm.lower():
            self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
        else:
            self.norm_attn = nn.LayerNorm(d_model)

        # Position-wise Feed-Forward
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
                                get_activation_fn(activation),
                                nn.Dropout(dropout),
                                nn.Linear(d_ff, d_model, bias=bias))

        # Add & Norm
        self.dropout_ffn = nn.Dropout(dropout)
        if "batch" in norm.lower():
            self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
        else:
            self.norm_ffn = nn.LayerNorm(d_model)

        self.pre_norm = pre_norm
        self.store_attn = store_attn


    def forward(self, src:Tensor, src_hist:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None) -> Tensor:

        # Multi-Head attention sublayer
        if self.pre_norm:
            src = self.norm_attn(src)
            src_hist = self.norm_attn(src_hist)
        ## Multi-Head attention
        if self.res_attention:
            src2, attn, scores = self.self_attn(src, src_hist, src_hist, prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        else:
            src2, attn = self.self_attn(src, src_hist, src_hist, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        if self.store_attn:
            self.attn = attn
        ## Add & Norm
        src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout
        if not self.pre_norm:
            src = self.norm_attn(src)

        # Feed-forward sublayer
        if self.pre_norm:
            src = self.norm_ffn(src)
        ## Position-wise Feed-Forward
        src2 = self.ff(src)
        ## Add & Norm
        src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout
        if not self.pre_norm:
            src = self.norm_ffn(src)

        if self.res_attention:
            return src, scores
        else:
            return src


class TSTEncoderLayer(nn.Module):
    def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False,
                 norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation="gelu", res_attention=False, pre_norm=False):
        super().__init__()
        assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
        d_k = d_model // n_heads if d_k is None else d_k
        d_v = d_model // n_heads if d_v is None else d_v

        # Multi-Head attention
        self.res_attention = res_attention
        self.self_attn = _MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention)

        # Add & Norm
        self.dropout_attn = nn.Dropout(dropout)
        if "batch" in norm.lower():
            self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
        else:
            self.norm_attn = nn.LayerNorm(d_model)

        # Position-wise Feed-Forward
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
                                get_activation_fn(activation),
                                nn.Dropout(dropout),
                                nn.Linear(d_ff, d_model, bias=bias))

        # Add & Norm
        self.dropout_ffn = nn.Dropout(dropout)
        if "batch" in norm.lower():
            self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
        else:
            self.norm_ffn = nn.LayerNorm(d_model)

        self.pre_norm = pre_norm
        self.store_attn = store_attn


    def forward(self, src:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None) -> Tensor:

        # Multi-Head attention sublayer
        if self.pre_norm:
            src = self.norm_attn(src)
        ## Multi-Head attention
        if self.res_attention:
            src2, attn, scores = self.self_attn(src, src, src, prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        else:
            src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        if self.store_attn:
            self.attn = attn
        ## Add & Norm
        src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout
        if not self.pre_norm:
            src = self.norm_attn(src)

        # Feed-forward sublayer
        if self.pre_norm:
            src = self.norm_ffn(src)
        ## Position-wise Feed-Forward
        src2 = self.ff(src)
        ## Add & Norm
        src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout
        if not self.pre_norm:
            src = self.norm_ffn(src)

        if self.res_attention:
            return src, scores
        else:
            return src        


class _MultiheadAttention(nn.Module):
    def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False):
        """Multi Head Attention Layer
        Input shape:
            Q:       [batch_size (bs) x max_q_len x d_model]
            K, V:    [batch_size (bs) x q_len x d_model]
            mask:    [q_len x q_len]
        """
        super().__init__()
        d_k = d_model // n_heads if d_k is None else d_k
        d_v = d_model // n_heads if d_v is None else d_v

        self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v

        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)

        # Scaled Dot-Product Attention (multiple heads)
        self.res_attention = res_attention
        self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa)

        # Poject output
        self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout))


    def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,
                key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):

        bs = Q.size(0)
        if K is None: K = Q
        if V is None: V = Q

        # Linear (+ split in multiple heads)
        q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)       # q_s    : [bs x n_heads x max_q_len x d_k]
        k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1)     # k_s    : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
        v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2)       # v_s    : [bs x n_heads x q_len x d_v]

        # Apply Scaled Dot-Product Attention (multiple heads)
        if self.res_attention:
            output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        else:
            output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]

        # back to the original inputs dimensions
        output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
        output = self.to_out(output)

        if self.res_attention: return output, attn_weights, attn_scores
        else: return output, attn_weights


class _ScaledDotProductAttention(nn.Module):
    r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
    (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
    by Lee et al, 2021)"""

    def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False):
        super().__init__()
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.res_attention = res_attention
        head_dim = d_model // n_heads
        self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)
        self.lsa = lsa

    def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
        '''
        Input shape:
            q               : [bs x n_heads x max_q_len x d_k]
            k               : [bs x n_heads x d_k x seq_len]
            v               : [bs x n_heads x seq_len x d_v]
            prev            : [bs x n_heads x q_len x seq_len]
            key_padding_mask: [bs x seq_len]
            attn_mask       : [1 x seq_len x seq_len]
        Output shape:
            output:  [bs x n_heads x q_len x d_v]
            attn   : [bs x n_heads x q_len x seq_len]
            scores : [bs x n_heads x q_len x seq_len]
        '''

        # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
        attn_scores = torch.matmul(q, k) * self.scale      # attn_scores : [bs x n_heads x max_q_len x q_len]

        # Add pre-softmax attention scores from the previous layer (optional)
        if prev is not None: attn_scores = attn_scores + prev

        # Attention mask (optional)
        if attn_mask is not None:                                     # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
            if attn_mask.dtype == torch.bool:
                attn_scores.masked_fill_(attn_mask, -np.inf)
            else:
                attn_scores += attn_mask

        # Key padding mask (optional)
        if key_padding_mask is not None:                              # mask with shape [bs x q_len] (only when max_w_len == q_len)
            attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)

        # normalize the attention weights
        attn_weights = F.softmax(attn_scores, dim=-1)                 # attn_weights   : [bs x n_heads x max_q_len x q_len]
        attn_weights = self.attn_dropout(attn_weights)

        # compute the new values given the attention weights
        output = torch.matmul(attn_weights, v)                        # output: [bs x n_heads x max_q_len x d_v]

        if self.res_attention: return output, attn_weights, attn_scores
        else: return output, attn_weights

