"""Benchmark lossless compression on quantized layer codes.

For each quantized layer, reports entropy and compressed size (bits/param)
using multiple algorithms: bzip2, gzip, lzma, zlib, zstandard, lz4.

Usage:
    python -m scripts.compression_benchmark --run_dir /path/to/quantized/run
    python -m scripts.compression_benchmark --run_dir /path/to/run --format csv
"""

from __future__ import annotations

import argparse
import bz2
import gzip
import lzma
import os
import sys
import time
import zlib
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from quant_layerwise.storage import LayerArtifact, RunManifest

# Algorithms: name -> (compress_fn, decompress_fn)
ALGORITHMS = None  # Lazy init to avoid import errors if libs missing


def _get_algorithms():
    global ALGORITHMS
    if ALGORITHMS is not None:
        return ALGORITHMS
    import lz4.frame
    import zstandard as zstd
    ALGORITHMS = {
        "bzip2": {
            "compress": lambda d: bz2.compress(d, compresslevel=9),
            "decompress": bz2.decompress,
        },
        "gzip": {
            "compress": lambda d: gzip.compress(d, compresslevel=9),
            "decompress": gzip.decompress,
        },
        "lzma": {
            "compress": lambda d: lzma.compress(d, preset=9),
            "decompress": lzma.decompress,
        },
        "zlib": {
            "compress": lambda d: zlib.compress(d, level=9),
            "decompress": zlib.decompress,
        },
        "zstd": {
            "compress": lambda d: zstd.ZstdCompressor(level=22).compress(d),
            "decompress": lambda d: zstd.ZstdDecompressor().decompress(d),
        },
        "lz4": {
            "compress": lambda d: lz4.frame.compress(d, compression_level=lz4.frame.COMPRESSIONLEVEL_MAX),
            "decompress": lz4.frame.decompress,
        },
    }
    return ALGORITHMS


ALGO_NAMES = ["lz4", "zlib", "gzip", "bzip2", "zstd", "lzma"]


def compute_entropy(codes: np.ndarray) -> float:
    """Compute Shannon entropy in bits/symbol over the full flattened array."""
    counts = Counter(codes.ravel())
    total = codes.size
    entropy = 0.0
    for count in counts.values():
        p = count / total
        if p > 0:
            entropy -= p * np.log2(p)
    return entropy


def compute_column_entropies(Z: np.ndarray) -> np.ndarray:
    """Compute Shannon entropy per column. Z shape: (out_features, in_features).

    Returns array of shape (in_features,) with entropy in bits for each column.
    """
    n_cols = Z.shape[1]
    entropies = np.zeros(n_cols)
    for col in range(n_cols):
        entropies[col] = compute_entropy(Z[:, col])
    return entropies


def pack_codes(Z: torch.Tensor) -> bytes:
    """Pack integer codes into minimum byte width, return raw bytes."""
    zmin = Z.min().item()
    zmax = Z.max().item()
    if zmin >= -127 and zmax <= 127:
        dtype = torch.int8
    elif zmin >= -32767 and zmax <= 32767:
        dtype = torch.int16
    else:
        dtype = torch.int32
    return Z.to(dtype).T.contiguous().numpy().tobytes(), dtype


def compress_all(raw_bytes: bytes) -> Dict[str, Dict]:
    """Compress with all algorithms, return sizes and timings."""
    algos = _get_algorithms()
    results = {}
    for name in ALGO_NAMES:
        funcs = algos[name]
        t0 = time.perf_counter()
        compressed = funcs["compress"](raw_bytes)
        t_compress = time.perf_counter() - t0

        t0 = time.perf_counter()
        funcs["decompress"](compressed)
        t_decompress = time.perf_counter() - t0

        results[name] = {
            "size": len(compressed),
            "compress_ms": t_compress * 1000,
            "decompress_ms": t_decompress * 1000,
        }
    return results


