import json
import argparse
from collections import Counter
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
import os


def main(json_path, plot_file, out_json):
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

    with open(json_path, "r") as f:
        data = json.load(f)

    all_tokens = []

    for item in data:
        question = item["question"]
        thoughts = " ".join(item["thinking_trajectories"])
        combined_text = question + " " + thoughts

        tokens = tokenizer.encode(combined_text, add_special_tokens=False)
        all_tokens.extend(tokens)

    # Token frequency
    token_counts = Counter(all_tokens)

    # Frequency of frequencies
    freq_of_freq = Counter(token_counts.values())

    # Save token -> frequency JSON (decoded tokens)
    token_freq_dict = {
        tokenizer.decode([token_id]): count
        for token_id, count in token_counts.items()
    }

    with open(out_json, "w") as f:
        json.dump(token_freq_dict, f, indent=2, ensure_ascii=False)

    print(f"✅ Token frequency JSON saved to: {os.path.abspath(out_json)}")

    # Prepare plot data (frequency vs count)
    x = sorted(freq_of_freq.keys())
    y = [freq_of_freq[freq] for freq in x]

    plt.figure(figsize=(8,5))
    plt.plot(x, y, marker="o")
    plt.xlabel("Token Frequency")
    plt.ylabel("Number of Tokens")
    plt.title("Token Frequency vs Count Distribution (Qwen2.5)")
    plt.xscale("log")
    plt.yscale("log")

    plt.savefig(plot_file, dpi=300, bbox_inches="tight")
    plt.close()

    print(f"✅ Plot saved to: {os.path.abspath(plot_file)}")

    print(f"Total unique tokens: {len(token_counts)}")
    print("Top 10 tokens:")
    for token_id, count in token_counts.most_common(10):
        print(f"{repr(tokenizer.decode([token_id]))}: {count}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Token frequency distribution using Qwen2.5 tokenizer")

    parser.add_argument("--file", type=str, required=True,
                        help="Path to input JSON file")

    parser.add_argument("--plot_file", type=str, required=True,
                        help="Path to save plot image (png)")

    parser.add_argument("--out_json", type=str, required=True,
                        help="Path to save token frequency JSON")

    args = parser.parse_args()

    main(args.file, args.plot_file, args.out_json)
