import argparse
import glob
import json
import os
import sys
import re
import subprocess

# Get the absolute path of the parent directory
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
# Add the parent directory to sys.path
sys.path.insert(0, parent_dir)

from run_text_generation import get_output_path
from config import EvaluationConfig


def get_input_output_paths(input_path, task):
    """Get all input files and an output path for a merged file."""
    # Single input file.
    if os.path.exists(input_path):
        input_file_paths = [input_path]
        output_file_path = input_path.replace(".jsonl", "-merged.json")
    # Select multiple partitions and dp ranks.
    else:
        cfg = EvaluationConfig(task=task, output_path=input_path, partition_id="*")
        pattern = get_output_path(cfg, dp_rank="*")
        input_file_paths = glob.glob(pattern)

        output_file_path = input_path + f"-{task}-merged.json"

    return input_file_paths, output_file_path


def convert_to_mmmu_format(input_path):
    """Convert input files to MMMU compatible format."""
    input_file_paths, output_file_path = get_input_output_paths(input_path, "MMMU")

    output = dict()

    for input_file_path in input_file_paths:
        with open(input_file_path, "r") as input_file:
            for line in input_file:
                res = json.loads(line)

                sample_id = res["sample_id"]
                prediction = res["prediction"]

                if res["question_type"] == "multiple-choice":
                    from MMMU.mmmu.utils.eval_utils import parse_multi_choice_response

                    prediction = parse_multi_choice_response(
                        prediction, res["all_choices"], res["index2ans"]
                    )

                # MMMU eval script expects just a sample_id to prediction mapping.
                # Skip possible duplicates.
                if sample_id in output:
                    continue

                output[sample_id] = prediction

    with open(output_file_path, "w") as output_file:
        json.dump(output, output_file)

    return output_file_path


def mmmu_eval(input_path, groundtruth_path):
    """Run MMMU evaluation."""
    result_file = convert_to_mmmu_format(input_path)

    # The MMMU repo has a script for running the actual evaluation but no API. So launching the script here.
    output = subprocess.run(
        [
            "python",
            "examples/multimodal/MMMU/mmmu/main_eval_only.py",
            "--output_path",
            result_file,
            "--answer_path",
            groundtruth_path,
        ],
        capture_output=True,
        text=True,
    )

    print(output.stderr)
    print(output.stdout)

    m = re.search("'Overall': {'num': \d+, 'acc': (\d.\d+)}", output.stdout)

    return float(m.group(1)) * 100.0


def main():
    """Run MMMU evaluation."""
    # Using the validation groundtruth file from the MMMU repo by default. This assumes you have cloned the MMMU github repo here.
    default_groundtruth_path = "examples/multimodal/MMMU/mmmu/answer_dict_val.json"

    parser = argparse.ArgumentParser()
    parser.add_argument("--input-path", type=str, required=True, help="Path to input file(s)")
    parser.add_argument(
        "--groundtruth-path",
        type=str,
        default=default_groundtruth_path,
        help="Path to groundtruth file. Defaults to the validation file in the MMMU repo.",
    )
    args = parser.parse_args()

    avg_acc = mmmu_eval(args.input_path, args.groundtruth_path)

    print(f"MMMU average accuracy: {avg_acc:.2f}")


if __name__ == "__main__":
    main()
