import json
from pathlib import Path
from collections import Counter, defaultdict

DATASET_DIR = Path(r"....")

def iter_records_from_file(fp: Path):
    try:
        suf = fp.suffix.lower()
        if suf == ".json":
            with fp.open("r", encoding="utf-8") as f:
                obj = json.load(f)
            if isinstance(obj, dict):
                yield obj
            elif isinstance(obj, list):
                for it in obj:
                    if isinstance(it, dict):
                        yield it
        elif suf == ".jsonl":
            with fp.open("r", encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    try:
                        obj = json.loads(line)
                        if isinstance(obj, dict):
                            yield obj
                    except Exception:
                        continue
    except Exception:
        return

def classify_source_strict(src: str | None) -> str:
    s = (src or "").strip().lower()
    if s == "textbook":
        return "Textbook"
    if s == "real-world assessment":
        return "Real-world"
    return "Unknown"

def main():
    topic_subtopic_counts: dict[str, Counter] = defaultdict(Counter)
    overall_topics = Counter()
    overall_subtopics = Counter()

    topic_source_counts: dict[str, Counter] = defaultdict(Counter)
    overall_sources = Counter()

    files_scanned = 0
    records_total = 0
    unknown_sources = 0

    for fp in DATASET_DIR.rglob("*"):
        if not fp.is_file() or fp.suffix.lower() not in {".json", ".jsonl"}:
            continue
        files_scanned += 1
        topic = fp.stem
        for rec in iter_records_from_file(fp):
            records_total += 1

            # Subtopic
            subtopic = rec.get("subtopic", "MISSING")
            topic_subtopic_counts[topic][subtopic] += 1
            overall_topics[topic] += 1
            overall_subtopics[subtopic] += 1

            # Source
            src_cat = classify_source_strict(rec.get("source"))
            if src_cat == "Unknown":
                unknown_sources += 1
            else:
                topic_source_counts[topic][src_cat] += 1
                overall_sources[src_cat] += 1

    # Summary
    print(f"Scanned directory: {DATASET_DIR}")
    print(f"Files scanned: {files_scanned}")
    print(f"Records found: {records_total}")
    if unknown_sources:
        print(f"Warning: {unknown_sources} record(s) had unexpected 'source' values.")
    print()

    # Per-topic subtopic breakdown
    for topic in sorted(topic_subtopic_counts.keys(), key=lambda t: t.lower()):
        counter = topic_subtopic_counts[topic]
        col_w = max([len("Subtopic")] + [len(s) for s in counter])
        total_topic = sum(counter.values())
        print(f"=== Topic: {topic} (Total: {total_topic}) ===")
        print(f"{'Subtopic':<{col_w}}  Count")
        print("-" * (col_w + 8))
        for subtopic, count in sorted(counter.items(), key=lambda x: (-x[1], x[0].lower())):
            print(f"{subtopic:<{col_w}}  {count}")
        print()

    # Overall by Topic
    if overall_topics:
        print("=== Overall by Topic ===")
        col_w_t = max([len("Topic")] + [len(t) for t in overall_topics])
        print(f"{'Topic':<{col_w_t}}  Count")
        print("-" * (col_w_t + 8))
        for t, c in sorted(overall_topics.items(), key=lambda x: (-x[1], x[0].lower())):
            print(f"{t:<{col_w_t}}  {c}")
        print()

    # Source distribution by topic
    print("=== Source distribution by topic ===")
    hdr = f"{'Topic (file)':<26}  {'Total':>5}  {'Textbook':>9}  {'Real-world':>11}  {'% Real-world':>13}"
    print(hdr)
    print("-" * len(hdr))
    for topic in sorted(topic_source_counts.keys(), key=lambda t: t.lower()):
        c = topic_source_counts[topic]
        tb = c.get("Textbook", 0)
        rw = c.get("Real-world", 0)
        total = tb + rw
        pct_rw = (100.0 * rw / total) if total else 0.0
        print(f"{topic:<26}  {total:>5}  {tb:>9}  {rw:>11}  {pct_rw:>12.1f}%")
    print()

    # Overall source distribution
    if overall_sources:
        tb = overall_sources.get("Textbook", 0)
        rw = overall_sources.get("Real-world", 0)
        total_all = tb + rw
        pct_rw = (100.0 * rw / total_all) if total_all else 0.0
        print("=== Overall Source Distribution ===")
        print(f"Total: {total_all}  |  Textbook: {tb}  |  Real-world: {rw}  |  % Real-world: {pct_rw:.1f}%")

if __name__ == "__main__":
    main()
