
import torch
from torch import nn
import torch.nn.functional as F

def get_torsionangle_layer(layer_type):
    if layer_type == 'torsion_attention':
        return TorsionAngleAttention
    else:
        raise ValueError(f'Invalid layer_type: {layer_type}')


class TorsionAngleAttention(nn.Module):
    def __init__(self,
                 angle_width            ,
                 num_heads             ,
                 attention_dropout = 0 ,
                 window_size = 32      ,
                 ):
        super().__init__()
        self.angle_width          = angle_width
        self.num_heads           = num_heads
        self.attention_dropout   = attention_dropout
        self.window_size         = window_size
        
        assert not (self.angle_width % self.num_heads),\
                'angle_width must be divisible by num_heads'
        self._dot_dim = self.angle_width//self.num_heads
        self._scale_factor = self._dot_dim ** -0.5
        
        self.tri_ln_p   = nn.LayerNorm(self.angle_width,eps=1e-5)
        # self.tor_ln_t = nn.LayerNorm(self.torsion_angle_width,eps=1e-5)
        
        self.lin_QKV_in = nn.Linear(self.angle_width, self.angle_width*3)
        self.norm_QKV_in = nn.LayerNorm(self.angle_width)
        self.lin_EG_in  = nn.Linear(self.angle_width, self.num_heads*2)
        self.norm_EG_in = nn.LayerNorm(self.num_heads)
        
        # self.lin_QKV_out = nn.Linear(self.angle_width, self.angle_width*3)
        # self.norm_QKV_out = nn.LayerNorm(self.angle_width)
        # self.lin_E_out  = nn.Linear(self.angle_width, self.num_heads)
        # self.lin_G_out  = nn.Linear(self.angle_width, self.num_heads)
        # self.norm_EG_out = nn.LayerNorm(self.num_heads)

        # self.lin_T_in = nn.Linear(self.torsion_angle_width, self.num_heads)
        # self.lin_T_out = nn.Linear(self.torsion_angle_width, self.num_heads)
        
        self.lin_O  = nn.Linear(self.angle_width, self.angle_width)
        self.tri_ln_o   = nn.LayerNorm(self.angle_width,eps=1e-5)

        nn.init.xavier_uniform_(self.lin_QKV_in.weight)
        nn.init.xavier_uniform_(self.lin_EG_in.weight)
        nn.init.xavier_uniform_(self.lin_O.weight)

    def _create_sliding_window_mask(self, seq_len, device):
        mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
        half_window = self.window_size // 2

        for i in range(seq_len):
            start = max(0, i - half_window)
            end = min(seq_len, i + half_window + 1)
            mask[i, start:end] = 0

        return mask

    def forward(self, p, mask):
        bsize, num_angles, embed_dim = p.shape
        
        # Projecting the input
        p_ln = self.tri_ln_p(p)
        # t_ln = self.tor_ln_t(torsion_embedding)
        
        # Projections
        Q_in, K_in, V_in = self.lin_QKV_in(p_ln).chunk(3, dim=-1)
        Q_in, K_in, V_in = self.norm_QKV_in(Q_in), self.norm_QKV_in(K_in), self.norm_QKV_in(V_in)
        E_in, G_in = self.lin_EG_in(p_ln).unsqueeze(2).chunk(2, dim=-1)
        E_in, G_in = self.norm_EG_in(E_in), self.norm_EG_in(G_in)
        # T_in = self.lin_T_in(t_ln)

        Q_in = Q_in.view(bsize, num_angles, self._dot_dim, self.num_heads)
        K_in = K_in.view(bsize, num_angles, self._dot_dim, self.num_heads)
        V_in = V_in.view(bsize, num_angles, self._dot_dim, self.num_heads)

        Q_in = Q_in * self._scale_factor
        H_in = torch.einsum('bidh,bjdh->bijh', Q_in, K_in) + E_in

        window_mask = self._create_sliding_window_mask(num_angles, H_in.device)

        mask_in = mask.unsqueeze(1)
        combined_mask = mask_in + window_mask.unsqueeze(0).unsqueeze(-1)
        gates_in = torch.sigmoid(G_in + mask_in)

        A_in = torch.softmax((H_in) + combined_mask, dim=-2) * gates_in
        if self.attention_dropout > 0:
            A_in = F.dropout(A_in, p=self.attention_dropout, training=self.training, inplace=True)

        Va = torch.einsum('bijh,bjdh->bidh', A_in, V_in).contiguous().view(bsize,num_angles,embed_dim)

        Va = self.lin_O(Va)
        p = self.tri_ln_o(Va)

        return p
