'''
Usage:
python -m mix_eval.compute_metrics \
    --split close_multichoice_hard \
    --multi_choice_parser model \
    --extract_base_model_response \
    --results_closeended_multichoice mix_eval/data/model_responses/close_multichoice_hard \
    --models_to_eval \
    gpt_4_turbo_2024_04_09 \

'''
import json
import argparse
import os
from tqdm import tqdm
import time
import warnings
warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.simplefilter("ignore", category=FutureWarning)

import mix_eval.api.registry
from mix_eval.utils.common_utils import set_seed
from mix_eval.utils.metric_utils import (
    parse_multi_choice_response_rule,
    parse_multi_choice_response_model,
    eval_multi_choice,
    eval_freeform_model,
    parse_freeform_response_rule,
    eval_freeform_rule,
    eval_openended,
    eval_openended_hard,
    get_task_pairs,
    get_model_answer_avrlen,
    get_scores_from_results_openhard,
    )
from mix_eval.models import AVAILABLE_MODELS

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--split", 
        type=str, 
        choices=["close_freeform", "close_multichoice", "close_freeform_hard", "close_multichoice_hard", "open", "open_hard"], 
        required=True,
        help="Split to evaluate."
        )
    parser.add_argument(
        "--results_closeended_freeform", 
        type=str, 
        default="mix_eval/data/model_responses/close_freeform", 
        help="Path to close-ended free-form model responses."
        )
    parser.add_argument(
        "--results_closeended_multichoice", 
        type=str, 
        default="mix_eval/data/model_responses/close_multichoice", 
        help="Path to close-ended multiple-choice model responses."
        )
    parser.add_argument(
        "--multichoice_judge",
        type=str, 
        default="gpt-3.5-turbo-0125", 
        help="Judge model for multiple-choice score computation."
        )
    parser.add_argument(
        "--freeform_judge",
        type=str, 
        default="gpt-3.5-turbo-0125", 
        help="Judge model for freeform score computation."
        )
    parser.add_argument(
        "--results_openended", 
        type=str, 
        default="mix_eval/data/model_responses/open", 
        help="Path to open-ended model responses."
        )
    parser.add_argument(
        "--results_openended_hard", 
        type=str, 
        default="mix_eval/data/model_responses/open_hard", 
        help="Path to open-ended model responses."
        )
    parser.add_argument(
        "--baseline_model_response_openended_hard", 
        type=str, 
        default="gpt_4_0314", 
        help="Baseline model for open-hard score computation."
        )
    parser.add_argument(
        "--open_judge", 
        type=str, 
        default="gpt-4-1106-preview", 
        help="Judge model for openended score computation (both open and open-hard). "
        "Here you should specify the official model name instead of the model names in our AVAILABLE_MODELS."
        )
    parser.add_argument(
        "--show_elo_score", 
        action="store_true", 
        help="Whether to show elo score in openended and open-hard evaluation, otherwise will show win rate."
        )
    parser.add_argument(
        "--models_to_eval", 
        nargs='+',
        default=None, 
        help="Models to evaluate."
        )
    parser.add_argument(
        "--free_form_parser", 
        type=str, 
        default="model", 
        choices=["model", "rule"], 
        help="Parser for freeform responses, either model parser or rule-based parser.")
    parser.add_argument(
        "--multi_choice_parser", 
        type=str, 
        default="model", 
        choices=["model", "rule"], 
        help="Parser for multiple-choice responses, either model parser or rule-based parser."
        )
    parser.add_argument(
        "--api_parallel_num", 
        type=int, 
        default=100, 
        help="Number of parallel threads for calling the judge model api if use model parsing." 
        "If you hit rate limit error frequently, try to reduce this number."
        )
    parser.add_argument(
        "--extract_base_model_response", 
        action="store_true", 
        help="The unfinetuned models will produce endless output, "
        "which may influence the model parse score."
        )
    parser.add_argument(
        "--compute_score_from_judged_file", 
        action="store_true", 
        help="Whether to compute score directly from the judged file."
        "This will save budge for those models that has been judged before."
        "it also helps to do some analysis easily without running judgements again."
        )
    return parser.parse_args()

