import torch.nn as nn
from src.modules.geo_modules import *


class Encoder(nn.Module):
    def __init__(self, args, **kwargs):
        super(Encoder, self).__init__()
        self.args = args
        hidden_dim = args.hidden_dim
        self.top_k = args.k_neighbors
        self.encoder = StructureEncoder(args.geo_layers, args.edge_layers, args.enc_layers, hidden_dim, dropout=args.dropout)        

    def forward(self, chain_encoding, edge_idx, batch_id, V=None, E=None, T_ts=None, batch_id_extend=None, edge_idx_extend=None):
        h_V = self.encoder(
            edge_idx, batch_id, chain_encoding,
            V, E, T_ts, batch_id_extend, edge_idx_extend,
        )
        h_V = h_V[chain_encoding < 1000]
        return h_V