def benchmark_layer(artifact_path: Path, verbose: bool = False) -> Dict:
    """Benchmark compression for a single layer artifact."""
    art = LayerArtifact.load(artifact_path, map_location="cpu")

    if art.method == "fullprec":
        return {"name": art.module_name, "skipped": True}

    Z = art.payload.get("Z")
    if Z is None:
        Z = art.payload.get("Qint")
    if Z is None:
        return {"name": art.module_name, "skipped": True}

    num_params = Z.numel()
    Z_np = Z.numpy().astype(np.int32)

    # Entropy: global (whole matrix) and per-column average
    global_entropy = compute_entropy(Z_np)
    col_entropies = compute_column_entropies(Z_np)
    avg_col_entropy = float(col_entropies.mean())
    stored_entropy = art.payload.get("entropy", None)

    # Pack to min byte width and compress (column-major order)
    raw_bytes, pack_dtype = pack_codes(Z)
    raw_size = len(raw_bytes)
    comp_results = compress_all(raw_bytes)

    # Find best
    best_algo = min(comp_results, key=lambda a: comp_results[a]["size"])

    # Build per-algo bpp
    algo_bpp = {}
    for name, cr in comp_results.items():
        algo_bpp[name] = cr["size"] * 8 / num_params

    if verbose:
        print(f"\n  {art.module_name} ({num_params:,} params, packed as {pack_dtype}, shape {tuple(Z.shape)})")
        stored_str = f"{stored_entropy:.3f}" if stored_entropy is not None else "n/a"
        print(f"    Entropy: global={global_entropy:.3f}  avg_column={avg_col_entropy:.3f}  stored={stored_str}  bits/param")
        print(f"    Column entropy: min={col_entropies.min():.3f}  max={col_entropies.max():.3f}  std={col_entropies.std():.3f}")
        print(f"    Original: {raw_size:,} bytes")
        print(f"    {'Algorithm':<8} {'Size':>12} {'bpp':>8} {'Compress':>12} {'Decompress':>14}")
        print(f"    {'-'*60}")
        for name in ALGO_NAMES:
            cr = comp_results[name]
            bpp = algo_bpp[name]
            mark = " *" if name == best_algo else ""
            print(f"    {name:<8} {cr['size']:>12,} {bpp:>8.3f} {cr['compress_ms']:>10.1f} ms {cr['decompress_ms']:>12.1f} ms{mark}")

    return {
        "name": art.module_name,
        "num_params": num_params,
        "shape": tuple(Z.shape),
        "global_entropy": global_entropy,
        "avg_col_entropy": avg_col_entropy,
        "stored_entropy": stored_entropy,
        "pack_dtype": str(pack_dtype),
        "algo_bpp": algo_bpp,
        "best_algo": best_algo,
        "best_bpp": algo_bpp[best_algo],
        "comp_results": comp_results,
        "skipped": False,
    }


def parse_layer_name(name: str) -> Tuple[int, str]:
    """Parse 'layers.0.attention.wq' -> (0, 'wq')."""
    parts = name.split(".")
    if len(parts) >= 4 and parts[0] == "layers":
        return int(parts[1]), parts[3]
    return -1, name


