from embedding_text import EmbeddingModel
from utils_proof.encode_graph import *

class GraphEmbedding:
    def __init__(self, encode_model='bert-base-uncased') -> None:
        self.encoder = EmbeddingModel(f'../../models/{encode_model}')
        self.device = self.encoder.device
    
    def encode(self, train_graphs, test_graphs, method, arg=None):
        train_embs, test_embs = encode_bayes_ppr2(self, train_graphs, test_graphs, arg)
        return train_embs, test_embs