import torch
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph
import networkx as nx
import numpy as np
import os

from src.utils.lm_modeling import load_model, load_text2embedding

path = 'dataset/GLBench'
path_lm = 'dataset/GLBench_lm'

def main():
    datanames = ['wikics', 'citeseer', 'cora', 'instagram', 'pubmed', 'reddit','arxiv']
    # datanames = [ 'cora',]
    # datanames = ['cora', 'citeseer']
    model_name = 'sbert'
    model, tokenizer, device = load_model[model_name]()
    text2embedding = load_text2embedding[model_name]
    for dataname in datanames:
        data = torch.load(f'{path}/{dataname}.pt').to(device)
        x = text2embedding(model, tokenizer, device, data.raw_texts)
        data.x = x
        if not os.path.exists(f'{path_lm}'):
            os.makedirs(f'{path_lm}')
        torch.save(data, f'{path_lm}/{dataname}.pt')
        torch.cuda.empty_cache()
        print(f"{dataname} processing is finished!")

if __name__ == "__main__":
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    main()