def compute_metric_closeended_freeform_modelparse_from_judgefile(args):
    score_dict = {}
    if args.models_to_eval is not None:
        models = args.models_to_eval
        models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    else:
        if os.path.exists(args.results_closeended_freeform):
            models = os.listdir(args.results_closeended_freeform)
            models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    
    for model in models:
        print(f"Processing model: {model}")
        score_dict_model = {}
        judge_file = os.path.join(
            args.results_closeended_freeform, 
            model, 
            f"judge_results_ff_model_judge_{args.freeform_judge}.jsonl"
            )
        if not os.path.exists(judge_file):
            print(f"Judge file not found: {judge_file}")
            continue
        with open(judge_file, "r") as f:
            for line in f:
                judge_dict = json.loads(line)
                judge_score = judge_dict["judge_score"]
                if 'overall' not in score_dict_model:
                    score_dict_model['overall'] = []
                score_dict_model['overall'].append(judge_score)
                if judge_dict['benchmark_name'] not in score_dict_model:
                    score_dict_model[judge_dict['benchmark_name']] = []
                score_dict_model[judge_dict['benchmark_name']].append(judge_score)
            
        for key, value in score_dict_model.items():
            score_dict_model[key] = round(sum(value)/len(value), 3)
        score_dict[model] = score_dict_model
    
    with open(os.path.join(args.results_closeended_freeform, 
                            "score_dict_model_fromjudge.json"), "w") as f:
        f.write(json.dumps(score_dict, indent=4) + "\n")

def compute_metric_closeended_freeform_ruleparse_from_judgefile(args):
    score_dict = {}
    if args.models_to_eval is not None:
        models = args.models_to_eval
        models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    else:
        if os.path.exists(args.results_closeended_freeform):
            models = os.listdir(args.results_closeended_freeform)
            models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    
    for model in models:
        print(f"Processing model: {model}")
        score_dict_model = {}
        judge_file = os.path.join(
            args.results_closeended_freeform, 
            model, 
            f"judge_results_ff_rule.jsonl"
            )
        if not os.path.exists(judge_file):
            print(f"Judge file not found: {judge_file}")
            continue
        with open(judge_file, "r") as f:
            for line in f:
                judge_dict = json.loads(line)
                judge_score = 1 if judge_dict["eval_result"] else 0
                if 'overall' not in score_dict_model:
                    score_dict_model['overall'] = []
                score_dict_model['overall'].append(judge_score)
                if judge_dict['benchmark_name'] not in score_dict_model:
                    score_dict_model[judge_dict['benchmark_name']] = []
                score_dict_model[judge_dict['benchmark_name']].append(judge_score)
            
        for key, value in score_dict_model.items():
            score_dict_model[key] = round(sum(value)/len(value), 3)
        score_dict[model] = score_dict_model
    
    with open(os.path.join(args.results_closeended_freeform, 
                            "score_dict_rule_fromjudge.json"), "w") as f:
        f.write(json.dumps(score_dict, indent=4) + "\n")

