import torch

from .smoothing import ApproxSmoothingLayer

# Implementation follows SimpleViT from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/simple_vit.py

class MultiheadSelfAttention(torch.nn.Module):
    def __init__(self, in_dim, att_dim, num_heads, dropout=0.0, smooth_steps=0):
        super(MultiheadSelfAttention, self).__init__()
        self.in_dim = in_dim
        self.att_dim = att_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.smooth_steps = smooth_steps
        self.head_dim = att_dim // num_heads
        self.qkv_nn = torch.nn.Linear(in_dim, 3 * att_dim, bias = False)
        self.to_out = torch.nn.Linear(att_dim, in_dim, bias = False)
        # self.qk_nn = torch.nn.Linear(in_dim, 2 * att_dim, bias = False)
        # self.v_nn = torch.nn.Linear(in_dim, in_dim, bias = False)

        if self.smooth_steps > 0:
            self.smooth_layer = ApproxSmoothingLayer(alpha='trainable', num_steps=self.smooth_steps)

    def _scaled_dot_product_attention(self, query, key, value, mask=None):
        """
        input:
            query: (batch_size, num_heads, seq_len, head_dim)
            key: (batch_size, num_heads, seq_len, head_dim)
            value: (batch_size, num_heads, seq_len, head_dim)
            mask: (batch_size, seq_len)
        output:
            x: (batch_size, num_heads, seq_len, head_dim)       
        """
        mask = mask.unsqueeze(1).unsqueeze(1) # (batch_size, 1, 1, seq_len)
        mask = mask.repeat(1, 1, query.size(2), 1).bool() # (batch_size, num_heads, seq_len, seq_len)
        with torch.backends.cuda.sdp_kernel(enable_math=True):
            att = torch.nn.functional.scaled_dot_product_attention(query, key, value, mask, self.dropout) # (batch_size, num_heads, seq_len, head_dim)
        # att = torch.nn.functional.scaled_dot_product_attention(query, key, value, mask, self.dropout) # (batch_size, num_heads, seq_len, head_dim)
        return att

    def _qkv(self, x):
        """
        input:
            x: (batch_size, seq_len, in_dim)
        output:
            query: (batch_size, seq_len, att_dim)
            key: (batch_size, seq_len, att_dim)
            value: (batch_size, seq_len, att_dim)        
        """
        q, k, v = self.qkv_nn(x).chunk(3, dim=-1) # (batch_size, seq_len, att_dim), (batch_size, seq_len, att_dim), (batch_size, seq_len, att_dim)
        # q, k = self.qk_nn(x).chunk(2, dim=-1) # (batch_size, seq_len, att_dim), (batch_size, seq_len, att_dim)
        # # v = self.v_nn(x) # (batch_size, seq_len, in_dim)
        # v = x
        return q, k, v

    def forward(self, x, adj_mat=None, mask=None):
        """
        input:
            x: (batch_size, seq_len, in_dim)
            adj_mat: sparse coo tensor (batch_size, seq_len, seq_len)
            mask: (batch_size, seq_len)
        output:
            y: (batch_size, seq_len, in_dim)
        """
        batch_size, seq_len, in_dim = x.size()
        query, key, value = self._qkv(x) # (batch_size, seq_len, att_dim), (batch_size, seq_len, att_dim), (batch_size, seq_len, att_dim)
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, head_dim)
        key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, head_dim)
        value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, head_dim)
        y = self._scaled_dot_product_attention(query, key, value, mask) # (batch_size, num_heads, seq_len, att_dim)
        y = y.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.att_dim) # (batch_size, seq_len, att_dim)
        
        if self.smooth_steps > 0:
            y = self.smooth_layer(y, adj_mat)        

        y = self.to_out(y) # (batch_size, seq_len, in_dim)
        return y
    
class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

class TransformerEncoderLayer(torch.nn.Module):
    def __init__(self, in_dim, att_dim, num_heads, use_ff=True, dropout=0.0, smooth_steps=0):
        super(TransformerEncoderLayer, self).__init__()
        self.in_dim = in_dim
        self.att_dim = att_dim
        self.num_heads = num_heads
        self.use_ff = use_ff
        self.dropout = dropout
        self.smooth_steps = smooth_steps
        
        self.mha_layer = MultiheadSelfAttention(in_dim, att_dim, num_heads, dropout=dropout, smooth_steps=smooth_steps)

        if self.use_ff:
            self.ff_layer = torch.nn.Sequential(
                torch.nn.Linear(in_dim, att_dim),
                torch.nn.GELU(),
                torch.nn.Linear(att_dim, in_dim)
                )
        else:
            self.ff_layer = Identity()
        
        self.norm = torch.nn.LayerNorm(in_dim)


    def forward(self, X, adj_mat=None, mask=None, **kwargs):
        """
        input:
            X: (batch_size, bag_size, in_dim)
            adj_mat: (batch_size, bag_size, bag_size)
            mask: (batch_size, bag_size)
        output:
            Y: (batch_size, bag_size, in_dim)
            
        """
        Y = self.norm(self.mha_layer(X, adj_mat=adj_mat, mask=mask) + X) # (batch_size, bag_size, in_dim)
        if self.use_ff:
            Y = self.norm(self.ff_layer(Y) + Y) # (batch_size, bag_size, in_dim)
        return Y

class TransformerEncoder(torch.nn.Module):
    def __init__(self, in_dim, att_dim, num_heads, num_layers, use_ff=False, dropout=0.0, smooth_steps=0):
        super(TransformerEncoder, self).__init__()
        self.in_dim = in_dim
        self.att_dim = att_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.use_ff = use_ff
        self.dropout = dropout
        self.smooth_steps = smooth_steps
        self.layers = torch.nn.ModuleList([TransformerEncoderLayer(in_dim, att_dim, num_heads, use_ff=use_ff, dropout=dropout, smooth_steps=smooth_steps) for _ in range(num_layers)])

    def forward(self, X, adj_mat=None, mask=None, **kwargs):
        """
        input:
            X: (batch_size, bag_size, in_dim)
            adj_mat: (batch_size, bag_size, bag_size)
            mask: (batch_size, bag_size)
        output:
            Y: (batch_size, bag_size, in_dim)
        """
        for layer in self.layers:
            X = layer(X, adj_mat=adj_mat, mask=mask, **kwargs) # (batch_size, bag_size, in_dim)
        return X