# evaluating GPT-3.5 turbo model on MMLU

import openai
import re
import time
import json

import numpy as np

from tqdm import tqdm
from datasets import load_dataset
from tenacity import retry, stop_after_attempt, wait_chain, wait_fixed
from utils import *

# parse arguments
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--api_key', type=str, default='sk')
args = parser.parse_args()

TASKS = [
        'abstract_algebra',
        'anatomy',
        'astronomy',
        'business_ethics',
        'clinical_knowledge',
        'college_biology',
        'college_chemistry',
        'college_computer_science',
        'college_mathematics',
        'college_medicine',
        'college_physics',
        'computer_security',
        'conceptual_physics',
        'econometrics',
        'electrical_engineering',
        'elementary_mathematics',
        'formal_logic',
        'global_facts',
        'high_school_biology',
        'high_school_chemistry',
        'high_school_computer_science',
        'high_school_european_history',
        'high_school_geography',
        'high_school_government_and_politics',
        'high_school_macroeconomics',
        'high_school_mathematics',
        'high_school_microeconomics',
        'high_school_physics',
        'high_school_psychology',
        'high_school_statistics',
        'high_school_us_history',
        'high_school_world_history',
        'human_aging',
        'human_sexuality',
        'international_law',
        'jurisprudence',
        'logical_fallacies',
        'machine_learning',
        'management',
        'marketing',
        'medical_genetics',
        'miscellaneous',
        'moral_disputes',
        'moral_scenarios',
        'nutrition',
        'philosophy',
        'prehistory',
        'professional_accounting',
        'professional_law',
        'professional_medicine',
        'professional_psychology',
        'public_relations',
        'security_studies',
        'sociology',
        'us_foreign_policy',
        'virology',
        'world_religions']

@retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +
                       [wait_fixed(5) for i in range(2)] +
                       [wait_fixed(10)]))
def completion_with_backoff(**kwargs):
    return openai.ChatCompletion.create(**kwargs)

def main(args, tasks=TASKS):
    openai.api_key = args.api_key
    mmlu_prompt = json.load(open('lib_prompt/mmlu-cot.json'))
    for task in tasks:
        print('Testing %s ...' % task)
        i = 0
        acc = 0
        task_data = load_dataset("lukaemon/mmlu", task)
        with open('outputs/test_gpt_3.5_turbo_%s.txt' % task, 'w') as fd:
            for q_ in tqdm(task_data['test'], total=len(task_data['test'])):
                q = q_['input'] + '\n'
                for letter in ['A', 'B', 'C', 'D']:
                    q += '(' + letter + ') ' + q_[letter] + ' '
                q += "\nA: Let's think step by step."  
                    
                prompt_q = mmlu_prompt[task] + "\n\n" + q

                response = completion_with_backoff(
                    model="gpt-3.5-turbo",
                    messages=[
                            {"role": "system", "content": "Follow the given examples and answer the question."},
                            {"role": "user", "content": prompt_q},
                        ],
                    temperature=0
                    )
                ans_model = response['choices'][0]['message']['content']
                ans_, residual = extract_ans(ans_model)
                    
                a = q_['target']
                fd.write('Q: %s\nA_model:\n%s\nA:\n%s\n\n' % (q, ans_, a))
                i += 1
                
                if(test_answer_mmlu_(ans_, a)): acc += 1
            print('%s acc %.4f' % (task, acc / len(task_data['test'])))
    return 

if __name__ == '__main__':
    main(args)