#!/usr/bin/env python

import os
import sys
import traceback
import pathlib
import time
import argparse
import datasets

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from epic import *
from run_utils import *
from utils import (
    create_dir_w_timestamp,
    get_txt_filepaths_from_dirs,
    get_prog_filepaths_from_dirs,
    read_txt,
    write_jsonl,
    read_epics,
    write_txt,
    write_json,
    load_json
)

models = {
    "test": imgpatch_test,
    "local": imgpatch_local,
    "4o_all": imgpatch_4o_all,
}

KINDS_DIR = os.path.join(os.path.dirname(__file__), f"../datasets/gqa")

parser = argparse.ArgumentParser()
parser.add_argument(
    '-k', '--kind',
    help='Target kind',
    choices=list(f for f in os.listdir(KINDS_DIR) if os.path.isdir(os.path.join(KINDS_DIR, f))), required = True
)
parser.add_argument('-m', '--model', help='Model to use', choices=list(models.keys()), required=True)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-e', '--epic', help='Run using EPIC', action='store_true', default=False)
group.add_argument('-p', '--python', help='Run using Python', action='store_true', default=False)
group.add_argument('-a', '--async', help='Run using async Python', dest='pya', action='store_true', default=False)
group.add_argument('-s', '--set', help='Run using Set Python', action='store_true', default=False)
group = parser.add_mutually_exclusive_group()
group.add_argument('--record', help='Record calls for replay', action='store_true', default=False)
group.add_argument('--replay', help='Replay recorded calls', action='store_true', default=False)
group = parser.add_mutually_exclusive_group()
group.add_argument('--rounds', help='Count rounds', action='store_true', default=False)
args = parser.parse_args()

EPIC = args.epic

RECORD = args.record

MODEL = args.model

REPLAY = args.replay

SYNC = not (EPIC or args.pya)

ROUNDS = args.rounds

OUTPUT = not (RECORD or ROUNDS)

SET = args.set

# Directory for vipergpt programs
KIND = args.kind
DIRNAME = f"{KINDS_DIR}/{KIND}"

# Load vipergpt programs
if EPIC:
    filepaths = get_prog_filepaths_from_dirs([f"{DIRNAME}/progs_epic"])
elif SET:
    filepaths = get_prog_filepaths_from_dirs([f"{DIRNAME}/progs_set"])
else:
    filepaths = get_prog_filepaths_from_dirs([f"{DIRNAME}/progs_py"])

split = "val"
dataset_images = datasets.load_dataset("lmms-lab/GQA", f"{split}_all_images", split=split)
dataset_instructions = datasets.load_dataset("lmms-lab/GQA", f"{split}_all_instructions", split=split)
image_lookup = {item["id"]: item["image"] for item in dataset_images}

# TODO: this is a hack to make it faster because iterating over the whole dataset to construct `problem_image_lookup` is too slow
import random
random.seed(2025)
num_sample=1000
indices = random.sample(range(len(dataset_instructions)), len(dataset_instructions))[0:num_sample]

# problem_image_lookup = {item["id"]: item["imageId"] for item in dataset_instructions}
problem_image_lookup = {dataset_instructions[i]["id"]: dataset_instructions[i]["imageId"] for i in indices}

suffix = f"{MODEL}_{'epic' if EPIC else 'py' if SYNC else 'async'}"

EVAL_DIRNAME = f"{DIRNAME}/exec_{suffix}{'_replay' if REPLAY else ''}{'_set' if SET else ''}"
if OUTPUT:
    os.makedirs(EVAL_DIRNAME, exist_ok=True)
RECORD_DIRNAME = f"{DIRNAME}/recordings/{MODEL}"
if RECORD:
    os.makedirs(RECORD_DIRNAME, exist_ok=True)
ROUNDS_DIRNAME = f"{DIRNAME}/rounds_{suffix}"
if ROUNDS:
    os.makedirs(ROUNDS_DIRNAME, exist_ok=True)

CONTEXT = make_context(imgpatch_replay if REPLAY else models[MODEL], RECORD, SYNC, track_rounds=ROUNDS)