def compute_metric_closeended_multichoice_modelparse_from_judgefile(args):
    score_dict = {}
    if args.models_to_eval is not None:
        models = args.models_to_eval
        models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    else:
        if os.path.exists(args.results_closeended_multichoice):
            models = os.listdir(args.results_closeended_multichoice)
            models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    
    for model in models:
        print(f"Processing model: {model}")
        score_dict_model = {}
        judge_file = os.path.join(
            args.results_closeended_multichoice, 
            model, 
            f"judge_results_mp_model_judge_{args.multichoice_judge}.jsonl"
            )
        if not os.path.exists(judge_file):
            print(f"Judge file not found: {judge_file}")
            continue
        with open(judge_file, "r") as f:
            for line in f:
                # compute score
                judge_dict = json.loads(line)
                options = judge_dict["options"]
                target = judge_dict["target"]
                assert isinstance(target, list) and len(target) == 1, \
                    f"Invalid target: {target}"
                all_choices = [chr(ord("A") + i) for i in range(len(options))]
                model_choice = judge_dict['judge_option']
                target_id = all_choices[target[0]]
                judge_score = 1 if eval_multi_choice(target_id, model_choice) else 0
                
                # add score
                if 'overall' not in score_dict_model:
                    score_dict_model['overall'] = []
                score_dict_model['overall'].append(judge_score)
                if judge_dict['benchmark_name'] not in score_dict_model:
                    score_dict_model[judge_dict['benchmark_name']] = []
                score_dict_model[judge_dict['benchmark_name']].append(judge_score)
            
        for key, value in score_dict_model.items():
            score_dict_model[key] = round(sum(value)/len(value), 3)
        score_dict[model] = score_dict_model
    
    with open(os.path.join(args.results_closeended_multichoice, 
                            "score_dict_model_fromjudge.json"), "w") as f:
        f.write(json.dumps(score_dict, indent=4) + "\n")

def compute_metric_closeended_multichoice_ruleparse_from_judgefile(args):
    score_dict = {}
    if args.models_to_eval is not None:
        models = args.models_to_eval
        models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    else:
        if os.path.exists(args.results_closeended_multichoice):
            models = os.listdir(args.results_closeended_multichoice)
            models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    
    for model in models:
        print(f"Processing model: {model}")
        score_dict_model = {}
        judge_file = os.path.join(
            args.results_closeended_multichoice, 
            model, 
            f"judge_results_mp_rule.jsonl"
            )
        if not os.path.exists(judge_file):
            print(f"Judge file not found: {judge_file}")
            continue
        with open(judge_file, "r") as f:
            for line in f:
                # compute score
                judge_dict = json.loads(line)
                judge_score = 1 if judge_dict["eval_result"] else 0
                
                # add score
                if 'overall' not in score_dict_model:
                    score_dict_model['overall'] = []
                score_dict_model['overall'].append(judge_score)
                if judge_dict['benchmark_name'] not in score_dict_model:
                    score_dict_model[judge_dict['benchmark_name']] = []
                score_dict_model[judge_dict['benchmark_name']].append(judge_score)
            
        for key, value in score_dict_model.items():
            score_dict_model[key] = round(sum(value)/len(value), 3)
        score_dict[model] = score_dict_model
    
    with open(os.path.join(args.results_closeended_multichoice, 
                            "score_dict_rule_fromjudge.json"), "w") as f:
        f.write(json.dumps(score_dict, indent=4) + "\n")

def compute_metric_openended_from_judgefile(args):
    raise NotImplementedError


def compute_metric_openhard_from_judgefile(args):
    
    if args.models_to_eval is not None:
        models = args.models_to_eval
        models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    else:
        if os.path.exists(args.results_openended_hard):
            models = os.listdir(args.results_openended_hard)
    
    assert args.baseline_model_response_openended_hard in AVAILABLE_MODELS.keys(), \
        f"Invalid baseline model: {args.baseline_model_response_openended_hard}"
    model_lengths = {}
    tasks_ref = []
    with open(
        f"mix_eval/data/open_hard_baseline/{args.baseline_model_response_openended_hard}.jsonl", 
        "r",
        ) as f:
        for line in f:
            ref_dict = json.loads(line)
            tasks_ref.append(ref_dict)
    model_lengths[
        args.baseline_model_response_openended_hard
        ] = get_model_answer_avrlen(tasks_ref)
    
    results = []
    for model in models:
        print(f"Processing model: {model}")
        results_m = []
        with open(
            os.path.join(
                args.results_openended_hard, 
                model, 
                f"judge_{args.open_judge}_base_{args.baseline_model_response_openended_hard}.jsonl",
                ),
            'r'
            ) as f:
            for line in f:
                judge_dict = json.loads(line)
                results_m.append(judge_dict)
                
        model_lengths[model] = get_model_answer_avrlen(results_m)
        for result_m in results_m:
            result_m['model'] = model
        
        results.extend(results_m)
    
    get_scores_from_results_openhard(args, results, model_lengths)
    

