import json
import argparse
from collections import Counter
from transformers import AutoTokenizer
import os


def classify_token(freq, rare_th=5, freq_th=100):
    if freq <= rare_th:
        return "rare"
    elif freq <= freq_th:
        return "medium"
    else:
        return "frequent"


def main(aime_file, math_freq_file, out_json, rare_json):
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

    # Load math token frequency distribution
    with open(math_freq_file, "r") as f:
        math_token_freq = json.load(f)

    # Load AIME24
    with open(aime_file, "r") as f:
        data = json.load(f)

    aime_tokens = []

    for item in data:
        text = item["question"] + " " + " ".join(item["thinking_trajectories"])
        token_ids = tokenizer.encode(text, add_special_tokens=False)
        tokens = [tokenizer.decode([tid]) for tid in token_ids]
        aime_tokens.extend(tokens)

    aime_token_counts = Counter(aime_tokens)

    comparison_results = {}

    summary = {
        "rare": 0,
        "medium": 0,
        "frequent": 0,
        "unseen": 0
    }

    rare_tokens = {}

    for token, aime_freq in aime_token_counts.items():
        math_freq = math_token_freq.get(token, 0)

        if token not in math_token_freq:
            category = "unseen"
            summary["unseen"] += 1
        else:
            category = classify_token(math_freq)
            summary[category] += 1

        comparison_results[token] = {
            "aime_frequency": aime_freq,
            "math_corpus_frequency": math_freq,
            "category": category
        }

        # Save rare tokens with freq counts
        if category == "rare":
            rare_tokens[token] = {
                "aime_frequency": aime_freq,
                "math_corpus_frequency": math_freq
            }

    # Save full comparison JSON
    with open(out_json, "w") as f:
        json.dump(comparison_results, f, indent=2, ensure_ascii=False)

    # Save rare tokens JSON
    with open(rare_json, "w") as f:
        json.dump(rare_tokens, f, indent=2, ensure_ascii=False)

    print("✅ Full comparison saved to:", os.path.abspath(out_json))
    print("✅ Rare tokens saved to:", os.path.abspath(rare_json))

    print("\n===== SUMMARY =====")
    print(f"Total unique AIME tokens: {len(aime_token_counts)}")
    print(f"Rare tokens: {summary['rare']}")
    print(f"Medium tokens: {summary['medium']}")
    print(f"Frequent tokens: {summary['frequent']}")
    print(f"Unseen tokens: {summary['unseen']}")

    # Print sample rare tokens
    print("\nSample rare tokens (up to 10):")
    for token, info in list(rare_tokens.items())[:10]:
        print(f"{repr(token)} -> AIME freq: {info['aime_frequency']}, Math freq: {info['math_corpus_frequency']}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compare AIME24 tokens with math token frequency distribution")

    parser.add_argument("--aime_file", type=str, required=True,
                        help="Path to aime24.json")

    parser.add_argument("--math_freq_file", type=str, required=True,
                        help="Path to math_token_freq_distribution1.json")

    parser.add_argument("--out_json", type=str, required=True,
                        help="Output full comparison JSON")

    parser.add_argument("--rare_json", type=str, required=True,
                        help="Output JSON containing only rare tokens")

    args = parser.parse_args()

    main(args.aime_file, args.math_freq_file, args.out_json, args.rare_json)