results = []
for i, fn_path in enumerate(filepaths, 1):
    # if i < 888: continue
    filename = os.path.basename(fn_path)
    print(f"[{i}/{len(filepaths)}] Processing: {filename}")

    filename_no_ext = os.path.splitext(filename)[0]
    problem_id = filename_no_ext
    image_id = problem_image_lookup[problem_id]

    exec_filepath = os.path.join(EVAL_DIRNAME, f"{filename_no_ext}.json")
    if RECORD or REPLAY:
        model_outputs_filepath = os.path.join(RECORD_DIRNAME, f"{filename_no_ext}.json")
    if ROUNDS:
        rounds_filepath = os.path.join(ROUNDS_DIRNAME, f"{filename_no_ext}.json")

    # _, image_id = extract_ids(filename)
    image = image_lookup.get(image_id)
    image = imgpatch.WrappedImage(image, image_id, CONTEXT)

    py_code = read_txt(fn_path)

    def clear_existing():
        traceback.print_exc()

        if RECORD:
            try:
                os.remove(model_outputs_filepath)
            except FileNotFoundError:
                pass

        if OUTPUT:
            try:
                os.remove(exec_filepath)
            except FileNotFoundError:
                pass

        if ROUNDS:
            try:
                os.remove(rounds_filepath)
            except FileNotFoundError:
                pass

    try:
        if EPIC:
            syntax._symbol_next_id = 1000 #HACK: Should find highest var in AST
            try:
                epics_expr = epics_syntax.from_str(read_epics(fn_path))
            except FileNotFoundError:
                print("Skipping due to missing translation.")
                clear_existing()
                continue
            try:
                epic_final = epics_vipergpt.finalize(epics_expr, [image], epics_vipergpt.make_mappings(CONTEXT.METHODS))
            except NotImplementedError as e:
                print("Ignoring due to translation failure.")
                clear_existing()
                continue
        else:
            execute_command = get_py_exec_command(py_code, filename, CONTEXT, ASYNC=not SYNC, SET=SET)
            if execute_command is None:
                print(f"        Skipping {filename} (no execute_command found)")
                continue

        if REPLAY:
            try:
                load_model_outputs(model_outputs_filepath, CONTEXT)
            except FileNotFoundError:
                print("Skipping due to missing recording.")
                clear_existing()
                continue
        elif RECORD:
            reset_model_outputs(CONTEXT)

        if ROUNDS:
            CONTEXT.ROUNDS = []

        if EPIC:
            try:
                start = time.perf_counter_ns()
                epic_exec_result = tuple(semantics.reduce_graph_opportunistic(epic_final))[-1]
                elapsed = time.perf_counter_ns() - start
            except KeyError as e:
                print("Ignoring due to eval failure.")
                clear_existing()
                continue
            except AssertionError as e:
                print("Ignoring due to eval failure.")
                clear_existing()
                continue

            # printing.print_func(epic_exec_result)
            try:
                epic_exec_result_value = epics_syntax.observe_term_as_value(epic_exec_result)
                if type(epic_exec_result_value) is str:
                    exec_result = epic_exec_result_value
                elif type(epic_exec_result_value) is imagepatch.ImagePatch:
                    exec_result = imagepatch.info(epic_exec_result_value)
                else:
                    assert False, (type(epic_exec_result_value), epic_exec_result_value)
            except Exception:
                print("Unable to find result")
                clear_existing()
                continue
        else:
            start = time.perf_counter_ns()
            exec_result = execute_command(image)
            elapsed = time.perf_counter_ns() - start

        print("=== RESULT ===")
        print(exec_result)
        # assert type(exec_result) is str

        if RECORD:
            with open(model_outputs_filepath, "w") as f:
                #print(CONTEXT.MODEL_OUTPUTS)
                json.dump(CONTEXT.MODEL_OUTPUTS, f)

        if OUTPUT:
            with open(exec_filepath, "w") as f:
                json.dump({
                    "result": exec_result,
                    "time_ns" : elapsed,
                }, f)

        if ROUNDS:
            with open(rounds_filepath, "w") as f:
                json.dump(CONTEXT.ROUNDS, f)

        print(f"         Execution completed in {elapsed / 1e6:.2f} ms")
        print(f"         Execution result: {exec_result}")
        #print(f"        Execution final term: {epic_exec_result}")
    except Exception as e:
        print(f"        Execution failed: {e}")
        #raise
        clear_existing()

