#!/usr/bin/env python3
"""
calc_kl.py

Compute KL divergence between each topic's test shard and the union of all other topics'
test shards, using empirical unigram token distributions (no model logits).

Given a directory containing files like:
  <topic>_test.bin, <topic>_train.bin, ...

We only read *_test.bin files. For each topic X, we compare:
  P_X = token frequency distribution from <X>_test.bin
  Q_-X = token frequency distribution from all other *_test.bin except X

We apply add-epsilon smoothing to avoid zeros and report KL(P_X || Q_-X) in nats and bits.

Example:
  python calc_kl.py --in_dir simple_stories_topics --out report_kl.json

Author: you :)
"""

from __future__ import annotations

import argparse
import json
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
from tqdm.auto import tqdm


@dataclass
class TopicStats:
    topic: str
    file: str
    total_tokens: int
    vocab_seen: int


@dataclass
class TopicKL:
    topic: str
    kl_nats: float
    kl_bits: float
    total_tokens_topic: int
    total_tokens_others: int
    vocab_size_used: int
    epsilon: float
    file: str


def load_counts_from_bin(bin_path: Path) -> Tuple[np.ndarray, TopicStats]:
    """
    Load a uint16 memmap of token ids and return bincounts for that shard.
    """
    arr = np.memmap(bin_path, dtype=np.uint16, mode="r")
    total = int(arr.shape[0])
    # bincount with minimal length; we'll pad later to a common vocab
    counts = np.bincount(arr.astype(np.int64))
    stats = TopicStats(
        topic=bin_path.stem.replace("_test", ""),
        file=str(bin_path),
        total_tokens=total,
        vocab_seen=int(counts.shape[0]),
    )
    # Make a regular ndarray (not memmap-backed) for safety downstream
    return counts.astype(np.int64, copy=True), stats


def kl_divergence(p: np.ndarray, q: np.ndarray) -> float:
    """
    KL(P||Q) for two probability vectors (same shape), using natural log.
    Assumes p and q are valid distributions (sum to 1, non-negative).
    """
    # mask out zero p to avoid 0*log(0/q) = 0
    mask = p > 0
    return float(np.sum(p[mask] * (np.log(p[mask]) - np.log(q[mask]))))


def main():
    ap = argparse.ArgumentParser(description="Compute topic-vs-others KL on *_test.bin shards.")
    ap.add_argument("--in_dir", type=str, required=True, help="Directory containing *_test.bin shards")
    ap.add_argument("--epsilon", type=float, default=1e-8, help="Additive smoothing for counts")
    ap.add_argument("--progress", action="store_true", help="Show per-file progress bars")
    args = ap.parse_args()

    in_dir = Path(args.in_dir)

    test_bins = sorted(in_dir.glob("*_test.bin"))
    if not test_bins:
        raise FileNotFoundError(f"No *_test.bin files found in: {in_dir}")

    # 1) Load counts per topic (individual bincount length may differ)
    per_topic_counts: Dict[str, np.ndarray] = {}
    per_topic_stats: Dict[str, TopicStats] = {}

    iterator = tqdm(test_bins, desc="Loading *_test.bin", disable=not args.progress)
    for p in iterator:
        counts, stats = load_counts_from_bin(p)
        topic = stats.topic
        if topic in per_topic_counts:
            raise ValueError(f"Duplicate topic name detected for {topic} (file: {p})")
        per_topic_counts[topic] = counts
        per_topic_stats[topic] = stats

    topics = sorted(per_topic_counts.keys())

    # 2) Determine common vocab size (max token id + 1 across all shards)
    vocab_size = max(arr.shape[0] for arr in per_topic_counts.values())

    # 3) Pad all count vectors to common vocab size and stack
    def pad_to_v(counts: np.ndarray, V: int) -> np.ndarray:
        if counts.shape[0] == V:
            return counts
        out = np.zeros((V,), dtype=np.int64)
        out[: counts.shape[0]] = counts
        return out

    counts_matrix = np.stack([pad_to_v(per_topic_counts[t], vocab_size) for t in topics], axis=0)
    totals = counts_matrix.sum(axis=1)  # per-topic token totals
    total_all = int(totals.sum())

    # 4) For each topic, form "others" by subtracting its counts from global sum
    global_counts = counts_matrix.sum(axis=0)  # shape (V,)

    eps = float(args.epsilon)
    kl_results: List[TopicKL] = []

    for i, topic in enumerate(topics):
        topic_counts = counts_matrix[i]
        others_counts = global_counts - topic_counts

        total_topic = int(topic_counts.sum())
        total_others = int(others_counts.sum())

        # Smoothing (Dirichlet add-epsilon)
        p_counts = topic_counts.astype(np.float64) + eps
        q_counts = others_counts.astype(np.float64) + eps

        # Normalize to probabilities
        p = p_counts / p_counts.sum()
        q = q_counts / q_counts.sum()

        # KL in nats
        kl_nats = kl_divergence(p, q)
        # KL in bits
        kl_bits = kl_nats / np.log(2.0)

        kl_results.append(
            TopicKL(
                topic=topic,
                kl_nats=kl_nats,
                kl_bits=kl_bits,
                total_tokens_topic=total_topic,
                total_tokens_others=total_others,
                vocab_size_used=vocab_size,
                epsilon=eps,
                file=per_topic_stats[topic].file,
            )
        )

    # 5) Write JSON report
    report = {
        "in_dir": str(in_dir.resolve()),
        "num_topics": len(topics),
        "topics": topics,
        "vocab_size_used": vocab_size,
        "epsilon": eps,
        "total_tokens_all_topics": total_all,
        "results": [asdict(r) for r in kl_results],
    }

    with open(in_dir / "kl_report.json", "w", encoding="utf-8") as f:
        json.dump(report, f, indent=2)

    print(f"Wrote KL report to: {in_dir / 'kl_report.json'}")


if __name__ == "__main__":
    main()