# evaluating Claude model on converted MMLU to Claude prompt,
# with the option of single or multiple rounds of questions

import anthropic
import json
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from utils import *

# parse arguments
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--anthropic_key', type=str, default='sk')
parser.add_argument('--engine', type=str, default='claude-instant-v1.0', 
                    help='Engine for claude, either claude-v1.3 or claude-instant-v1.0')
parser.add_argument('--prompt_type', type=str, default='single', 
                    help='[single, multiple], single round dialog or multiple round dialog')
args = parser.parse_args()

TASKS = [
        'abstract_algebra',
        'anatomy',
        'astronomy',
        'business_ethics',
        'clinical_knowledge',
        'college_biology',
        'college_chemistry',
        'college_computer_science',
        'college_mathematics',
        'college_medicine',
        'college_physics',
        'computer_security',
        'conceptual_physics',
        'econometrics',
        'electrical_engineering',
        'elementary_mathematics',
        'formal_logic',
        'global_facts',
        'high_school_biology',
        'high_school_chemistry',
        'high_school_computer_science',
        'high_school_european_history',
        'high_school_geography',
        'high_school_government_and_politics',
        'high_school_macroeconomics',
        'high_school_mathematics',
        'high_school_microeconomics',
        'high_school_physics',
        'high_school_psychology',
        'high_school_statistics',
        'high_school_us_history',
        'high_school_world_history',
        'human_aging',
        'human_sexuality',
        'international_law',
        'jurisprudence',
        'logical_fallacies',
        'machine_learning',
        'management',
        'marketing',
        'medical_genetics',
        'miscellaneous',
        'moral_disputes',
        'moral_scenarios',
        'nutrition',
        'philosophy',
        'prehistory',
        'professional_accounting',
        'professional_law',
        'professional_medicine',
        'professional_psychology',
        'public_relations',
        'security_studies',
        'sociology',
        'us_foreign_policy',
        'virology',
        'world_religions']


def get_response(**kwargs):
    client = anthropic.Client(args.anthropic_key)
    response = client.completion(**kwargs)
    return response

def main(args, tasks=TASKS):
    for task in tasks:
        print('Testing %s ...' % task)
        i = 0
        acc = 0
        task_data = load_dataset("lukaemon/mmlu", task)
        
        with open('outputs/test_%s_%s.txt' % (args.engine, task), 'w') as fd:
            for q_ in tqdm(task_data['test'], total=len(task_data['test'])):
                q = 'Q: '+ q_['input'] + '\n'
                task_mod = task.replace('_', ' ')

                # add test prompt based on subject matter
                if task_mod in ["business ethics", 
                                "computer security",
                                "marketing"]:
                    q += "Which one of the four choices completes the question correctly, (A), (B), (C) or (D)?" + "\nChoices:" + "\n"
                elif task_mod in ["college medicine", 
                                    "high school biology", 
                                    "high school european history", 
                                    "high school geography", 
                                    "high school government and politics",
                                    "high school macroeconomics",
                                    "moral disputes"]:
                    q += "Choices:"
                elif task_mod == "college physics":
                    q += "Which one of the four choices is correct about the question, (A), (B), (C) or (D)?" + "\nChoices:" + "\n"
                else:
                    q += "Which one of the four choices is correct, (A), (B), (C) or (D)?" + "\nChoices:" + "\n"
                
                for letter in ['A', 'B', 'C', 'D']:
                    q += '(' + letter + ') ' + q_[letter] + ' '
                
                # add step-by-step prompt
                q += "\nLet's think step by step."
                q += "\nA:"

                # load converted prompt based on prompt type
                if args.prompt_type == 'single':
                    mmlu_prompt = json.load(open('lib_prompt/mmlu-cot-claude-single.json'))
                    prompt_q = mmlu_prompt[task] + "\n\n" + q
                    claude_prompt = anthropic.HUMAN_PROMPT + prompt_q + anthropic.AI_PROMPT
                elif args.prompt_type == 'multiple':
                    mmlu_prompt = json.load(open('lib_prompt/mmlu-cot-claude-multiple.json'))
                    prompt_q = mmlu_prompt[task] + "\n\n" + anthropic.HUMAN_PROMPT + "\n" + q
                    claude_prompt = prompt_q + anthropic.AI_PROMPT
                else:
                    raise ValueError('Prompt type not supported')
                
                # obtain Claude response
                response = get_response(
                    model=args.engine,
                    prompt=claude_prompt,
                    stop_sequences=[anthropic.HUMAN_PROMPT],
                    max_tokens_to_sample=300,
                    temperature=0
                    )
                
                # clean response
                ans_ = response['completion'].strip()    
                a = q_['target']
                fd.write('%s\nA_model:\n%s\nA:\n%s\n\n' % (q, ans_, a))
                i += 1

                # check answer
                if(test_answer_mmlu_claude_instant(ans_, a)): acc += 1
            print('%s acc %.4f' % (task, acc / len(task_data['test'])))

        # write accuracy to file
        with open('outputs/test_%s_%s_acc.txt' % (args.engine, args.prompt_type), 'a') as fd:
            fd.write('%s acc %.4f\n' % (task, acc / len(task_data['test'])))
    
    # write average accuracy to file
    acc_list = []
    with open('outputs/test_%s_%s_acc.txt' % (args.engine, args.prompt_type), 'r') as fd2:
        for line in fd2:
            acc_list.append(float(line.split(' ')[2]))
    with open('outputs/test_%s_%s_acc.txt' % (args.engine, args.prompt_type), 'a') as fd:
        fd.write('Average acc %.4f\n' % (np.mean(acc_list)))

    return 

if __name__ == '__main__':
    main(args)