import argparse
import json
import sys

# Hardcoded prices (in USD) per 1M tokens
PRICE_PER_MILLION_INPUT_GPT5 = 1.25
PRICE_PER_MILLION_OUTPUT_GPT5 = 10.0


def tally_api_cost(preds_path):
    try:
        with open(preds_path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except Exception as e:
        print(f"Error loading predictions file: {e}", file=sys.stderr)
        sys.exit(1)

    total_input_tokens = 0
    total_output_tokens = 0
    for item in data:
        total_input_tokens += item.get("input_tokens", 0)
        total_output_tokens += item.get("output_tokens", 0)

    input_cost = (total_input_tokens / 1_000_000) * PRICE_PER_MILLION_INPUT_GPT5
    output_cost = (total_output_tokens / 1_000_000) * PRICE_PER_MILLION_OUTPUT_GPT5
    total_cost = input_cost + output_cost
    print(
        f"Total OpenAI API cost: ${total_cost:.2f} (input: ${input_cost:.2f}, output: ${output_cost:.2f})"
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Sum API cost from predictions file.")
    parser.add_argument("preds_file", type=str, help="Path to predictions JSON file.")
    args = parser.parse_args()
    tally_api_cost(args.preds_file)
