from .Abstract import *
from .utils import SchNetFilter


class DistanceAngleMessage(Message):
    ESP = 1e-6
    DIS_DIM = 16
    ANGLE_DIM = 16

    def __init__(self, *args, **kwargs):
        super(DistanceAngleMessage, self).__init__(*args, **kwargs)
        self.schnet_filter = SchNetFilter(self.DIS_DIM, use_cuda=self.use_cuda)
        self.angle_encode = nn.Linear(1, self.ANGLE_DIM)

        self.attend = nn.Linear(3 * self.hv_dim + 2 * self.DIS_DIM + self.ANGLE_DIM, self.mv_dim, bias=False)
        self.at_act = nn.ReLU()
        self.align = nn.Linear(self.he_dim, 1)
        self.al_act = nn.Softmax(dim=-1)
        self.ag_act = nn.ELU()
        self.link = nn.Linear(self.hv_dim + self.DIS_DIM + self.hv_dim, self.me_dim)
        self.l_act = nn.LeakyReLU()

    def forward(self, hv_ftr: torch.Tensor, he_ftr: torch.Tensor, pos_ftr: torch.Tensor,
                mask_matrices: MaskMatrices, return_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
        hv_ftr, he_ftr = self.dropout(hv_ftr), self.dropout(he_ftr)
        n_edge = mask_matrices.vertex_edge_w1.shape[1]
        vew1 = mask_matrices.vertex_edge_w1  # shape [n_vertex, n_edge]
        vew2 = mask_matrices.vertex_edge_w2  # shape [n_vertex, n_edge]
        veb1 = mask_matrices.vertex_edge_b1  # shape [n_vertex, n_edge]
        veb2 = mask_matrices.vertex_edge_b2  # shape [n_vertex, n_edge]
        vew_u = torch.cat([vew1, vew2], dim=1)  # shape [n_vertex, 2 * n_edge]
        vew_v = torch.cat([vew2, vew1], dim=1)  # shape [n_vertex, 2 * n_edge]
        veb_v = torch.cat([veb2, veb1], dim=1)  # shape [n_vertex, 2 * n_edge]
        he2_ftr = torch.cat([he_ftr, he_ftr])  # shape [2 * n_edge, he_dim]
        ee = vew_u.t() @ vew_u  # shape [2 * n_edge, 2 * n_edge]
        ee1 = ee.unsqueeze(-1)  # shape [2 * n_edge, 2 * n_edge, 1]
        hv_v_ftr = vew_v.t() @ hv_ftr  # shape [2 * n_edge, hv_dim]
        hv_u_ftr = vew_u.t() @ hv_ftr  # shape [2 * n_edge, hv_dim]
        q_u_ftr = vew_u.t() @ pos_ftr  # shape [2 * n_edge, pos_dim]
        q_v_ftr = vew_v.t() @ pos_ftr  # shape [2 * n_edge, pos_dim]
        q_uv_ftr = q_v_ftr - q_u_ftr  # shape [2 * n_edge, pos_dim]
        dis_uv = torch.norm(q_uv_ftr, dim=1, keepdim=True) + self.ESP  # shape [2 * n_edge, 1]
        norm_dis_uv = q_uv_ftr / dis_uv  # shape [2 * n_edge, pos_dim]
        angle_ee = norm_dis_uv @ norm_dis_uv.t()  # shape [2 * n_edge, 2 * n_edge]

        dis_ftr = self.schnet_filter(dis_uv)  # shape [2 * n_edge, dis_dim]
        angle_ftr = torch.tanh(self.angle_encode(angle_ee.unsqueeze(dim=-1)))  # shape [2 * n_edge, 2 * n_edge, angle_dim]
        vd_ftr = torch.cat([hv_v_ftr, dis_ftr], dim=1)  # shape [2 * n_edge, hv_dim + dis_dim]
        v1e1ue2v2 = ee1 * torch.cat([
            torch.cat([vd_ftr, hv_u_ftr], dim=1).repeat([n_edge * 2, 1, 1]).transpose(0, 1),
            vd_ftr.repeat([n_edge * 2, 1, 1]),
            angle_ftr
        ], dim=2)  # shape [2 * n_edge, 2 * n_edge, 3 * hv_dim + 2 * dis_dim + angle_dim]

        attend_ftr = self.attend(v1e1ue2v2)  # shape [2 * n_edge, 2 * n_edge, mv_dim]
        attend_ftr = self.at_act(attend_ftr)
        attend_ftr = torch.max(attend_ftr, dim=1)[0]  # shape [2 * n_edge, mv_dim]
        align_ftr = self.align(he2_ftr)  # shape [2 * n_edge, 1]
        align_ftr = vew_v @ torch.diag(torch.reshape(align_ftr, [-1])) + veb_v  # shape [n_vertex, 2 * n_edge]
        align_ftr = self.al_act(align_ftr)
        mv_ftr = self.ag_act(align_ftr @ attend_ftr)  # shape [n_vertex, mv_dim]

        me2_ftr = self.link(torch.cat([hv_u_ftr, dis_ftr, hv_v_ftr], dim=1))  # shape [2 * n_edge, me_dim]
        me_ftr = me2_ftr[:n_edge, :] + me2_ftr[n_edge:, :]  # shape [n_edge, me_dim]
        me_ftr = self.l_act(me_ftr)

        return_dict = {}
        if 'alignment' in return_list:
            return_dict['alignment'] = align_ftr.cpu().detach().numpy()

        return mv_ftr, me_ftr, return_dict
