from .Abstract import *
from .utils import SchNetFilter


class DistanceMessage(Message):
    ESP = 1e-6
    DIS_DIM = 16

    def __init__(self, *args, **kwargs):
        super(DistanceMessage, self).__init__(*args, **kwargs)
        self.schnet_filter = SchNetFilter(self.DIS_DIM, use_cuda=self.use_cuda)

        self.attend = nn.Linear(2 * self.hv_dim, self.mv_dim, bias=False)
        self.at_act = nn.ReLU()
        self.align = nn.Linear(self.he_dim + self.DIS_DIM, 1)
        self.al_act = nn.Softmax(dim=-1)
        self.ag_act = nn.ELU()
        self.link = nn.Linear(self.hv_dim, self.me_dim)
        self.l_act = nn.LeakyReLU()

    def forward(self, hv_ftr: torch.Tensor, he_ftr: torch.Tensor, pos_ftr: torch.Tensor,
                mask_matrices: MaskMatrices, return_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
        hv_ftr, he_ftr = self.dropout(hv_ftr), self.dropout(he_ftr)
        n_edge = mask_matrices.vertex_edge_w1.shape[1]
        vew1 = mask_matrices.vertex_edge_w1  # shape [n_vertex, n_edge]
        vew2 = mask_matrices.vertex_edge_w2  # shape [n_vertex, n_edge]
        veb1 = mask_matrices.vertex_edge_b1  # shape [n_vertex, n_edge]
        veb2 = mask_matrices.vertex_edge_b2  # shape [n_vertex, n_edge]
        vew_u = torch.cat([vew1, vew2], dim=1)  # shape [n_vertex, 2 * n_edge]
        vew_v = torch.cat([vew2, vew1], dim=1)  # shape [n_vertex, 2 * n_edge]
        veb_v = torch.cat([veb2, veb1], dim=1)  # shape [n_vertex, 2 * n_edge]
        he2_ftr = torch.cat([he_ftr, he_ftr])  # shape [2 * n_edge, he_dim]
        hv_v_ftr = vew_v.t() @ hv_ftr  # shape [2 * n_edge, hv_dim]
        hv_u_ftr = vew_u.t() @ hv_ftr  # shape [2 * n_edge, hv_dim]
        q_u_ftr = vew_u.t() @ pos_ftr  # shape [2 * n_edge, pos_dim]
        q_v_ftr = vew_v.t() @ pos_ftr  # shape [2 * n_edge, pos_dim]
        q_uv_ftr = q_v_ftr - q_u_ftr  # shape [2 * n_edge, pos_dim]
        dis_uv = torch.norm(q_uv_ftr, dim=1, keepdim=True) + self.ESP  # shape [2 * n_edge, 1]

        dis_ftr = self.schnet_filter(dis_uv)  # shape [2 * n_edge, dis_dim]

        attend_ftr = self.attend(hv_v_ftr)  # shape [2 * n_edge, mv_dim]
        attend_ftr = self.at_act(attend_ftr)
        align_ftr = self.align(torch.cat([he2_ftr, dis_ftr], dim=1))  # shape [2 * n_edge, 1]
        align_ftr = vew_v @ torch.diag(torch.reshape(align_ftr, [-1])) + veb_v  # shape [n_vertex, 2 * n_edge]
        align_ftr = self.al_act(align_ftr)
        mv_ftr = self.ag_act(align_ftr @ attend_ftr)  # shape [n_vertex, mv_dim]

        me2_ftr = self.link(torch.cat([hv_u_ftr, hv_v_ftr], dim=1))  # shape [2 * n_edge, me_dim]
        me_ftr = me2_ftr[:n_edge, :] + me2_ftr[n_edge:, :]  # shape [n_edge, me_dim]
        me_ftr = self.l_act(me_ftr)

        return_dict = {}
        if 'alignment' in return_list:
            return_dict['alignment'] = align_ftr.cpu().detach().numpy()

        return mv_ftr, me_ftr, return_dict
