import anthropic
import argparse
import re
import numpy as np
from tqdm import tqdm
from datasets import load_dataset

parser = argparse.ArgumentParser()
parser.add_argument('--prompt_file', type=str, default='lib_prompt/prompt_original.txt')
parser.add_argument('--anthropic_key', type=str, default='sk-', help='Anthropic key for claude-instant')
parser.add_argument('--engine', type=str, default='claude-instant-v1.0', help='Engine for claude-instant')
parser.add_argument('--eval_only', action='store_true', help='Only evaluate the model')
parser.add_argument('--output_file', type=str, default='outputs/claude_instant_gsm8k_test.txt', help='Output file for claude-instant')

gsm8k = load_dataset('gsm8k', 'main')
validation_index = np.load('lib_prompt/validation_index.npy')
validation_data = gsm8k['train'].select(validation_index)
gsm8k_test = gsm8k['test']

def parse_answer_file(answer_file):
    lines = open(answer_file, 'r').readlines()

    accuracy = 0
    last_number = 0
    should_find_answer = True

    for i, l in enumerate(lines):
        try:
            if should_find_answer:
                last_number = re.findall(r'\d+', l)[-1]
        except:
            pass

        if l.startswith('####'):
            reference_answer = l.split('####')[1].strip()
            if reference_answer == last_number:
                accuracy += 1
        elif l.startswith('===== CASE'):
            should_find_answer = True
        elif l.startswith('Reference Answer'):
            should_find_answer = False

    print('Accuracy: ', accuracy / len(gsm8k_test['question']) * 100)

def main(args):
    prompts = open(args.prompt_file, 'r').read()
    client = anthropic.Client(args.anthropic_key)

    if args.eval_only:
        parse_answer_file(args.output_file)
        return

    run_count = 0
    with open(args.output_file, 'w') as f:
        for q, a in tqdm(zip(gsm8k_test['question'], gsm8k_test['answer']), total=len(gsm8k_test['question'])):
            # prompt = prompt + '\nQuestion: ' + q + '\n'

            claude_prompt = anthropic.HUMAN_PROMPT + "\n" + prompt + '\nQuestion: ' + q + '\n' + anthropic.AI_PROMPT
            run_count += 1
            
            response = client.completion(
                prompt=claude_prompt,
                stop_sequences=[anthropic.HUMAN_PROMPT, anthropic.AI_PROMPT],
                model=args.engine,
                max_tokens_to_sample=300,
                temperature=0,
            )
            cleaned_response = response['completion'].strip()
            
            f.write(f'===== CASE {run_count} =====\n')
            f.write(f'Question\n: {q}\n')
            f.write(f'Claude-instant Answer\n: {cleaned_response}\n')
            f.write(f'Reference Answer\n: {a}\n\n')

            run_count += 1
        # f.close()
    parse_answer_file(args.output_file)

if __name__ == '__main__':
    args = parser.parse_args()
    main(args)