import argparse, os, json
from openai import OpenAI
from evaluation.utils import (multi_thread_scoring, 
                            multi_thread_response_generation,
                            process_subject_data,
                            get_mr_score,
                            construct_critic_prompts,
                            construct_refine_prompts,
                            multi_thread_critic_generation,
                            multi_thread_refine_generation)

SUBJECTS = [
    'biology',
    'math',
    'physics',
    'medicine',
    'coding',
    'chemistry',
    'logic'
]
EVAL_KEY = 'Evaluation_Result'

def load_dataset(dataset_path, k_shot, demo_path, hint):    
    kshot_data = None
    if k_shot != 0:
        with open(demo_path) as file:
            kshot_data = json.load(file)
    # load corresponding subjects and construct corresponding evaluation prompt
    loaded_dataset = {}
    if os.path.isdir(dataset_path):
        for subject in SUBJECTS:
            if os.path.exists(f"{dataset_path}/{subject}.json"):
                print(f'Loading subject {subject} data ......')
                with open(f"{dataset_path}/{subject}.json") as file:
                    subject_data = json.load(file)
                process_subject_data(subject_data, k_shot, kshot_data, hint)   
                loaded_dataset[subject] = subject_data
    return loaded_dataset

def generate_response(client, benchmark, model_name, temperature, top_p, max_tokens, stop_token_ids, unscored_save_path):
    results = {}
    for subject in benchmark:
        results[subject] = []
        print(f"Generating answers of {subject} ...... ")
        count = 0
        if type(benchmark[subject]) == list:
            results[subject] = benchmark[subject]
            continue # skip the subject that has already been evaluated 
        all_queries = []
        for question_uuid in benchmark[subject]:
            for qs_dict in benchmark[subject][question_uuid]:
                all_queries.append(qs_dict)
        multi_thread_response_generation(all_queries, 
                                         client, 
                                         model_name, 
                                         temperature, 
                                         top_p, 
                                         max_tokens, 
                                         stop_token_ids, 
                                         EVAL_KEY)
        # these query dict should be modified and contain the desired response results
        results[subject] = all_queries 
        with open(unscored_save_path, 'w') as file:
            json.dump(results, file, indent=2, ensure_ascii=False)
    return results

def score_error_reason(score_client, score_model, eval_results, scored_save_path): 
    step_mapper = {f"step {i}": f"{i}"  for i in range(30)}
    for subject in eval_results:
        for data in eval_results[subject]:
            # We only need to score incorrect solutions with correctly predicted first error step
            data['Need_Error_Reason_Review'] = False
            if data['Model_Solution_Correctness'] == 'incorrect':
                if data[EVAL_KEY]['solution_correctness'].strip().lower() == 'incorrect':
                    if subject == 'coding':
                        # for coding task, the first error step is only a rough indicator and should be scored by scoring model or annotator
                        data['Need_Error_Reason_Review'] = True
                        continue
                    if data[EVAL_KEY]['first_error_step'].strip().isdigit():
                        error_step_pred = data[EVAL_KEY]['first_error_step'].strip()
                    elif data[EVAL_KEY]['first_error_step'].strip().lower() in step_mapper:
                        error_step_pred = step_mapper[data[EVAL_KEY]['first_error_step'].strip().lower()]
                    else:
                        error_step_pred = ''
                    if error_step_pred == data['Model_Solution_First_Error_Step']:
                        data['Need_Error_Reason_Review'] = True
    # score with gpt4 
    for subject in eval_results:
        print(f"scoring answers of {subject} ......")
        to_be_scored_data = []
        for data in eval_results[subject]:
            # save the gpt4 score results so that we can recover from any breaks without re-querying
            if data['Need_Error_Reason_Review'] and 'Error_Reason_Correctness_Analysis' not in data:
                to_be_scored_data.append(data)
        multi_thread_scoring(to_be_scored_data, score_client, subject, score_model)
        with open(scored_save_path, 'w') as file:
            json.dump(eval_results, file, indent=2, ensure_ascii=False)
    return eval_results