def compute_metric_closeended_freeform_modelparse(args):
    if args.compute_score_from_judged_file:
        return compute_metric_closeended_freeform_modelparse_from_judgefile(args)
    
    score_dict = {}
    if args.models_to_eval is not None:
        models = args.models_to_eval
        models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    else:
        if os.path.exists(args.results_closeended_freeform):
            models = os.listdir(args.results_closeended_freeform)
            models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    
    for model in models:
        print(f"\n\n\nProcessing model: {model}\n\n\n")
        if args.extract_base_model_response:
            args.model_type = mix_eval.api.registry.get_model(model).__bases__[0].__name__
        ans_file = os.path.join(
            args.results_closeended_freeform, 
            model, 
            f"{model}.jsonl"
            )
        tasks = []
        with open(ans_file, "r") as f:
            for line in f:
                ans_dict = json.loads(line)
                tasks.append(ans_dict)
        results = eval_freeform_model(args, tasks)
        error_cases = [case for case in results if case["judge_score"] < 0.3]
        total_score = 0
        for task_dict in results:
            total_score += task_dict["judge_score"]
        score_avr = round(total_score/len(results), 3)
        score_dict[model] = score_avr
        with open(os.path.join(args.results_closeended_freeform, 
                               model, 
                               f"score_ff_model_{score_avr}.txt"), "w") as f:
            pass
        with open(os.path.join(args.results_closeended_freeform, 
                               model, 
                               f"judge_results_ff_model_judge_{args.freeform_judge}.jsonl"), "w") as f:
            for case in results:
                f.write(json.dumps(case) + "\n")
        with open(os.path.join(args.results_closeended_freeform, 
                               model, 
                               f"error_cases_ff_model_judge_{args.freeform_judge}.jsonl"), "w") as f:
            for case in error_cases:
                f.write(json.dumps(case) + "\n")
        
        print("Sleep 60 seconds to avoid ratelimit error ... ")
        time.sleep(60)
        
    with open(os.path.join(args.results_closeended_freeform, 
                            "score_dict_model.json"), "w") as f:
        f.write(json.dumps(score_dict, indent=4) + "\n")
    print(f"[Close-ended Free-form Model Parser]")
    for model, score in score_dict.items():
        print(f"{model}: {score * 100}")
        

def compute_metric_closeended_freeform_ruleparse(args):
    if args.compute_score_from_judged_file:
        return compute_metric_closeended_freeform_ruleparse_from_judgefile(args)
    
    score_dict = {}
    if args.models_to_eval is not None:
        models = args.models_to_eval
        models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    else:
        if os.path.exists(args.results_closeended_freeform):
            models = os.listdir(args.results_closeended_freeform)
            models = [model for model in models if model in AVAILABLE_MODELS.keys()]
        
    for model in tqdm(models):
        if args.extract_base_model_response:
            args.model_type = mix_eval.api.registry.get_model(model).__bases__[0].__name__
        ans_file = os.path.join(
            args.results_closeended_freeform, 
            model, 
            f"{model}.jsonl"
            )
        with open(ans_file, "r") as f:
            total = 0
            correct = 0
            results = []
            error_cases = []
            for line in f:
                ans_dict = json.loads(line)
                target = ans_dict["target"]
                assert isinstance(target, list)
                model_response = ans_dict["response"]
                preds = parse_freeform_response_rule(args, model_response)
                if eval_freeform_rule(target, preds):
                    correct += 1
                else:
                    error_cases.append(ans_dict)
                ans_dict['parsed_result'] = preds
                ans_dict['eval_result'] = eval_freeform_rule(target, preds)
                results.append(ans_dict)
                total += 1
            acc = round(correct/total, 3)
        score_dict[model] = acc
        
        with open(os.path.join(args.results_closeended_freeform, 
                               model, 
                               f"accuracy_ff_rule_{acc}.txt"), "w") as f:
            pass
        with open(os.path.join(args.results_closeended_freeform, 
                               model, 
                               "judge_results_ff_rule.jsonl"), "w") as f:
            for case in results:
                f.write(json.dumps(case) + "\n")
        with open(os.path.join(args.results_closeended_freeform,
                               model,
                               "error_cases_ff_rule.jsonl"), "w") as f:
            for case in error_cases:
                f.write(json.dumps(case) + "\n")
    with open(os.path.join(args.results_closeended_freeform, 
                            "score_dict_rule.json"), "w") as f:
        f.write(json.dumps(score_dict, indent=4) + "\n")
    print(f"[Close-ended Free-form Rule Parser]")
    for model, score in score_dict.items():
        print(f"{model}: {score * 100}")

