from src.models.structgnn_model import StructGNN_Model
from src.modules.graphtrans_module import Struct2Seq


class GraphTrans_Model(StructGNN_Model):
    def __init__(self, args):
        StructGNN_Model.__init__(self, args)
        
        self.model = Struct2Seq(
            num_letters=args.vocab_size,
            node_features=args.hidden,
            edge_features=args.hidden, 
            hidden_dim=args.hidden,
            k_neighbors=args.k_neighbors,
            protein_features=args.features,
            dropout=args.dropout,
            use_mpnn=False)