# Script to generate predictions for all images in the ChartQA test set and evaluate them using the TinyChart code.
# Now accepts output_path as a command-line argument.

import argparse
import json
import os
import sys

# Add the parent directory to the path so we can import from modules
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)

from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm

from modules.prompts import PROMPT_EXTRACT_DATA_FROM_CHART_3
from modules.utils import extract_data_from_chart


def main(input_images_dir, output_path, reasoning_effort, code_interpreter):
    load_dotenv()

    if code_interpreter:
        print("Using code interpreter for data extraction")

    openai_api_key = os.getenv("OPENAI_API_KEY")

    client = OpenAI()
    model = "gpt-5"

    image_dir = os.path.join(project_root, input_images_dir)

    # Get all png images in the directory
    all_images = [
        f
        for f in os.listdir(image_dir)
        if f.endswith(".png")
        or f.endswith(".jpg")
        or f.endswith(".jpeg")
        or f.endswith(".gif")
    ]

    print(f"Found {len(all_images)} images in {image_dir}")

    # Try to load existing predictions if file exists
    if os.path.exists(output_path):
        with open(output_path, "r") as f:
            try:
                answer_dict = json.load(f)
            except Exception:
                answer_dict = []
    else:
        answer_dict = []

    # Keep track of already processed images to avoid duplicates
    already_done = set(entry["image"] for entry in answer_dict)

    for image_name in tqdm(all_images, desc="Processing images"):
        if image_name in already_done:
            continue
        image_path = os.path.join(image_dir, image_name)
        res = extract_data_from_chart(
            image_path,
            client,
            model,
            prompt=PROMPT_EXTRACT_DATA_FROM_CHART_3,
            reasoning={"effort": reasoning_effort},
            code_interpreter=code_interpreter,
        )

        pred = res["csv_data"]
        input_tokens = res["input_tokens"]
        output_tokens = res["output_tokens"]
        # print(f"{image_name}: {pred}")
        answer_dict.append(
            {
                "image": image_name,
                "answer": pred,
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
            }
        )

        with open(output_path, "w") as f:
            json.dump(answer_dict, f, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate predictions for ChartQA test images"
    )
    parser.add_argument("input_images_dir", help="Path to input image directory")
    parser.add_argument("output_path", help="Path to output JSON file for predictions")
    parser.add_argument(
        "reasoning_effort", help="Reasoning effort to use for data extraction"
    )
    parser.add_argument(
        "--code_interpreter",
        help="Whether to use the code interpreter for data extraction",
        default=False,
        action="store_true",
        required=False,
    )

    args = parser.parse_args()
    main(
        args.input_images_dir,
        args.output_path,
        args.reasoning_effort,
        args.code_interpreter,
    )
