from .Abstract import *


class NaivePositionMessage(Message):
    def __init__(self, *args, **kwargs):
        super(NaivePositionMessage, self).__init__(*args, **kwargs)

        self.attend = nn.Linear(self.hv_dim, self.mv_dim)
        self.at_act = nn.LeakyReLU()
        self.align = nn.Linear(self.pos_dim + self.he_dim, 1)
        self.al_act = nn.Softmax(dim=-1)
        self.ag_act = nn.ELU()
        self.link = nn.Linear(self.hv_dim + self.pos_dim + self.hv_dim, self.me_dim)
        self.l_act = nn.LeakyReLU()

    def forward(self, hv_ftr: torch.Tensor, he_ftr: torch.Tensor, pos_ftr: torch.Tensor,
                mask_matrices: MaskMatrices, return_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
        hv_ftr, he_ftr = self.dropout(hv_ftr), self.dropout(he_ftr)
        n_edge = mask_matrices.vertex_edge_w1.shape[1]
        vew1 = mask_matrices.vertex_edge_w1  # shape [n_vertex, n_edge]
        vew2 = mask_matrices.vertex_edge_w2  # shape [n_vertex, n_edge]
        veb1 = mask_matrices.vertex_edge_b1  # shape [n_vertex, n_edge]
        veb2 = mask_matrices.vertex_edge_b2  # shape [n_vertex, n_edge]
        vew_u = torch.cat([vew1, vew2], dim=1)  # shape [n_vertex, 2 * n_edge]
        vew_v = torch.cat([vew2, vew1], dim=1)  # shape [n_vertex, 2 * n_edge]
        veb_v = torch.cat([veb2, veb1], dim=1)  # shape [n_vertex, 2 * n_edge]
        hv_u_ftr = vew_u.t() @ hv_ftr  # shape [2 * n_edge, hv_dim]
        hv_v_ftr = vew_v.t() @ hv_ftr  # shape [2 * n_edge, hv_dim]
        pos_u_ftr = vew_u.t() @ pos_ftr  # shape [2 * n_edge, pos_dim]
        pos_v_ftr = vew_v.t() @ pos_ftr  # shape [2 * n_edge, pos_dim]
        pos_uv_ftr = pos_v_ftr - pos_u_ftr  # shape [2 * n_edge, pos_dim]
        he2_ftr = torch.cat([he_ftr, he_ftr])  # shape [2 * n_edge, he_dim]

        attend_ftr = self.attend(hv_v_ftr)  # shape [2 * n_edge, mv_dim]
        attend_ftr = self.at_act(attend_ftr)
        align_ftr = self.align(torch.cat([pos_uv_ftr, he2_ftr], dim=1))  # shape [2 * n_edge, 1]
        align_ftr = vew_v @ torch.diag(torch.reshape(align_ftr, [-1])) + veb_v  # shape [n_vertex, 2 * n_edge]
        align_ftr = self.al_act(align_ftr)
        mv_ftr = self.ag_act(align_ftr @ attend_ftr)  # shape [n_vertex, mv_dim]

        me2_ftr = self.link(torch.cat([hv_u_ftr, pos_uv_ftr, hv_v_ftr], dim=1))  # shape [2 * n_edge, me_dim]
        me_ftr = me2_ftr[:n_edge, :] + me2_ftr[n_edge:, :]  # shape [n_edge, me_dim]
        me_ftr = self.l_act(me_ftr)

        return_dict = {}
        if 'alignment' in return_list:
            return_dict['alignment'] = align_ftr.cpu().detach().numpy()

        return mv_ftr, me_ftr, return_dict
