from transformers import AutoTokenizer
import json
import os

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

# [csqa, med]
file_paths = [
    "results_output/arc_c_test_infer_v0_layer16.json",
    "results_output/csqa_val_infer_v0_layer16.json",
]

all_token_counts = []

for file_path in file_paths:
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    token_counts = [len(tokenizer.encode(entry["with_steering_output_1.0"])) for entry in data]
    average_tokens = sum(token_counts) / len(token_counts)
    all_token_counts.extend(token_counts)

    print(f"\nFile: {os.path.basename(file_path)}")
    # print("Number of tokens per question:", token_counts)
    print("Average number of tokens in this file:", average_tokens)

overall_average = sum(all_token_counts) / len(all_token_counts)
print("\n====================")
print("Overall average number of tokens across all files:", overall_average)