import json
import os

FILES = [
    "full_backward_250_acc.json",
    "traditional_backward_250_acc.json",
    "one_shot_backward_250_acc.json"
]

def sum_tokens(filename):
    """count prompt_tokens, output_tokens, total_tokens"""
    if not os.path.exists(filename):
        print(f"[Warning] File not found: {filename}")
        return (0, 0, 0)

    with open(filename, "r", encoding="utf-8") as f:
        data = json.load(f)

    p_sum = sum(item.get("prompt_tokens", 0) for item in data)
    o_sum = sum(item.get("output_tokens", 0) for item in data)
    t_sum = sum(item.get("total_tokens", 0) for item in data)

    return (p_sum, o_sum, t_sum)


def main():
    print("==============================================")
    print(" Token Cost Summary for 3 Backward Methods")
    print("==============================================\n")

    all_stats = {}

    for fname in FILES:
        p, o, t = sum_tokens(fname)
        all_stats[fname] = (p, o, t)

    # print results
    for fname, (p, o, t) in all_stats.items():
        print(f"File: {fname}")
        print(f"  prompt_tokens sum : {p}")
        print(f"  output_tokens sum : {o}")
        print(f"  total_tokens sum  : {t}")
        print("----------------------------------------------")

    print("============== Finished ======================")

if __name__ == "__main__":
    main()
