import argparse
import json 
import re
from utils.model import ModelWrapper
from tqdm import tqdm

instruction = """Compare the generated solution to the provided standard solution to determine its correctness. Your evaluation should include:
A comparison of the generated solution with these elements, assessing accuracy, completeness, and clarity.
A final judgment on whether the generated solution is correct or incorrect.
If correct, return: '# Answer: True', else return: '# Answer: False'
Be detailed and concise in your explanation.
"""


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='Llama3_1_8b_chat')
    parser.add_argument('--dataset', type=str, default='math')
    parser.add_argument('--eval_model', type=str, default='gpt-4o')
    parser.add_argument('--n_samples', type=int, default=200)
    args = parser.parse_args()

    model_name = args.model 
    dataset = args.dataset 
    eval_model = args.eval_model
    n_samples = args.n_samples
    
    model = ModelWrapper(eval_model)
    result_path = f'./result/{dataset}/{model_name}/sc10_e3_200.json'
    with open(result_path, 'r') as f:
        data = json.load(f)[:-1]
    
    results = []
    for item in tqdm(data[:n_samples]):
        if 'corrects' in item.keys():
            response = [item['response'][i] for i in range(10) if item['corrects'][i]]
        else:
            response = [item['response'][i] for i in range(10) if item['cor_flag'][i]]
        for res in response[:3]:
            if '# Answer:' in res:
                cot = res.split('# Answer:')[0]
            elif '\n\n' in res:
                cot = ('\n\n').join(res.split('\n\n')[:-1])
            else:
                cot = res
            cot = ':'.join(cot.split(':')[1:]).strip()
            input = f"Question: {item['question']}\nGenerated Solution: {cot}\nStandard Solution: {item['reason']}"
            inputs = [{"role": "system", "content": instruction}, {"role": "user", "content": input}]
            response = model.generate(inputs)
            pattern = r'(True|False|true|false)'
            match = re.findall(pattern,response.strip(). split('\n')[-1])
            if match:
                score = bool(match[0])
            else:
                score = None 
            msg = {
                'id':item['id'],
                'question':item['question'],
                'cot':cot,
                'gold_cot': item['reason'],
                'evaluation':response,
                'score':score
            }
            results.append(msg)
    eval_path = f'./result/{dataset}/{model_name}/{eval_model}_eval_{n_samples}.json'
    with open(eval_path, 'w') as f:
        json.dump(results, f, indent=4)