def compute_metric_closeended_freeform(args):
    if args.free_form_parser == "model":
        compute_metric_closeended_freeform_modelparse(args)
    else:
        compute_metric_closeended_freeform_ruleparse(args)


def compute_metric_closeended_multichoice_ruleparse(args):
    if args.compute_score_from_judged_file:
        return compute_metric_closeended_multichoice_ruleparse_from_judgefile(args)
    
    score_dict = {}
    if args.models_to_eval is not None:
        models = args.models_to_eval
        models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    else:
        if os.path.exists(args.results_closeended_multichoice):
            models = os.listdir(args.results_closeended_multichoice)
            models = [model for model in models if model in AVAILABLE_MODELS.keys()]
        
    for model in tqdm(models):
        if args.extract_base_model_response:
            args.model_type = mix_eval.api.registry.get_model(model).__bases__[0].__name__
        ans_file = os.path.join(
            args.results_closeended_multichoice, 
            model, 
            f"{model}.jsonl"
            )
        with open(ans_file, "r") as f:
            total = 0
            correct = 0
            results = []
            error_cases = []
            for line in f:
                ans_dict = json.loads(line)
                options = ans_dict["options"]
                target = ans_dict["target"]
                assert isinstance(target, list) and len(target) == 1, \
                    f"Invalid target: {target}"
                model_response = ans_dict["response"]
                all_choices = [chr(ord("A") + i) for i in range(len(options))]
                index2ans = {id: option for id, option in zip(all_choices, options)}
                model_choice = parse_multi_choice_response_rule(
                    args,
                    model_response, 
                    all_choices, 
                    index2ans
                    )
                target_id = all_choices[target[0]]
                if eval_multi_choice(target_id, model_choice):
                    correct += 1
                else:
                    error_cases.append(ans_dict)
                ans_dict['parsed_result'] = model_choice
                ans_dict['eval_result'] = eval_multi_choice(target_id, model_choice)
                results.append(ans_dict)
                total += 1
            
        acc = round(correct/total, 3)
        score_dict[model] = acc

        with open(os.path.join(args.results_closeended_multichoice, 
                               model, 
                               f"accuracy_mp_rule_{acc}.txt"), "w") as f:
            pass
            
        with open(os.path.join(args.results_closeended_multichoice, 
                               model, 
                               "judge_results_mp_rule.jsonl"), "w") as f:
            for case in results:
                f.write(json.dumps(case) + "\n")
        with open(os.path.join(args.results_closeended_multichoice,
                               model,
                               "error_cases_mp_rule.jsonl"), "w") as f:
            for case in error_cases:
                f.write(json.dumps(case) + "\n")
    with open(os.path.join(args.results_closeended_multichoice, 
                            "score_dict_rule.json"), "w") as f:
        f.write(json.dumps(score_dict, indent=4) + "\n")
    print(f"[Close-ended Multiple-choice Rule Parser]")
    for model, score in score_dict.items():
        print(f"{model}: {score * 100}")
                
