import argparse
import os
import json
import numpy as np
from utils.model import ModelWrapper
from utils.load_data import DataLoader, extract_answer
from transformers import set_seed
from tqdm import tqdm
import random
from utils.reward import reward_factory
from utils.eval import is_equiv


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='Llama3_1_8b_chat')
    parser.add_argument('--n_samples', type=int, default=200)
    parser.add_argument('--n_examples', type=int, default=3)
    parser.add_argument('--dataset', type=str, default='proofwriter')
    parser.add_argument('--method', type=str, default='cot')
    parser.add_argument('--roll_num', type=int, default=None)
    parser.add_argument('--reward', type=str, default=None)
    parser.add_argument('--remote', action='store_true')
    args = parser.parse_args()
    set_seed(17)
    random.seed(17)
    
    model_name = args.model
    n_samples = args.n_samples
    n_examples = args.n_examples
    dataset = args.dataset 
    method = args.method
    roll_num = args.roll_num
    reward = args.reward
    remote = args.remote

    dataloader = DataLoader(dataset=dataset, n_samples=n_samples)
    if method in ['cot', 'sc', 'bestn']:
        data = dataloader.load_data(method='cot', n_examples=n_examples)
    else:
        data = dataloader.load_data(method=method, n_examples=n_examples)

    if reward:
        reward_model = reward_factory(reward, remote)