from .Abstract import *


class RDKitGenerator(Generator):
    def __init__(self, *args, **kwargs):
        super(RDKitGenerator, self).__init__(*args, **kwargs)
        if self.pos_dim != 3:
            self.linear = nn.Linear(3, self.pos_dim, bias=False)
        else:
            self.linear = lambda x: x

    def forward(self, hv_ftr: torch.Tensor, he_ftr: torch.Tensor, mask_matrices: MaskMatrices,
                return_list: List[str], **kwargs) -> Tuple[Union[torch.Tensor, None], torch.Tensor, Dict[str, Any]]:
        rdkit_pos_ftr = kwargs['rdkit_pos_ftr']
        q_ftr = self.linear(rdkit_pos_ftr)
        return_dict = {}
        if self.need_momentum:
            p_ftr = torch.zeros_like(q_ftr)
            return p_ftr, q_ftr, return_dict
        else:
            return None, q_ftr, return_dict