def compute_metric_closeended_multichoice_modelparse(args):
    if args.compute_score_from_judged_file:
        return compute_metric_closeended_multichoice_modelparse_from_judgefile(args)
    
    score_dict = {}
    if args.models_to_eval is not None:
        models = args.models_to_eval
        models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    else:
        if os.path.exists(args.results_closeended_multichoice):
            models = os.listdir(args.results_closeended_multichoice)
            models = [model for model in models if model in AVAILABLE_MODELS.keys()]
        
    for model in models:
        print(f"\n\n\nProcessing model: {model}\n\n\n")
        if args.extract_base_model_response:
            args.model_type = mix_eval.api.registry.get_model(model).__bases__[0].__name__
        ans_file = os.path.join(
            args.results_closeended_multichoice, 
            model, 
            f"{model}.jsonl"
            )
        with open(ans_file, "r") as f:
            ans_dicts = []
            for line in f:
                ans_dict = json.loads(line)
                ans_dicts.append(ans_dict)
                
            ans_dicts_withscore = parse_multi_choice_response_model(args, ans_dicts)
            
            total = 0
            correct = 0
            results = []
            error_cases = []
            for ans_dict_ws in ans_dicts_withscore:
                options = ans_dict_ws["options"]
                target = ans_dict_ws["target"]
                assert isinstance(target, list) and len(target) == 1, \
                    f"Invalid target: {target}"
                all_choices = [chr(ord("A") + i) for i in range(len(options))]
                model_choice = ans_dict_ws['judge_option']
                target_id = all_choices[target[0]]
                if eval_multi_choice(target_id, model_choice):
                    correct += 1
                else:
                    error_cases.append(ans_dict_ws)
                results.append(ans_dict_ws)
                total += 1
                
        acc = round(correct/total, 3)
        score_dict[model] = acc
        with open(os.path.join(args.results_closeended_multichoice, 
                               model, 
                               f"accuracy_mp_model_{acc}.txt"), "w") as f:
            pass
            
        with open(os.path.join(args.results_closeended_multichoice, 
                               model, 
                               f"judge_results_mp_model_judge_{args.multichoice_judge}.jsonl"
                               ), "w") as f:
            for case in results:
                f.write(json.dumps(case) + "\n")

        with open(os.path.join(args.results_closeended_multichoice,
                               model,
                               f"error_cases_mp_model_judge_{args.multichoice_judge}.jsonl"), "w") as f:
            for case in error_cases:
                f.write(json.dumps(case) + "\n")
                
        print("Sleep 60 seconds to avoid ratelimit error ... ")
        time.sleep(60)
        
    with open(os.path.join(args.results_closeended_multichoice, 
                            "score_dict_model.json"), "w") as f:
        f.write(json.dumps(score_dict, indent=4) + "\n")
    print(f"[Close-ended Multiple-choice Model Parser]")
    for model, score in score_dict.items():
        print(f"{model}: {score * 100}")

def compute_metric_closeended_multichoice(args):
    if args.multi_choice_parser == "model":
        compute_metric_closeended_multichoice_modelparse(args)
    else:
        compute_metric_closeended_multichoice_ruleparse(args)
                
