from utils import *
import pdb


class MultiHeadAttention(nn.Module):
    
    def __init__(self, d_model, num_heads, dropout=0., gain=1.,composev=False):
        super().__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        
        self.attn_dropout = nn.Dropout(dropout)
        self.output_dropout = nn.Dropout(dropout)
        
        self.proj_q = linear(d_model, d_model, bias=False)
        self.proj_k = linear(d_model, d_model, bias=False)
        self.proj_v = linear(d_model, d_model, bias=False)
        self.proj_o = linear(d_model, d_model, bias=False, gain=gain)

        

        self.composev=composev
    
    
    def forward(self, q, k, v, attn_mask=None):
        """
        q: batch_size x target_len x d_model
        k: batch_size x source_len x d_model
        v: batch_size x source_len x d_model
        attn_mask: target_len x source_len
        return: batch_size x target_len x d_model
        """

        if self.composev:
            B, S, T, _ = q.shape
            
            q = self.proj_q(q).view(B, S, T, self.num_heads, -1).transpose(2, 3) #B,S,num_heads,T,d_model//num_heads
            k = self.proj_k(k).view(B, S, self.num_heads, -1).transpose(1, 2).unsqueeze(1) #B,1,num_heads,S,d_model//num_heads
            v = self.proj_v(v).view(B, S, self.num_heads, -1).transpose(1, 2).unsqueeze(2) #B,num_heads,1,S,d_model//num_heads
            
            q = q * (q.shape[-1] ** (-0.5))
            attn = torch.matmul(q, k.transpose(-1, -2)) #B,S,num_heads,T,S
            attn = F.softmax(attn, dim=-1)
            attn = self.attn_dropout(attn)
            attn=attn.permute(0,2,3,1,4) #B,num_heads,T,S,S
            
            ones=torch.eye(30).cuda()
            attn=attn*ones[:S,:S]
            
            
            output = torch.matmul(attn, v).transpose(1, 3).reshape(B,S, T, -1)
            output = self.proj_o(output)
            output = self.output_dropout(output)
        else:
            B, T, _ = q.shape
            _, S, _ = k.shape
            
            q = self.proj_q(q).view(B, T, self.num_heads, -1).transpose(1, 2) #B,num_heads,T,d_model//num_heads
            k = self.proj_k(k).view(B, S, self.num_heads, -1).transpose(1, 2) #B,num_heads,S,d_model//num_heads
            v = self.proj_v(v).view(B, S, self.num_heads, -1).transpose(1, 2) #B,num_heads,S,d_model//num_heads
            
            q = q * (q.shape[-1] ** (-0.5))
            attn = torch.matmul(q, k.transpose(-1, -2)) #B,num_heads,T,S
            
            attn = F.softmax(attn, dim=-1)
            attn = self.attn_dropout(attn)
            
            output = torch.matmul(attn, v).transpose(1, 2).reshape(B, T, -1)
            output = self.proj_o(output)
            output = self.output_dropout(output)

        return output


class TransformerEncoderBlock(nn.Module):
    
    def __init__(self, d_model, num_heads, dropout=0., gain=1., is_first=False):
        super().__init__()
        
        self.is_first = is_first
        
        self.attn_layer_norm = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads, dropout, gain)
        
        self.ffn_layer_norm = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            linear(d_model, 4 * d_model, weight_init='kaiming'),
            nn.ReLU(),
            linear(4 * d_model, d_model, gain=gain),
            nn.Dropout(dropout))
    
    
    def forward(self, input):
        """
        input: batch_size x source_len x d_model
        return: batch_size x source_len x d_model
        """
        if self.is_first:
            input = self.attn_layer_norm(input)
            x = self.attn(input, input, input)
            input = input + x
        else:
            x = self.attn_layer_norm(input)
            x = self.attn(x, x, x)
            input = input + x
        
        x = self.ffn_layer_norm(input)
        x = self.ffn(x)
        return input + x


class TransformerEncoder(nn.Module):
    
    def __init__(self, num_blocks, d_model, num_heads, dropout=0.):
        super().__init__()
        
        if num_blocks > 0:
            gain = (2 * num_blocks) ** (-0.5)
            self.blocks = nn.ModuleList(
                [TransformerEncoderBlock(d_model, num_heads, dropout, gain, is_first=True)] +
                [TransformerEncoderBlock(d_model, num_heads, dropout, gain, is_first=False)
                 for _ in range(num_blocks - 1)])
        else:
            self.blocks = nn.ModuleList()
        
        self.layer_norm = nn.LayerNorm(d_model)
    
    
    def forward(self, input):
        """
        input: batch_size x source_len x d_model
        return: batch_size x source_len x d_model
        """
        for block in self.blocks:
            input = block(input)
        
        return self.layer_norm(input)


class TransformerDecoderBlock(nn.Module):
    
    def __init__(self, max_len, d_model, num_heads, dropout=0., gain=1., is_first=False,composev=False):
        super().__init__()
        
        self.is_first = is_first
        self.composev=composev
        
        self.self_attn_layer_norm = nn.LayerNorm(d_model)
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout, gain)
        
        mask = torch.triu(torch.ones((max_len, max_len), dtype=torch.bool), diagonal=1)
        self.self_attn_mask = nn.Parameter(mask, requires_grad=False)
        
        self.encoder_decoder_attn_layer_norm = nn.LayerNorm(d_model)
        self.encoder_decoder_attn = MultiHeadAttention(d_model, num_heads, dropout, gain,composev)
        
        self.ffn_layer_norm = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            linear(d_model, 4 * d_model, weight_init='kaiming'),
            nn.ReLU(),
            linear(4 * d_model, d_model, gain=gain),
            nn.Dropout(dropout))
    
    
    def forward(self, input, encoder_output):
        """
        input: batch_size x target_len x d_model
        encoder_output: batch_size x source_len x d_model
        return: batch_size x target_len x d_model
        """
        if self.composev:
            B,N,T,C=input.shape
            input=input.flatten(end_dim=1)
        
        if self.is_first:
            input = self.self_attn_layer_norm(input)
            x = self.self_attn(input, input, input)
            input = input + x
        else:
            x = self.self_attn_layer_norm(input)
            x = self.self_attn(x, x, x)
            input = input + x
        
        if self.composev:
            input=input.reshape(B,N,T,C)

        x = self.encoder_decoder_attn_layer_norm(input)
        x = self.encoder_decoder_attn(x, encoder_output, encoder_output)
        input = input + x
        
        x = self.ffn_layer_norm(input)
        x = self.ffn(x)
        return input + x


class TransformerDecoder(nn.Module):
    
    def __init__(self, num_blocks, max_len, d_model, num_heads, dropout=0.,composev=False):
        super().__init__()
        
        if num_blocks > 0:
            gain = (3 * num_blocks) ** (-0.5)
            self.blocks = nn.ModuleList(
                [TransformerDecoderBlock(max_len, d_model, num_heads, dropout, gain, is_first=True,composev=composev)] +
                [TransformerDecoderBlock(max_len, d_model, num_heads, dropout, gain, is_first=False,composev=composev)
                 for _ in range(num_blocks - 1)])
        else:
            self.blocks = nn.ModuleList()
        
        self.layer_norm = nn.LayerNorm(d_model)
    
    
    def forward(self, input, encoder_output):
        """
        input: batch_size x target_len x d_model
        encoder_output: batch_size x source_len x d_model
        return: batch_size x target_len x d_model
        """
        for block in self.blocks:
            input = block(input, encoder_output)
        
        return self.layer_norm(input)
