import torch
from tqdm import tqdm

from retriever import Retriever
from graph_dataset import GraphDataset
from utils_proof.file import write_json

def main(args):
    dataset = args.task
    infer_model = args.llm
    method = args.method
    arg = args.arg
    sim = args.sim
    k = 8

    graph_dataset = GraphDataset(dataset, infer_model)

    train_embs = torch.load(f'embs/{dataset}/{infer_model}/{method}_{arg}_train.pt')
    test_embs = torch.load(f'embs/{dataset}/{infer_model}/{method}_{arg}_test.pt')

    retriever = Retriever(train_embs, sim)

    prompt_list = []
    with torch.no_grad():
        for i, test_emb in tqdm(enumerate(test_embs), total=len(test_embs)):
            topk_indices = retriever.retrieve(test_emb, k)
            prompt_dict = graph_dataset.construct_prompt(topk_indices, i)
            prompt_list.append(prompt_dict)

    write_json(f"results/{dataset}/{infer_model}/{method}_{arg}_{sim}.json", prompt_list)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='mini_proofwriter')
    parser.add_argument('--llm', type=str, default='llama-3.1-8b-instruct')
    parser.add_argument('--method', type=str, default='bayes_ppr2')
    parser.add_argument('--sim', type=str, default='ppr2')
    parser.add_argument('--arg', type=str, default='2-mean-0p1')
    args = parser.parse_args()
    main(args)