def main():
    p = argparse.ArgumentParser(description="Benchmark lossless compression on quantized layers")
    p.add_argument("--run_dir", required=True, help="Path to quantized run directory")
    p.add_argument("--format", choices=["table", "csv"], default="table",
                   help="Output format (default: table)")
    p.add_argument("--layer_min", type=int, default=None,
                   help="Only benchmark layers >= this (default: all)")
    p.add_argument("--layer_max", type=int, default=None,
                   help="Only benchmark layers < this (default: all)")
    p.add_argument("--verbose", action="store_true",
                   help="Print per-layer per-algorithm details")
    args = p.parse_args()

    run_dir = Path(args.run_dir)
    if not run_dir.exists():
        print(f"Error: {run_dir} not found")
        sys.exit(1)

    # Load manifest
    manifest = RunManifest.load(run_dir / "manifest.json")
    print(f"Model: {manifest.model_name}, Method: {manifest.method}")
    print(f"Run: {manifest.run_id}")
    print(f"Layers: {len(manifest.artifacts)}")

    # Benchmark each layer
    results = []
    for module_name, rel_path in sorted(manifest.artifacts.items()):
        lid, _ = parse_layer_name(module_name)
        if args.layer_min is not None and lid < args.layer_min:
            continue
        if args.layer_max is not None and lid >= args.layer_max:
            continue
        artifact_path = run_dir / rel_path
        if not artifact_path.exists():
            print(f"  WARNING: {artifact_path} not found, skipping")
            continue
        result = benchmark_layer(artifact_path, verbose=args.verbose)
        results.append(result)

    active = [r for r in results if not r["skipped"]]
    if not active:
        print("No quantized layers found.")
        return

    # Per-layer summary table
    print(f"\n{'Layer':<12} {'Rate':>8} {'H(mat)':>8} {'H(col)':>8} ", end="")
    for algo in ALGO_NAMES:
        print(f" {algo:>8}", end="")
    print(f" {'Best':>8}  bits/param")
    print(f"{'-'*12} {'-'*8} {'-'*8} {'-'*8} ", end="")
    for _ in ALGO_NAMES:
        print(f" {'-'*8}", end="")
    print(f" {'-'*8}")

    for r in active:
        lid, wtype = parse_layer_name(r["name"])
        label = f"L{lid}_{wtype}"
        rate_str = f"{r['stored_entropy']:.3f}" if r["stored_entropy"] is not None else "    n/a"
        print(f"{label:<12} {rate_str:>8} {r['global_entropy']:>8.3f} {r['avg_col_entropy']:>8.3f} ", end="")
        for algo in ALGO_NAMES:
            print(f" {r['algo_bpp'][algo]:>8.3f}", end="")
        print(f" {r['best_bpp']:>8.3f}")

    # Weighted averages
    total_params = sum(r["num_params"] for r in active)
    avg_global_entropy = sum(r["global_entropy"] * r["num_params"] for r in active) / total_params
    avg_col_entropy = sum(r["avg_col_entropy"] * r["num_params"] for r in active) / total_params
    avg_bpp = {}
    for algo in ALGO_NAMES:
        avg_bpp[algo] = sum(r["algo_bpp"][algo] * r["num_params"] for r in active) / total_params
    avg_best = sum(r["best_bpp"] * r["num_params"] for r in active) / total_params

    avg_rate = sum((r["stored_entropy"] or 0) * r["num_params"] for r in active) / total_params
    print(f"{'-'*12} {'-'*8} {'-'*8} {'-'*8} ", end="")
    for _ in ALGO_NAMES:
        print(f" {'-'*8}", end="")
    print(f" {'-'*8}")
    print(f"{'Avg':<12} {avg_rate:>8.3f} {avg_global_entropy:>8.3f} {avg_col_entropy:>8.3f} ", end="")
    for algo in ALGO_NAMES:
        print(f" {avg_bpp[algo]:>8.3f}", end="")
    print(f" {avg_best:>8.3f}")

    # Summary
    print(f"\n{'='*60}")
    print(f"Summary ({len(active)} layers, {total_params/1e6:.1f}M params)")
    print(f"{'='*60}")
    print(f"  H(matrix) global:   {avg_global_entropy:.3f} bits/param")
    print(f"  H(col) avg column:  {avg_col_entropy:.3f} bits/param")
    for algo in ALGO_NAMES:
        ratio = avg_bpp[algo] / avg_col_entropy if avg_col_entropy > 0 else 0
        print(f"  {algo:<8}           {avg_bpp[algo]:.3f} bits/param  ({ratio:.3f}x H(col))")
    print(f"  Best per-layer:     {avg_best:.3f} bits/param")

    # Per weight-type summary
    by_type = defaultdict(lambda: {"params": 0, "global_entropy_sum": 0, "col_entropy_sum": 0, "best_sum": 0,
                                    **{f"{a}_sum": 0 for a in ALGO_NAMES}})
    for r in active:
        _, wtype = parse_layer_name(r["name"])
        by_type[wtype]["params"] += r["num_params"]
        by_type[wtype]["global_entropy_sum"] += r["global_entropy"] * r["num_params"]
        by_type[wtype]["col_entropy_sum"] += r["avg_col_entropy"] * r["num_params"]
        by_type[wtype]["best_sum"] += r["best_bpp"] * r["num_params"]
        for algo in ALGO_NAMES:
            by_type[wtype][f"{algo}_sum"] += r["algo_bpp"][algo] * r["num_params"]

    print(f"\nPer weight-type (bits/param):")
    print(f"  {'Type':<6} {'H(mat)':>8} {'H(col)':>8}", end="")
    for algo in ALGO_NAMES:
        print(f" {algo:>8}", end="")
    print(f" {'Best':>8}")
    print(f"  {'-'*6} {'-'*8} {'-'*8}", end="")
    for _ in ALGO_NAMES:
        print(f" {'-'*8}", end="")
    print(f" {'-'*8}")
    for wtype in ["wq", "wk", "wv", "wo", "w1", "w2", "w3"]:
        if wtype in by_type:
            d = by_type[wtype]
            n = d["params"]
            print(f"  {wtype:<6} {d['global_entropy_sum']/n:>8.3f} {d['col_entropy_sum']/n:>8.3f}", end="")
            for algo in ALGO_NAMES:
                print(f" {d[f'{algo}_sum']/n:>8.3f}", end="")
            print(f" {d['best_sum']/n:>8.3f}")

    if args.format == "csv":
        print(f"\n--- CSV ---")
        header = "layer,weight_type,num_params,global_entropy,avg_col_entropy," + ",".join(f"{a}_bpp" for a in ALGO_NAMES) + ",best_bpp,best_algo"
        print(header)
        for r in active:
            lid, wtype = parse_layer_name(r["name"])
            vals = ",".join(f"{r['algo_bpp'][a]:.4f}" for a in ALGO_NAMES)
            print(f"L{lid}_{wtype},{wtype},{r['num_params']},{r['global_entropy']:.4f},{r['avg_col_entropy']:.4f},{vals},{r['best_bpp']:.4f},{r['best_algo']}")


if __name__ == "__main__":
    main()
