# compute_entity_stats.py

import os
import json
import argparse
import numpy as np


def collect_entity_counts(dataset_root):
    en_counts = []
    zh_counts = []
    # Recursively traverse dataset_root
    for dirpath, dirnames, filenames in os.walk(dataset_root):
        for fn in filenames:
            if fn.endswith('_en_entities.jsonl'):
                path = os.path.join(dirpath, fn)
                with open(path, 'r', encoding='utf-8') as f:
                    for line in f:
                        data = json.loads(line)
                        en_counts.append(len(data.get('entities', [])))
            elif fn.endswith('_zh_entities.jsonl'):
                path = os.path.join(dirpath, fn)
                with open(path, 'r', encoding='utf-8') as f:
                    for line in f:
                        data = json.loads(line)
                        zh_counts.append(len(data.get('entities', [])))
    return en_counts, zh_counts


def print_stats(name, counts):
    arr = np.array(counts, dtype=int)
    total = len(arr)
    zeros = int((arr == 0).sum())
    nonzeros = total - zeros

    print(f"\n--- {name} Entity Statistics ---")
    print(f"Total reports: {total}")
    print(f"Reports with 0 entities : {zeros} ({zeros/total:.2%})")
    print(f"Reports with >=1 entities: {nonzeros} ({nonzeros/total:.2%})\n")

    print(f"Min entities per report   : {arr.min()}")
    print(f"Max entities per report   : {arr.max()}")
    print(f"Mean entities per report  : {arr.mean():.2f}")
    print(f"Median entities per report: {int(np.median(arr))}\n")

    print("Percentiles:")
    for p in [25, 50, 75, 90, 95, 99]:
        print(f"  {p:>2}th percentile: {int(np.percentile(arr, p))}")

    # Simplified distribution statistics
    print("\nCount distribution (entities -> reports):")
    max_display = min(arr.max(), 10)
    for i in range(0, max_display):
        cnt = int((arr == i).sum())
        print(f"  {i:>3} -> {cnt}")
    cnt_ge = int((arr >= max_display).sum())
    print(f" >={max_display} -> {cnt_ge}")


def main():
    parser = argparse.ArgumentParser(
        description="Count entity distribution in all subdirectories for en/zh"
    )
    parser.add_argument(
        '--dataset_root', '-d', required=True,
        help='Dataset root directory, e.g., /root/autodl-tmp/dataset'
    )
    args = parser.parse_args()

    en_counts, zh_counts = collect_entity_counts(args.dataset_root)
    if not en_counts and not zh_counts:
        print("No *_en_entities.jsonl or *_zh_entities.jsonl files found")
        return

    if en_counts:
        print_stats("English Report", en_counts)
    else:
        print("No English entity files found")

    if zh_counts:
        print_stats("Chinese Report", zh_counts)
    else:
        print("No Chinese entity files found")


if __name__ == "__main__":
    main()