import argparse
from tot.methods.bfs import solve, solve_new, solve_new_gsm8k
from tot.tasks.game24 import Game24Task
from tot.tasks.blocksworld2 import Blocksworld
from tot.tasks.gsm8k import GSM8K
import sys
from utils import *
from prompts import *
import copy
from tqdm import tqdm
import json
import re

model_name="gpt-3.5-turbo-0301"
args = argparse.Namespace(backend=model_name, temperature=1.0, \
                          task='gsm8k', naive_run=False, prompt_sample=None, \
                          method_generate='propose', method_evaluate='value', \
                          method_select='greedy', n_generate_sample=1, n_evaluate_sample=3, \
                          n_select_sample=5)


for _ in range(5):
    res = []
    for idx in range(1319):
        task = GSM8K()
        ys, infos = solve_new_gsm8k(args, task, idx)
        
        for ys_ in ys:
            if 'final answer' in ys_.lower():
                try:
                    gt_action_sequence = [i.split(':')[1].strip() for i in ys_.split('\n') if len(i)>5]
                except:
                    gt_action_sequence = []
                filtered_ans = [re.sub("[^0-9]", "", lol) for lol in gt_action_sequence if len(lol)<10]
                if task.answer in filtered_ans:
                    res.append(1)
                    break

print('Final score:\t', sum(res))