from torch.nn.utils.rnn import pad_sequence

from .Abstract import *
from net.utils.components import GCN
from net.utils import normalize_adj_r


class LSTMGenerator(Generator):
    GCN_HID_DIMS = [64]
    GCN_OUT_DIM = 64

    def __init__(self, *args, **kwargs):
        super(LSTMGenerator, self).__init__(*args, **kwargs)
        self.a_linear = nn.Linear(self.he_dim, 1, bias=True)
        self.a_act = nn.Sigmoid()
        self.gcn = GCN(self.hv_dim, self.GCN_OUT_DIM, self.GCN_HID_DIMS, use_cuda=self.use_cuda, residual=True)
        self.rnn = nn.LSTM(self.hv_dim + sum(self.GCN_HID_DIMS) + self.GCN_OUT_DIM,
                           self.pos_dim * 2 if self.need_momentum else self.pos_dim, 2)

    def forward(self, hv_ftr: torch.Tensor, he_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                return_list: List[str], **kwargs) -> Tuple[Union[torch.Tensor, None], torch.Tensor, Dict[str, Any]]:
        vew1 = mask_matrices.vertex_edge_w1
        vew2 = mask_matrices.vertex_edge_w2
        a = self.a_act(self.a_linear(he_ftr))
        adj_d = vew1 @ torch.diag(torch.reshape(a, [-1])) @ vew2.t()
        adj = adj_d + adj_d.t()
        norm_adj = normalize_adj_r(adj)
        hv_neighbor_ftr = self.gcn(hv_ftr, norm_adj)

        seqs = [hv_neighbor_ftr[n == 1, :] for n in mask_matrices.mol_vertex_w]
        lengths = [s.shape[0] for s in seqs]
        m = pad_sequence(seqs)
        output, _ = self.rnn(m)
        pq_ftr = torch.cat([output[:lengths[i], i, :] for i in range(len(lengths))])

        return_dict = {}
        if self.need_momentum:
            p_ftr, q_ftr = pq_ftr[:, :self.pos_dim], pq_ftr[:, self.pos_dim:]
            return p_ftr, q_ftr, return_dict
        else:
            return None, pq_ftr, return_dict