def calculate_mr_score(scored_eval_results):
    mr_score_stats, mr_scores = {}, {}
    step_mapper = {f"step {i}": f"{i}"  for i in range(30)}
    for subject in scored_eval_results:
        task1_true_positive, task1_true_negative = 0, 0
        correct_sol_num, incorrect_sol_num = 0, 0 
        task2_accy, task3_accy_auto = 0, 0    
        for data in scored_eval_results[subject]:
            if data['Model_Solution_Correctness'] == 'correct':
                correct_sol_num +=1
            else:
                incorrect_sol_num +=1 
            correctness_pred = data[EVAL_KEY]['solution_correctness'].strip().lower()
            if data['Model_Solution_Correctness'] == correctness_pred:
                if data['Model_Solution_Correctness'] == 'correct':
                    task1_true_positive +=1
                else:
                    task1_true_negative +=1
                    # only if the solution is incorrect and the model agrees on the incorrectness do 
                    # we look into task2 and task3 performance. Note for coding task, it is hard to pinpoint
                    # the exact location of the first error step. We instead use the line as a rough indicator 
                    # and leave the judgement of the error reason to the scoring model or annotator. 
                    if subject == 'coding':
                        if 'correct' in data['Error_Reason_Correctness_Analysis']['error_reason_correctness'].lower():
                            task2_accy += 1
                            task3_accy_auto +=1
                            continue
                    if data[EVAL_KEY]['first_error_step'].strip().isdigit():
                        error_step_pred = data[EVAL_KEY]['first_error_step']
                    elif data[EVAL_KEY]['first_error_step'].strip().lower() in step_mapper:
                        error_step_pred = step_mapper[data[EVAL_KEY]['first_error_step'].strip().lower()]
                    else:
                        error_step_pred = ''
                    if error_step_pred == data['Model_Solution_First_Error_Step']:
                        task2_accy += 1
                        if 'correct' in data['Error_Reason_Correctness_Analysis']['error_reason_correctness'].lower():
                            task3_accy_auto +=1
        mr_score_stats[subject] = {
            't1-tp': task1_true_positive,
            't1-tn': task1_true_negative,
            't2_corr_num': task2_accy,
            't3_corr_num_auto': task3_accy_auto,
            'correct_sol_num': correct_sol_num,
            'incorrect_sol_num': incorrect_sol_num
        }
    for subject in mr_score_stats:
        mr_scores[subject] = get_mr_score(mr_score_stats[subject])
    return mr_scores
            


def main(args):
    if '/' in args.eval_model_name:
        # name of open-sourced model served by vllm is the absolute path of the downloaded model folder
        succint_model_name = args.eval_model_name.split('/')[-1]
    else:
        succint_model_name = args.eval_model_name # commercial models
    unscored_save_path = f"{args.output_dir}/{succint_model_name}_{args.shot_num}shot_cot_True_hint_{args.hint}_eval_results.json"
    scored_save_path = f"{args.output_dir}/{succint_model_name}_{args.shot_num}shot_cot_True_hint_{args.hint}_scored_eval_results.json"    
    eval_client = OpenAI(base_url=args.eval_base_url, api_key=args.eval_api_key)
    # load the MR-BEAN dataset and construct the corresponding evaluation prompts
    mr_bean_dataset = load_dataset(args.dataset_path, args.shot_num, args.demo_path, args.hint)
    try:
        with open(unscored_save_path) as file:
            generated_res = json.load(file)
        for subject in generated_res:
            mr_bean_dataset[subject] = generated_res[subject]
        print("Cached eval results found, reusing generated subject data ....")  
    except Exception as e:
        print("No cached eval result found, proceed to full dataset evaluation ....")
    mr_bean_eval_results = generate_response(eval_client, 
                                             mr_bean_dataset,
                                             args.eval_model_name, 
                                             args.temperature,
                                             args.top_p,
                                             args.max_tokens,
                                             args.stop_token_ids,
                                             unscored_save_path)
    score_client = OpenAI(base_url=args.score_base_url, api_key=args.score_api_key)
    scored_mr_bean_eval_results = score_error_reason(score_client, args.score_model_name, mr_bean_eval_results, scored_save_path)
    mr_scores = calculate_mr_score(scored_mr_bean_eval_results)
    print(mr_scores)
    return mr_scores


def generate_self_refine(client, benchmark, model_name, temperature, top_p, max_tokens, stop_token_ids, critic_unscored_save_path,
                         refine_unscored_save_path):
    critic_results = {}
    # for each data point, we update the prompt to perform self critic and improvement
    construct_critic_prompts(benchmark)
    for subject in benchmark:
        multi_thread_critic_generation(benchmark[subject], 
                                         client, 
                                         model_name, 
                                         temperature, 
                                         top_p, 
                                         max_tokens, 
                                         stop_token_ids)
        # these query dict should be modified and contain the desired response results
        critic_results[subject] = benchmark[subject]
        with open(critic_unscored_save_path, 'w') as file:
            json.dump(critic_results, file, indent=2, ensure_ascii=False)
    # After revise, construct refine propt and query model
    refine_results = {}
    construct_refine_prompts(critic_results)
    for subject in critic_results:
        multi_thread_refine_generation(critic_results[subject], 
                                         client, 
                                         model_name, 
                                         temperature, 
                                         top_p, 
                                         max_tokens, 
                                         stop_token_ids)
        # these query dict should be modified and contain the desired response results
        refine_results[subject] = critic_results[subject]
        with open(refine_unscored_save_path, 'w') as file:
            json.dump(refine_results, file, indent=2, ensure_ascii=False)
    # overwrite the original eval key for simplicity for now 
    for subject in refine_results:
        for sol in refine_results[subject]:
            sol[EVAL_KEY] = sol['Refine_Response']
    return refine_results
    