def compute_metric_openended(args):
    if args.models_to_eval is not None:
        models = args.models_to_eval
        models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    else:
        if os.path.exists(args.results_openended):
            models = os.listdir(args.results_openended)
        
    for model in models:
        print(f"\n\n\nProcessing model: {model}\n\n\n")
        ans_file = os.path.join(
            args.results_openended, 
            model, 
            f"{model}.jsonl"
            )
        tasks = []
        with open(ans_file, "r") as f:
            for line in f:
                ans_dict = json.loads(line)
                tasks.append(ans_dict)
        results = eval_openended(args, tasks)
        score_1 = 0
        score_2 = 0
        score_3 = 0
        for task_dict in results:
            assert len(task_dict["judge_score"]) == 3, \
                f"Invalid scores: {task_dict['judge_score']}"
            score_1 += task_dict["judge_score"][0]
            score_2 += task_dict["judge_score"][1]
            score_3 += task_dict["judge_score"][2]
        score_1_avr = round(score_1/len(results), 3)
        score_2_avr = round(score_2/len(results), 3)
        score_3_avr = round(score_3/len(results), 3)
        overall_score = round((score_1_avr+score_2_avr+score_3_avr)/3, 3)
        print(f"[Open-ended]\nModel: {model}\nTurn 1 score: {score_1_avr}"
              f"\nTurn 2 score: {score_2_avr}\nTurn 3 score: {score_3_avr}"
              f"\nOverall score: {overall_score}")
        with open(os.path.join(args.results_openended, 
                               model, 
                               f"scores.txt"), "w") as f:
            f.write(f"Turn 1: {score_1_avr}\n")
            f.write(f"Turn 2: {score_2_avr}\n")
            f.write(f"Turn 3: {score_3_avr}\n")
            f.write(f"Overall: {overall_score}\n")
        with open(os.path.join(args.results_openended, 
                               model, 
                               "judge_results.jsonl"), "w") as f:
            for case in results:
                f.write(json.dumps(case) + "\n")
        
        print("Sleep 60 seconds to avoid ratelimit error ... ")
        time.sleep(60)


def compute_metric_openhard(args):
    if args.compute_score_from_judged_file:
        return compute_metric_openhard_from_judgefile(args)
    
    if args.models_to_eval is not None:
        models = args.models_to_eval
        models = [model for model in models if model in AVAILABLE_MODELS.keys()]
    else:
        if os.path.exists(args.results_openended_hard):
            models = os.listdir(args.results_openended_hard)


    assert args.baseline_model_response_openended_hard in AVAILABLE_MODELS.keys(), \
        f"Invalid baseline model: {args.baseline_model_response_openended_hard}"
    model_lengths = {}
    tasks_ref = []
    with open(
        f"mix_eval/data/open_hard_baseline/{args.baseline_model_response_openended_hard}.jsonl", 
        "r",
        ) as f:
        for line in f:
            ref_dict = json.loads(line)
            tasks_ref.append(ref_dict)
    model_lengths[
        args.baseline_model_response_openended_hard
        ] = get_model_answer_avrlen(tasks_ref)
    
    results = []
    
    for model in models:
        print(f"\n\n\nProcessing model: {model}\n\n\n")
        ans_file = os.path.join(
            args.results_openended_hard, 
            model, 
            f"{model}.jsonl"
            )
        tasks = []
        with open(ans_file, "r") as f:
            for line in f:
                ans_dict = json.loads(line)
                tasks.append(ans_dict)
                
        task_pairs = get_task_pairs(tasks, tasks_ref)
        results_m = eval_openended_hard(args, task_pairs)
        model_lengths[model] = get_model_answer_avrlen(results_m)
        for result_m in results_m:
            result_m['model'] = model
        
        with open(
            os.path.join(
                args.results_openended_hard, 
                model, 
                f"judge_{args.open_judge}_base_{args.baseline_model_response_openended_hard}.jsonl",
                ), 
            "w"
            ) as f:
            for case in results_m:
                f.write(json.dumps(case) + "\n")
            
        results.extend(results_m)
        
        print("Sleep 30 seconds to avoid ratelimit error ... ")
        time.sleep(30)
    
    get_scores_from_results_openhard(args, results, model_lengths)
    
    


def compute_metric(args):
    if args.split == "close_freeform" or args.split == "close_freeform_hard":
        compute_metric_closeended_freeform(args)
    elif args.split == "close_multichoice" or args.split == "close_multichoice_hard":
        compute_metric_closeended_multichoice(args)
    elif args.split == "open":
        compute_metric_openended(args)
    elif args.split == "open_hard":
        compute_metric_openhard(args)
    else:
        raise ValueError(f"Invalid split: {args.split}")


if __name__ == '__main__':
    set_seed()
    args = parse_args()
    compute_metric(args)