# evaluating Claude on MMLU

import anthropic
import re
import time
import json

import numpy as np

from tqdm import tqdm
from datasets import load_dataset
from tenacity import retry, stop_after_attempt, wait_chain, wait_fixed
from utils import *

# parse arguments
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, default='claude-v1.3')
parser.add_argument('--api_key', type=str, default='sk')
args = parser.parse_args()

TASKS = [
        'abstract_algebra',
        'anatomy',
        'astronomy',
        'business_ethics',
        'clinical_knowledge',
        'college_biology',
        'college_chemistry',
        'college_computer_science',
        'college_mathematics',
        'college_medicine',
        'college_physics',
        'computer_security',
        'conceptual_physics',
        'econometrics',
        'electrical_engineering',
        'elementary_mathematics',
        'formal_logic',
        'global_facts',
        'high_school_biology',
        'high_school_chemistry',
        'high_school_computer_science',
        'high_school_european_history',
        'high_school_geography',
        'high_school_government_and_politics',
        'high_school_macroeconomics',
        'high_school_mathematics',
        'high_school_microeconomics',
        'high_school_physics',
        'high_school_psychology',
        'high_school_statistics',
        'high_school_us_history',
        'high_school_world_history',
        'human_aging',
        'human_sexuality',
        'international_law',
        'jurisprudence',
        'logical_fallacies',
        'machine_learning',
        'management',
        'marketing',
        'medical_genetics',
        'miscellaneous',
        'moral_disputes',
        'moral_scenarios',
        'nutrition',
        'philosophy',
        'prehistory',
        'professional_accounting',
        'professional_law',
        'professional_medicine',
        'professional_psychology',
        'public_relations',
        'security_studies',
        'sociology',
        'us_foreign_policy',
        'virology',
        'world_religions']

# @retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +
#                        [wait_fixed(5) for i in range(2)] +
#                        [wait_fixed(10)]))
# def completion_with_backoff(**kwargs):
#     return openai.ChatCompletion.create(**kwargs)

def main(args, tasks=TASKS):
    client = anthropic.Client(args.api_key)

    mmlu_prompt = json.load(open('lib_prompt/mmlu-cot.json'))
    for task in tasks:
        print('Testing %s ...' % task)
        i = 0
        acc = 0
        task_data = load_dataset("lukaemon/mmlu", task)
        with open('outputs/test_%s_%s.txt' % (args.engine, task), 'w') as fd:
            for q_ in tqdm(task_data['test'], total=len(task_data['test'])):
                q = q_['input'] + '\n'
                for letter in ['A', 'B', 'C', 'D']:
                    q += '(' + letter + ') ' + q_[letter] + ' '
                q += "\nA: Let's think step by step."  
                    
                prompt_q = mmlu_prompt[task] + "\n\n" + q

                claude_prompt = anthropic.HUMAN_PROMPT + prompt_q + anthropic.AI_PROMPT
                # import ipdb; ipdb.set_trace()
                response = client.completion(
                    prompt=claude_prompt,
                    stop_sequences=[anthropic.HUMAN_PROMPT, anthropic.AI_PROMPT],
                    model=args.engine,
                    max_tokens_to_sample=300,
                    temperature=0
                )

                ans_model = response['completion'].strip()
                print(ans_model)
                ans_, residual = extract_ans(ans_model)
                # import ipdb; ipdb.set_trace()
                    
                a = q_['target']
                fd.write('Q: %s\nA_model:\n%s\nA:\n%s\n\n' % (q, ans_, a))
                i += 1
                
                if(test_answer_mmlu_claude(ans_, a)): acc += 1
            print('%s acc %.4f' % (task, acc / len(task_data['test'])))
    return 

if __name__ == '__main__':
    main(args)