#!/usr/bin/env python

import os
import sys
import json
import argparse
import datasets
import string

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import (
    get_prog_filepaths_from_dirs,
)

KINDS_DIR = os.path.join(os.path.dirname(__file__), f"../datasets/gqa")

parser = argparse.ArgumentParser()
parser.add_argument('-k', '--kind', help='Target kind', choices=list(f for f in os.listdir(KINDS_DIR) if os.path.isdir(os.path.join(KINDS_DIR, f))), required = True)
parser.add_argument('-m', '--model', help='Model to use', required=True)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-e', '--epic', help='Run using EPIC', action='store_true', default=False)
group.add_argument('-p', '--python', help='Run using Python', action='store_true', default=False)
# group.add_argument('-a', '--async', help='Run using async Python', dest='pya', action='store_true', default=False)
args = parser.parse_args()

# Directory for vipergpt programs
DATASET_NAME = args.kind
DIRNAME = f"{KINDS_DIR}/{DATASET_NAME}"

EPIC = args.epic
MODEL = args.model

split = "val"
dataset_instructions = datasets.load_dataset("lmms-lab/GQA", f"{split}_all_instructions", split=split)

# TODO: this is a hack to make it faster because iterating over the whole dataset to construct `dataset_instructions_by_id` is too slow
import random
random.seed(2025)
num_sample=1000
indices = random.sample(range(len(dataset_instructions)), len(dataset_instructions))[0:num_sample]

# dataset_instructions_by_id = {item["id"]: item for item in dataset_instructions}
dataset_instructions_by_id = {dataset_instructions[i]["id"]: dataset_instructions[i] for i in indices}

# suffix = f"{MODEL}_{"epic" if EPIC else "py" if SYNC else "async"}{"_replay" if REPLAY else ""}"
suffix = f"{MODEL}_{'epic' if EPIC else 'py'}_replay"
EVAL_DIRNAME = f"{DIRNAME}/exec_{suffix}"

if EPIC:
    filepaths = get_prog_filepaths_from_dirs([f"{DIRNAME}/progs_epic"])
else:
    filepaths = get_prog_filepaths_from_dirs([f"{DIRNAME}/progs_py"])

matches = 0
nonmatches = []
missing = 0
for i, fn_path in enumerate(filepaths, 1):
    filename = os.path.basename(fn_path)
    problem_id = os.path.splitext(filename)[0]
    problem = dataset_instructions_by_id[problem_id]
    
    exec_filepath = os.path.join(EVAL_DIRNAME, f"{problem_id}.json")

    try:
        with open(exec_filepath, "r") as f:
            result = json.load(f)
    except FileNotFoundError:
        missing += 1
        continue
    pred_raw = result["result"]
    pred = pred_raw.translate(str.maketrans('', '', string.punctuation)).lower()
    ground_truth = problem["answer"]

    # if pred != ground_truth:
    if ground_truth not in pred:
        nonmatches.append((problem_id, pred, ground_truth))
    else:
        matches += 1

for problem_id, pred, ground_truth in nonmatches:
    print(problem_id, ":", pred, "|", ground_truth)
print(matches, len(nonmatches), missing)
print(f"{((matches + len(nonmatches))/1000*100):0.1f}%, {(matches/(matches + len(nonmatches)))*100:0.1f}%")