from tqdm import tqdm

from utils_proof.graph import frr_to_graph
from utils_proof.file import *
from embedding_graph import GraphEmbedding

def main(args):
    dataset = args.task
    infer_model = args.llm
    method = args.method
    arg = args.arg

    train_code_list = load_json(f'frrs/{dataset}/{infer_model}/train.json')
    test_code_list = load_json(f'frrs/{dataset}/{infer_model}/test.json')

    encoder = GraphEmbedding()

    train_graphs = []
    for i, code in tqdm(enumerate(train_code_list), total=len(train_code_list)):
        g = frr_to_graph(code)
        train_graphs.append(g)
    test_graphs = []
    for i, code in tqdm(enumerate(test_code_list), total=len(test_code_list)):
        g = frr_to_graph(code)
        test_graphs.append(g)
    train_embs, test_embs = encoder.encode(train_graphs, test_graphs, method, arg)

    save_tensor(train_embs, f'embs/{dataset}/{infer_model}/{method}_{arg}_train.pt')
    save_tensor(test_embs, f'embs/{dataset}/{infer_model}/{method}_{arg}_test.pt')

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='mini_proofwriter')
    parser.add_argument('--llm', type=str, default='llama-3.1-8b-instruct')
    parser.add_argument('--method', type=str, default='bayes_ppr2')
    parser.add_argument('--arg', type=str, default='2-mean-0p0')
    args = parser.parse_args()
    main(args)