import argparse
import json

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="all")
    parser.add_argument("--output_dir", type=str, default="output/answer")

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    if args.model == "all":
        models = [
            "llama-3.2-3b",
            "llama-3.1-8b",
            "llama-3.1-70b",
            "gemma-2-9b",
            "gemma-2-27b",
            "qwen-2.5-1.5b",
            "qwen-2.5-3b",
            "qwen-2.5-7b",
            "qwen-2.5-72b",
            "mistral-7b-v0.3",
            "phi-3.5-mini",
            "gpt-3.5-turbo",
            "gpt-4o",
            "gemma-2-2b",
            "llama-3.2-1b",
            "mistral-small"
        ]
    else:
        models = [args.model]

    for model in models:
        # load answers
        with open(f'{args.output_dir}/humaneval/{model}_eval.jsonl', "r") as f:
            humaneval_results = [json.loads(line) for line in f]
        assert len(humaneval_results) == 164

        with open(f'{args.output_dir}/mbpp/{model}_eval.jsonl', "r") as f:
            mbpp_results = [json.loads(line) for line in f]
        assert len(mbpp_results) == 974

        # # compute combined accuracy
        # combined_acc = (sum([instance["is_correct"] for instance in humaneval_results]) + sum([instance["is_correct"] for instance in mbpp_results])) / (len(humaneval_results) + len(mbpp_results))

        # # print(f"HumanEval Acc: {sum([instance['is_correct'] for instance in humaneval_results]) / len(humaneval_results)}")
        # # print(f"MBPP Acc: {sum([instance['is_correct'] for instance in mbpp_results]) / len(mbpp_results)}")

        # print(f"{'='*20}")
        # print(f"Model: {model}, Combined Accuracy: {combined_acc}")
        # print(f"{'='*20}")

        # make combined results
        combined_results = humaneval_results + mbpp_results
        with open(f'{args.output_dir}/humaneval_mbpp/{model}_eval.jsonl', "w") as f:
            for instance in combined_results:
                f.write(json.dumps(instance) + "\n")