def self_refine(args):
    if '/' in args.eval_model_name:
        # name of open-sourced model served by vllm is the absolute path of the downloaded model folder
        succint_model_name = args.eval_model_name.split('/')[-1]
    else:
        succint_model_name = args.eval_model_name # commercial models
    scored_save_path = f"{args.output_dir}/{succint_model_name}_{args.shot_num}shot_cot_True_scored_eval_results.json"
    critic_unscored_save_path = f"{args.output_dir}/{succint_model_name}_{args.shot_num}shot_cot_True_critic_results.json"
    refine_unscored_save_path = f"{args.output_dir}/{succint_model_name}_{args.shot_num}shot_cot_True_refine_results.json"
    self_refine_scored_save_path = f"{args.output_dir}/{succint_model_name}_{args.shot_num}shot_cot_True_scored_self_refine_results.json"    
    eval_client = OpenAI(base_url=args.eval_base_url, api_key=args.eval_api_key)
    with open(scored_save_path) as file:
        mr_bean_dataset = json.load(file)
    mr_bean_eval_results = generate_self_refine(eval_client, 
                                             mr_bean_dataset,
                                             args.eval_model_name, 
                                             args.temperature,
                                             args.top_p,
                                             args.max_tokens,
                                             args.stop_token_ids,
                                             critic_unscored_save_path,
                                             refine_unscored_save_path)
    score_client = OpenAI(base_url=args.score_base_url, api_key=args.score_api_key)
    scored_mr_bean_eval_results = score_error_reason(score_client, args.score_model_name, mr_bean_eval_results, self_refine_scored_save_path)
    mr_scores = calculate_mr_score(scored_mr_bean_eval_results)
    print(mr_scores)
    return mr_scores
    
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate open sourced models on MR.BEAN benchmark")
    parser.add_argument('--eval_base_url', type=str, required=True, help='The base url to the openAI-api compatible server of the evaluated model')
    parser.add_argument('--eval_api_key', type=str, required=False, help='The potential api-key to the api server of the evaluated model', default='placeholder')
    parser.add_argument('--eval_model_name', type=str, required=True, help='The name of the evaluated model, for local open-sourced model please provide absolute path to the model') 
    parser.add_argument('--score_base_url', type=str, required=True, help='The base url to the openAI-api compatible server for scoring the error reason')
    parser.add_argument('--score_api_key', type=str, required=False, help='The potential api-key to the api server for scoring the error reason', default='')
    parser.add_argument('--score_model_name', type=str, required=True, help='The name of the scoring model. We recommend using gpt-4-turbo')
    parser.add_argument('--dataset_path', type=str, required=True, help='Path to the MR.BEAN dataset file')
    parser.add_argument('--output_dir', '-o', type=str, required=True, help='Output directory for saving evaluation results')
    parser.add_argument('--temperature', '-t', type=float, required=False, default=1.0,  help='Temperature for sampling')
    parser.add_argument('--top_p', '-p', type=float, required=False, default=0.8, help='Top-p threshold for sampling')
    parser.add_argument('--max_tokens', '-m', type=int, required=False, default=1024, help='Max token numbers to generate during sampling')
    parser.add_argument('--stop_token_ids', type=int, required=False,  nargs="+", help='List of stop token ids because default tokenizer used by vllm might not using correct stop tokens in chat models.')
    parser.add_argument('--shot_num', '-k', type=int, required=False, default=0, help='The number of demonstrations for evaluated model')
    parser.add_argument('--demo_path', type=str, required=False, default='', help='The path to the few shot demo file for evaluation')
    parser.add_argument('--hint', action='store_true', required=False, help='Whether or not we disclose the solution correctness to the evaluated model.')
    parser.add_argument('--self_refine', action='store_true', required=False, help='Whether or not we perform the self-refine experiment.')
    args = parser.parse_args()
    if args.self_refine:
        self_refine(args)
    else:
        main(args)
