import torch
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
import numpy as np
import math
from math import sqrt
import os
import torch
from models.mra2_kernel.attention import mra2_attention
import math
from performer_pytorch import FastAttention

class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AttentionLayer, self).__init__()
        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)
        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out = self.inner_attention(
            queries,
            keys,
            values
        )
        out = out.reshape(B, L, -1)

        return self.out_projection(out)

class MrsAttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None,
                 d_values=None, **kwargs):
        super(MrsAttentionLayer, self).__init__()
        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)
        self.n_heads = n_heads
        self.inner_attention = attention
        self.query_projection = nn.ModuleList([nn.Linear(d_model, d_keys) for i in range(self.n_heads)])
        self.key_projection = nn.ModuleList([nn.Linear(d_model, d_keys) for i in range(self.n_heads) ])
        self.value_projection = nn.ModuleList([nn.Linear(d_model, d_keys) for i in range(self.n_heads) ])
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.group_by_list = [1, 1, 2, 2, 4, 4, 8, 8]
        self.attn = attention

    def forward(self, queries, keys, values):
        X = queries
        batch_size, seq_len, dim = X.shape
        attn_out = torch.empty(X.shape[0], self.n_heads, seq_len, dim // self.n_heads, device=X.device)

        for h in range(self.n_heads):
            # Down sampling mask and input
            # [bsz, seq_len, dim]
            X_ = F.avg_pool1d(X.transpose(-1, -2), kernel_size=self.group_by_list[h], stride=self.group_by_list[h], ceil_mode=True, count_include_pad=False).transpose(-1, -2)
            # Calcualte downsampled qkv
            q_ = self.query_projection[h](X_).view(batch_size, -1, dim // self.n_heads).unsqueeze(2)
            k_ = self.key_projection[h](X_).view(batch_size, -1, dim // self.n_heads).unsqueeze(2)
            v_ = self.value_projection[h](X_).view(batch_size, -1, dim // self.n_heads).unsqueeze(2)
            attn_out_ = self.attn(q_, k_, v_)
            attn_out[:, h, :, :] = attn_out_.squeeze().repeat_interleave(self.group_by_list[h], dim=1)[:, :seq_len, :]
        # [bsz, n_heads, seq_len, dim]
        attn_out = attn_out.transpose(1, 2)
        out = attn_out.reshape(batch_size, seq_len, -1)

        return self.out_projection(out)

class LunaEncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(LunaEncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.luna_attention = LunaAttentionLayer(d_model=d_model, drop_out=dropout, n_heads=n_heads, d_keys=d_model//n_heads)
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.packed_context_layer_norm = nn.LayerNorm(d_model)
        self.unpacked_context_layer_norm = nn.LayerNorm(d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, p):
        unpacked_context, packed_context = self.luna_attention(
            queries=x,
            keys=x,
            values=x,
            p=p,
        )
        packed_context = self.dropout(self.packed_context_layer_norm(packed_context + p))
        x = x + self.dropout(unpacked_context)
        y = x = self.norm1(x)
        # [bsz, seq_len, hidden_dim]
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        # [bsz, hidden_dim*4, seq_len]
        y = self.dropout(self.conv2(y).transpose(-1, 1))
        return self.norm2(x + y), packed_context

class MrsLunaEncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(MrsLunaEncoderLayer, self).__init__()
        self.d_ff = d_ff or 4 * d_model
        self.n_heads = n_heads
        self.luna_attention = MrsLunaAttentionLayer(d_model=d_model, drop_out=dropout, n_heads=n_heads, d_keys=d_model//n_heads)
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=self.d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=self.d_ff, out_channels=d_model, kernel_size=1)
        self.packed_context_layer_norm = nn.LayerNorm(d_model)
        self.unpacked_context_layer_norm = nn.LayerNorm(d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu
        self.group_by_list = [1, 1, 2, 2, 4, 4, 8, 8]

    def forward(self, x, p):
        batch_size, seq_len, dim = x.shape
        unpacked_context = torch.empty(x.shape[0], self.n_heads, seq_len, dim // self.n_heads, device=x.device)
        packed_context = torch.empty(p.shape[0], self.n_heads, seq_len, dim // self.n_heads, device=x.device)
        unpacked_context, packed_context = self.luna_attention(
            queries=x,
            keys=x,
            values=x,
            p=p,
        )
        packed_context = self.dropout(self.packed_context_layer_norm(packed_context + p))
        x = x + self.dropout(unpacked_context)
        y = x = self.norm1(x)
        # [bsz, seq_len, hidden_dim]
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        # [bsz, hidden_dim*4, seq_len]
        y = self.dropout(self.conv2(y).transpose(-1, 1))
        return self.norm2(x + y), packed_context

class LunaAttentionLayer(nn.Module):
    def __init__(self, d_model, drop_out, n_heads, d_keys=None,
                 d_values=None):
        super(LunaAttentionLayer, self).__init__()
        self.pack_attention = MultiheadAttentionLayer(d_model=d_model, drop_out=drop_out, n_heads=n_heads, d_keys=d_keys)
        self.unpack_attention = MultiheadAttentionLayer(d_model=d_model, drop_out=drop_out, n_heads=n_heads, d_keys=d_keys)

    def forward(self, queries, keys, values, p):
        B, L, _ = queries.shape
        _, S, _ = keys.shape

        packed_context = self.pack_attention(p, keys, values)
        unpacked_context = self.unpack_attention(queries, packed_context, packed_context)

        return unpacked_context, packed_context

class MrsLunaAttentionLayer(nn.Module):
    def __init__(self, d_model, drop_out, n_heads, d_keys=None,
                 d_values=None):
        super(MrsLunaAttentionLayer, self).__init__()
        self.pack_attention = MrsMultiheadAttentionLayer(d_model=d_model, drop_out=drop_out, n_heads=n_heads, d_keys=d_keys, downsampling_mode = "kv")
        self.unpack_attention = MrsMultiheadAttentionLayer(d_model=d_model, drop_out=drop_out, n_heads=n_heads, d_keys=d_keys, downsampling_mode = "q")

    def forward(self, queries, keys, values, p):
        B, L, _ = queries.shape
        _, S, _ = keys.shape

        packed_context = self.pack_attention(p, keys, values)
        unpacked_context = self.unpack_attention(queries, packed_context, packed_context)

        return unpacked_context, packed_context

class Flow_Attention(nn.Module):

    def __init__(self, attention_dropout=0.1):
        super(Flow_Attention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)

    def kernel_method(self, x):
        return torch.sigmoid(x)

    def forward(self, queries, keys, values):
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        # kernel
        queries = self.kernel_method(queries)
        keys = self.kernel_method(keys)
        # incoming and outgoing
        normalizer_row = 1.0 / (torch.einsum("nhld,nhd->nhl", queries + 1e-6, keys.sum(dim=2) + 1e-6))
        normalizer_col = 1.0 / (torch.einsum("nhsd,nhd->nhs", keys + 1e-6, queries.sum(dim=2) + 1e-6))
        # reweighting
        normalizer_row_refine = (
            torch.einsum("nhld,nhd->nhl", queries + 1e-6, (keys * normalizer_col[:, :, :, None]).sum(dim=2) + 1e-6))
        normalizer_col_refine = (
            torch.einsum("nhsd,nhd->nhs", keys + 1e-6, (queries * normalizer_row[:, :, :, None]).sum(dim=2) + 1e-6))
        # competition and allocation
        normalizer_row_refine = torch.sigmoid(
            normalizer_row_refine * (float(queries.shape[2]) / float(keys.shape[2])))
        normalizer_col_refine = torch.softmax(normalizer_col_refine, dim=-1) * keys.shape[2]  # B h L vis
        # multiply
        kv = keys.transpose(-2, -1) @ (values * normalizer_col_refine[:, :, :, None])
        x = (((queries @ kv) * normalizer_row[:, :, :, None]) * normalizer_row_refine[:, :, :, None]).transpose(1, 2).contiguous()
        return x

class MRA_head_Attention(nn.Module):
    def __init__(self, attention_dropout=0.1):
        super(MRA_head_Attention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.group_by_list = [1, 1, 2, 2, 4, 4, 8, 8]

    def forward(self, queries, keys, values):
        # [bsz, seq_len, no_heads, hidden_dim]
        seq_len = queries.shape[1]
        num_cluster_list = [seq_len // group_by for group_by in map(int, self.group_by_list)]
        remainder = [seq_len - num_cluster * group_by
                     for num_cluster, group_by  in zip(map(int, num_cluster_list), map(int, self.group_by_list))]
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        # [bsz, no_heads, seq_len, hidden_dim]
        attn = torch.empty((queries.shape[0], queries.shape[1], queries.shape[2], queries.shape[2]), device=queries.device)
        for head in range(queries.shape[1]):
            query = queries[:, head, :, :]
            key = keys[:, head, :, :]
            value = values[:, head, :, :]
            # Perform padding
            if remainder != 0:
                query = F.pad(query, (0, 0, 0, self.group_by_list[head] - remainder[head]), mode='constant')
                key = F.pad(key, (0, 0, 0, self.group_by_list[head] - remainder[head]), mode='constant')
                value = F.pad(value, (0, 0, 0, self.group_by_list[head] - remainder[head]), mode='constant')
            # Calculate average for Q,K and sum for V
            _query = F.avg_pool1d(query.transpose(1, 2), kernel_size=self.group_by_list[head],
                                  stride=self.group_by_list[head], count_include_pad=False).transpose(1,2)
            _key = F.avg_pool1d(key.transpose(1, 2), kernel_size=self.group_by_list[head],
                                  stride=self.group_by_list[head], count_include_pad=False).transpose(1, 2)
            # Correct term if remaider > 0
            if remainder[head] != 0:
                # [bsz, seq_len, hidden_dim]
                _query[:, -1, :] = _query[:, -1, :] * self.group_by_list[head] / remainder[head]
                _key[:, -1, :] = _key[:, -1, :]* self.group_by_list[head] / remainder[head]
            qk = torch.einsum("bqe, bke -> bqk", _query, _key)
            qk = qk / math.sqrt(queries.shape[-1])
            qk = torch.repeat_interleave(qk, self.group_by_list[head], dim = 1)
            qk = torch.repeat_interleave(qk, self.group_by_list[head], dim = 2)
            attn[:, head, :, :] = qk[:, :queries.shape[2], :keys.shape[2]]
        attn = nn.functional.softmax(attn, dim = -1)
        X = torch.matmul(attn, values)
        # [bsz, n_head, seq_len, dim]
        return X

class MRA2_Attention(nn.Module):
    def __init__(self, attention_dropout=0.1, mode='sparse'):
        super(MRA2_Attention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        # self.group_by_list = [1, 1, 2, 2, 4, 4, 8, 8]
        self.group_by_list = [32] * 8
        self.approx_mode = mode

    def kernel_method(self, x):
        return torch.sigmoid(x)

    def forward(self, queries, keys, values):
        # [bsz, seq_len, no_heads, hidden_dim]
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        # [bsz, no_heads, seq_len, hidden_dim]
        seq_len = queries.shape[2]
        divisor, remainder = divmod(seq_len, 32)
        mask = torch.ones([queries.shape[0], seq_len], device = queries.device)

        if remainder > 0:
            queries = F.pad(queries, (0, 0, 0, 32 - remainder), 'constant', 0)
            keys = F.pad(keys, (0, 0, 0, 32 - remainder), 'constant', 0)
            values = F.pad(values, (0, 0, 0, 32 -remainder), 'constant', 0)
            mask = F.pad(mask, (0, 32 - remainder), 'constant', 0)
        num_block = (seq_len // 32) * math.ceil(seq_len/(32*4.0)) # Finely calculate 25% of attention matrix based on original implementation
        # [bsz, num_heads, seq_len, hidden_dim]
        # num_block = (queries.shape[2] // 32) * math.ceil(queries.shape[2]/(32*4.0)) # Finely calculate 25% of attention matrix based on original implementation
        
        # mask: [bsz, seq_len]
        with torch.cuda.amp.autocast(enabled = False):
            attn_out = mra2_attention(
                queries.float(), keys.float(), values.float(), mask.float(), num_block,
                approx_mode = self.approx_mode,
                initial_prior_first_n_blocks = 0,
                initial_prior_diagonal_n_blocks = 0
            )
        return attn_out[:, :, :seq_len, :]

class Linear_transfomrer_Attention(nn.Module):
    def __init__(self, attention_dropout=0.1):
        super(Linear_transfomrer_Attention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, Q, K, V):
        # [bsz, seq_len, no_heads, hidden_dim]
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        Q = (nn.functional.elu(Q) + 1) / math.sqrt(math.sqrt(Q.size(2)))
        K = (nn.functional.elu(K) + 1) / math.sqrt(math.sqrt(K.size(2)))
        X = torch.matmul(Q, torch.matmul(torch.transpose(K, -2, -1), V))
        return X

class FMM_transfomrer_Attention(nn.Module):
    def __init__(self, attention_dropout, head_dim, diag_size, num_head, kernels, sparse_ratio):
        super(FMM_transfomrer_Attention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.head_dim = head_dim
        self.diag_size = diag_size
        self.num_head = num_head
        self.sparse_ratio = sparse_ratio
        
        self.kernels = kernels
        self.rank_k = len(self.kernels)
        
        if self.sparse_ratio < 1.0: # e.g. sparse_ratio = 0.5
            self.type_blend = 0
        elif self.sparse_ratio < 2.0: # e.g. sparse_ratio = 1.5
            self.type_blend = 1
        elif self.sparse_ratio < 3.0: # e.g. sparse_ratio = 2.5
            self.type_blend = 2
        elif self.sparse_ratio < 4.0:  # e.g. sparse_ratio = 3.5
            self.type_blend = 3
        elif self.sparse_ratio < 5.0: # e.g. sparse_ratio = 4.5
            self.type_blend = 4 
        elif self.sparse_ratio < 6.0: # e.g. sparse_ratio = 5.5
            self.type_blend = 5 
        elif self.sparse_ratio < 7.0: # e.g. sparse_ratio = 6.5
            self.type_blend = 6 # sparse only
        elif self.sparse_ratio < 8.0: # e.g. sparse_ratio = 7.5
            self.type_blend = 7 # lowrank only 
        elif self.sparse_ratio < 9.0: # e.g. sparse_ratio = 8.5
            self.type_blend = 8 
        elif self.sparse_ratio < 10.0: # e.g. sparse_ratio = 9.5
            self.type_blend = 9 
        elif self.sparse_ratio < 11.0: # e.g. sparse_ratio = 10.5
            self.type_blend = 10 
        elif self.sparse_ratio < 12.0: # e.g. sparse_ratio = 11.5
            self.type_blend = 11 
        elif self.sparse_ratio < 13.0: # e.g. sparse_ratio = 12.5
            self.type_blend = 12 
        elif self.sparse_ratio < 14.0: # e.g. sparse_ratio = 13.5
            self.type_blend = 13
        elif self.sparse_ratio < 15.0: # e.g. sparse_ratio = 14.5
            self.type_blend = 14 
        elif self.sparse_ratio < 16.0: # e.g. sparse_ratio = 15.5
            self.type_blend = 15 
        elif self.sparse_ratio < 17.0: # e.g. sparse_ratio = 16.5
            self.type_blend = 16
        elif self.sparse_ratio < 18.0: # e.g. sparse_ratio = 17.5
            self.type_blend = 17
        elif self.sparse_ratio < 19.0: # e.g. sparse_ratio = 18.5
            self.type_blend = 18
        elif self.sparse_ratio < 20.0: # e.g. sparse_ratio = 19.5
            self.type_blend = 19
            
        if self.type_blend == 0:
            self.sparse_ratio = nn.Parameter(torch.Tensor([0.5]))
        elif self.type_blend == 1:
            self.sparse_ratio = nn.Parameter(torch.Tensor([0.5]))
        elif self.type_blend == 2:
            self.sparse_ratio = nn.Parameter(torch.Tensor([0.5]))
        elif self.type_blend == 3:
            self.sparse_ratio = nn.Parameter(torch.zeros(1, self.num_head, 1, self.head_dim))
        elif self.type_blend == 4:  
            self.sparse_ratio = nn.Parameter(torch.zeros(1, self.num_head, 1, self.head_dim))
        elif self.type_blend == 5:  
            self.sparse_ratio = nn.Parameter(torch.zeros(1, self.num_head, 1, self.head_dim))
        elif self.type_blend == 6:
            self.sparse_ratio = 1.0 # sparse only
        elif self.type_blend == 7:
            self.sparse_ratio = 0.0 # lowrank only
        elif self.type_blend == 8:
            self.sparse_ratio = nn.Parameter(torch.zeros(1, self.num_head, 1, self.head_dim))
            self.lowrank_ratio = nn.Parameter(torch.Tensor([0.5]))
            self.lowrank_ratio2 = nn.Parameter(torch.Tensor([0.5]))
        elif self.type_blend == 9:
            self.sparse_ratio = nn.Parameter(torch.zeros(1, self.num_head, 1, self.head_dim))
            self.lowrank_ratio = nn.Parameter(torch.Tensor([0.5]))
            self.lowrank_ratio2 = nn.Parameter(torch.Tensor([0.5]))
        elif self.type_blend == 10:
            self.sparse_ratio = nn.Parameter(torch.zeros(1, self.num_head, 1, 1))
            self.sparse_ratio2 = nn.Parameter(torch.ones(1, self.num_head, 1, 1))
        elif self.type_blend == 11:
            self.sparse_ratio = nn.Parameter(torch.ones(1, self.num_head, 1, 1))
            self.sparse_ratio2 = nn.Parameter(torch.zeros(1, self.num_head, 1, 1))
        elif self.type_blend == 12:
            self.sparse_ratio = nn.Parameter(torch.zeros(1, self.num_head, 1, 1))
        elif self.type_blend == 13:
            self.sparse_ratio = nn.Parameter(torch.ones(1, self.num_head, 1, 1))
            self.sparse_ratio2 = nn.Parameter(torch.zeros(1, self.num_head, 1, 1))
            self.lowrank_ratio = nn.Parameter(torch.Tensor([1.0]))
            self.lowrank_ratio2 = nn.Parameter(torch.Tensor([0.0]))
        elif self.type_blend == 14:
            self.sparse_ratio = nn.Parameter(torch.Tensor([0.5]))
            self.lowrank_ratio = nn.Parameter(torch.Tensor([1.0]))
            self.lowrank_ratio2 = nn.Parameter(torch.Tensor([0.0]))
        elif self.type_blend == 15:
            self.sparse_ratio = nn.Parameter(torch.Tensor([0.5]))
            self.lowrank_ratio = nn.Parameter(torch.Tensor([0.0]))
        elif self.type_blend == 16:  
            self.sparse_ratio = nn.Parameter(torch.zeros(1, self.num_head, 1, self.head_dim))
            self.lowrank_ratio = nn.Parameter(torch.Tensor([1.0]))
            self.lowrank_ratio2 = nn.Parameter(torch.Tensor([0.0]))
        elif self.type_blend == 17:  
            self.sparse_ratio = nn.Parameter(torch.zeros(1, self.num_head, 1, self.head_dim))
            self.lowrank_ratio = nn.Parameter(torch.Tensor([0.0]))
        elif self.type_blend == 18:
            self.sparse_ratio = nn.Parameter(torch.zeros(1, self.num_head, 1, 1))
        elif self.type_blend == 19:
            self.sparse_ratio = nn.Parameter(torch.zeros(1, self.num_head, 1, 1))
            self.lowrank_ratio = nn.Parameter(torch.Tensor([1.0]))
            self.lowrank_ratio2 = nn.Parameter(torch.Tensor([0.0]))

    def forward(self, Q, K, V):
        # [bsz, seq_len, no_heads, hidden_dim]
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        # [bsz, no_heads, seq_len, hidden_dim]
        mask = None
        if self.type_blend != 7:
            attn_vec_sparse = self._forward_sparse(Q, K, V, mask) # batchsize x num_head x seq.len. x head_dim
        
        if self.type_blend != 6:
            attn_vec_lowrank = self._forward_lowrank(Q, K, V, mask) # batchsize x num_head x seq.len. x head_dim
            
        if self.type_blend == 0:
            X = self.sparse_ratio * attn_vec_sparse + attn_vec_lowrank
        elif self.type_blend == 1:
            X = attn_vec_sparse + self.sparse_ratio * attn_vec_lowrank
        elif self.type_blend == 2:
            X = self.sparse_ratio * attn_vec_sparse + (1.0 - self.sparse_ratio) * attn_vec_lowrank
        elif self.type_blend == 3:
            X = self.sparse_ratio * attn_vec_sparse + attn_vec_lowrank
        elif self.type_blend == 4:
            X = attn_vec_sparse + self.sparse_ratio * attn_vec_lowrank
        elif self.type_blend == 5:
            X = self.sparse_ratio * attn_vec_sparse + (1.0 - self.sparse_ratio) * attn_vec_lowrank
        elif self.type_blend == 6:
            X = attn_vec_sparse
        elif self.type_blend == 7:
            X = attn_vec_lowrank
        elif self.type_blend == 8:
            X = self.sparse_ratio * attn_vec_sparse + attn_vec_lowrank
        elif self.type_blend == 9:
            X = attn_vec_sparse + self.sparse_ratio * attn_vec_lowrank
        elif self.type_blend == 10:
            X = self.sparse_ratio * attn_vec_sparse + self.sparse_ratio2 * attn_vec_lowrank
        elif self.type_blend == 11:
            X = self.sparse_ratio * attn_vec_sparse + self.sparse_ratio2 * attn_vec_lowrank
        elif self.type_blend == 12:
            X = attn_vec_sparse + self.sparse_ratio * attn_vec_lowrank
        elif self.type_blend == 13:
            X = self.sparse_ratio * attn_vec_sparse + self.sparse_ratio2 * attn_vec_lowrank
        elif self.type_blend == 14:
            X = self.sparse_ratio * attn_vec_sparse + (1.0 - self.sparse_ratio) * attn_vec_lowrank
        elif self.type_blend == 15:
            X = self.sparse_ratio * attn_vec_sparse + (1.0 - self.sparse_ratio) * attn_vec_lowrank
        elif self.type_blend == 16:
            X = attn_vec_sparse + self.sparse_ratio * attn_vec_lowrank
        elif self.type_blend == 17:
            X = attn_vec_sparse + self.sparse_ratio * attn_vec_lowrank
        elif self.type_blend == 18:
            X = self.sparse_ratio * attn_vec_sparse + attn_vec_lowrank
        elif self.type_blend == 19:
            X = self.sparse_ratio * attn_vec_sparse + attn_vec_lowrank
        
        return X

    def _project_features(self, features, kernel_name):
        if kernel_name == 'elu':
            out = F.elu(features, 1., False) + 1.
        elif kernel_name == 'tanh':
            out = F.tanh(features) + 1.
        elif kernel_name == 'relu':
            out = F.relu(features, False)
        elif kernel_name == 'celu':
            out = F.celu(features, 1., False) + 1.
        elif kernel_name == 'sigmoid':
            out = F.sigmoid(features)
        elif kernel_name == 'leaky_relu':
            out = F.leaky_relu(features) + 1.
        elif kernel_name == 'softplus':
            out = F.softplus(features)
        elif kernel_name == 'tanh_orthogonal':
            out = 1. - F.tanh(features)
        elif kernel_name == 'elu_flip':
            out = F.elu(-features, 1., False) + 1.
        else:
            out = features
            
        return out
            
    def _forward_lowrank(self, Q, K, V, mask):
        if mask is not None:
            V = V * mask[:, None, :, None]
        
        Q1 = self._project_features(Q, kernel_name=self.kernels[0]) / math.sqrt(math.sqrt(Q.size(2)))
        if mask is not None:
            K1 = self._project_features(K, kernel_name=self.kernels[0]) * mask[:, None, :, None] / math.sqrt(math.sqrt(K.size(2)))
        else:
            K1 = self._project_features(K, kernel_name=self.kernels[0]) / math.sqrt(math.sqrt(K.size(2)))
        X = torch.matmul(Q1, torch.matmul(torch.transpose(K1, -2, -1), V))
        
        if self.rank_k > 1:
            Q2 = self._project_features(Q, kernel_name=self.kernels[1]) / math.sqrt(math.sqrt(Q.size(2)))
            if mask is not None:
                K2 = self._project_features(K, kernel_name=self.kernels[1]) * mask[:, None, :, None] / math.sqrt(math.sqrt(K.size(2)))
            else:
                K2 = self._project_features(K, kernel_name=self.kernels[1]) / math.sqrt(math.sqrt(K.size(2)))
            X2 = torch.matmul(Q2, torch.matmul(torch.transpose(K2, -2, -1), V))
            
            if self.type_blend == 8:
                X = self.lowrank_ratio * X + self.lowrank_ratio2 * X2
            elif self.type_blend == 9:
                X = self.lowrank_ratio * X + self.lowrank_ratio2 * X2
            elif self.type_blend == 13:
                X = self.lowrank_ratio * X + self.lowrank_ratio2 * X2
            elif self.type_blend == 14:
                X = self.lowrank_ratio * X + self.lowrank_ratio2 * X2
            elif self.type_blend == 15:
                X = X + self.lowrank_ratio * X2
            elif self.type_blend == 16:
                X = self.lowrank_ratio * X + self.lowrank_ratio2 * X2
            elif self.type_blend == 17:
                X = X + self.lowrank_ratio * X2
            elif self.type_blend == 19:
                X = self.lowrank_ratio * X + self.lowrank_ratio2 * X2
            else:
                X = 0.5 * X + 0.5 * X2

        return X
    
    def _forward_sparse(self, Q, K, V, mask):
        dot = torch.matmul(Q, torch.transpose(K, -2, -1))
        dot = dot / math.sqrt(self.head_dim)
        if mask is not None:
            dot = dot - 1e6 * (1 - mask[:, None, None, :])
        
        #### Computing masked softmax(sparse) ####
        bsz, n_head, qlen, klen = dot.shape
        
        sparse_mask = torch.ones(qlen, klen).to(dot)
        sparse_mask = torch.tril(sparse_mask, diagonal=-self.diag_size) + torch.triu(sparse_mask, diagonal=self.diag_size)
        sparse_mask = sparse_mask.to(torch.bool)
        
        dot.masked_fill_(sparse_mask[None, None, :, :], -float('inf'))
        
        attn = nn.functional.softmax(dot, dim = -1)
        attn = self.dropout(attn)

        X = torch.matmul(attn, V)
        return X

class Performer_Attention(nn.Module):
    def __init__(self, attention_dropout, head_dim, rp_dim):
        super(Performer_Attention, self).__init__()
        self.head_dim = head_dim
        self.rp_dim = rp_dim
        self.dropout = nn.Dropout(attention_dropout)
        self.attn_fn = FastAttention(dim_heads = self.head_dim, nb_features = self.rp_dim, causal = False, kernel_fn = torch.exp)

    def forward(self, Q, K, V):
        # [bsz, seq_len, no_heads, hidden_dim]
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        X = self.attn_fn(Q / math.sqrt(math.sqrt(self.head_dim)), K / math.sqrt(math.sqrt(self.head_dim)), V)
        return X

class Linformer_Attention(nn.Module):
    def __init__(self, attention_dropout, num_head, head_dim, linformer_k, max_seq_len):
        super(Linformer_Attention, self).__init__()
        self.head_dim = head_dim
        self.num_head = num_head
        self.linformer_k = linformer_k
        self.max_seq_len = max_seq_len
        self.dropout = nn.Dropout(attention_dropout)
        projection_matrix = nn.Parameter(torch.randn(self.num_head, self.linformer_k, self.max_seq_len) / math.sqrt(self.linformer_k))
        self.E = projection_matrix

    def forward(self, Q, K, V):
        # [bsz, seq_len, no_heads, hidden_dim]
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        # [bsz, no_heads, seq_len, hidden_dim]
        bsz, n_heads, seq_len, hidden_dim = Q.shape
        mask = torch.ones([Q.shape[0], Q.shape[2]], device = Q.device)
        mask = F.pad(mask, (0, self.max_seq_len - Q.shape[2]), 'constant', 0)
        Q = F.pad(Q, (0, 0, 0, self.max_seq_len - Q.shape[2]), 'constant', 0)
        K = F.pad(K, (0, 0, 0, self.max_seq_len - K.shape[2]), 'constant', 0)
        V = F.pad(V, (0, 0, 0, self.max_seq_len - V.shape[2]), 'constant', 0)

        K = torch.matmul(self.E, K * mask[:, None, :, None])
        V = torch.matmul(self.E, V * mask[:, None, :, None])

        dot = torch.matmul(Q, torch.transpose(K, -2, -1))
        dot = dot / math.sqrt(self.head_dim)

        attn = nn.functional.softmax(dot, dim = -1)

        X = torch.matmul(attn, V)

        return X[:, :seq_len, :]

class Softmax_Attention(nn.Module):
    def __init__(self, attention_dropout, head_dim):
        super(Softmax_Attention, self).__init__()
        self.head_dim = head_dim
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, Q, K, V):
        # [bsz, no_heads, seq_len, hidden_dim]
        dot = torch.matmul(Q, torch.transpose(K, -2, -1))
        dot = dot / math.sqrt(self.head_dim)
        attn = nn.functional.softmax(dot, dim = -1)
        attn = self.dropout(attn)

        X = torch.matmul(attn, V)
        return X

class MultiheadAttentionLayer(nn.Module):
    def __init__(self, d_model, drop_out, n_heads, d_keys=None,
                 d_values=None):
        super(MultiheadAttentionLayer, self).__init__()
        self.d_keys = d_keys or (d_model // n_heads)
        self.d_values = d_values or (d_model // n_heads)
        self.inner_attention = Softmax_Attention(attention_dropout=drop_out, head_dim=d_model // n_heads)
        self.query_projection = nn.Linear(d_model, self.d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, self.d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, self.d_values * n_heads)
        self.out_projection = nn.Linear(self.d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1).transpose(1, 2)
        keys = self.key_projection(keys).view(B, S, H, -1).transpose(1, 2)
        values = self.value_projection(values).view(B, S, H, -1).transpose(1, 2)
        # [bsz, n_heads, seq_len, hidden_dim]
        context = self.inner_attention(
            queries,
            keys,
            values
        )
        context = context.transpose(1, 2).reshape(B, -1, self.n_heads * self.d_values)

        return self.out_projection(context)

class MrsMultiheadAttentionLayer(nn.Module):
    def __init__(self, d_model, drop_out, n_heads, d_keys=None,
                 d_values=None, downsampling_mode=None):
        super(MrsMultiheadAttentionLayer, self).__init__()
        self.d_keys = d_keys or (d_model // n_heads)
        self.d_values = d_values or (d_model // n_heads)
        self.inner_attention = Softmax_Attention(attention_dropout=drop_out, head_dim=d_model // n_heads)
        self.query_projection = nn.Linear(d_model, self.d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, self.d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, self.d_values * n_heads)
        self.out_projection = nn.Linear(self.d_values * n_heads, d_model)
        self.n_heads = n_heads
        self.group_by_list = [1, 1, 2, 2, 4, 4, 8, 8]
        self.downsampling_mode = downsampling_mode

    def forward(self, queries, keys, values):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1).transpose(1, 2)
        keys = self.key_projection(keys).view(B, S, H, -1).transpose(1, 2)
        values = self.value_projection(values).view(B, S, H, -1).transpose(1, 2)
        # [bsz, n_heads, seq_len, hidden_dim]
        context = torch.empty(B, self.n_heads, L, self.d_values, device=queries.device)
        for h in range(self.n_heads):
            if self.downsampling_mode == "q":
                _queries = F.avg_pool1d(queries[:, h, :, :].transpose(-1, -2), kernel_size=self.group_by_list[h], stride=self.group_by_list[h], ceil_mode=True, count_include_pad=False).transpose(-1, -2)
                _keys = keys[:, h, :, :]
                _values = values[:, h, :, :]
            elif self.downsampling_mode == "kv":
                _queries = queries[:, h, :, :]
                _keys = F.avg_pool1d(keys[:, h, :, :].transpose(-1, -2), kernel_size=self.group_by_list[h], stride=self.group_by_list[h], ceil_mode=True, count_include_pad=False).transpose(-1, -2)
                _values = F.avg_pool1d(values[:, h, :, :].transpose(-1, -2), kernel_size=self.group_by_list[h], stride=self.group_by_list[h], ceil_mode=True, count_include_pad=False).transpose(-1, -2)
            context_ = self.inner_attention(
                _queries.unsqueeze(1),
                _keys.unsqueeze(1),
                _values.unsqueeze(1)
            )
            if self.downsampling_mode == "q":
                context[:, h, :, :] = context_.squeeze().repeat_interleave(self.group_by_list[h], dim=1)[:, :L, :]
            else:
                context[:, h, :, :] = context_.squeeze()

        context = context.transpose(1, 2).reshape(B, -1, self.n_heads * self.d_values)

        return self.out_projection(context)