#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# dependencies = ["plotext"]
# ///
"""Plot grid of autointerp correlation histograms comparing SAE vs KSVD."""

import argparse
import json
from pathlib import Path

import plotext as plt


def load_corrs(path: Path) -> list[float]:
    """Load correlation values from JSON result file."""
    data = json.load(open(path))
    return [r["correlation"] for r in data["results"] if r["correlation"] is not None]


def plot_single(title: str, sae: list[float], ksvd: list[float], width: int, height: int):
    """Plot a single histogram comparison."""
    plt.clf()
    plt.plot_size(width, height)
    plt.hist(sae, bins=15)
    plt.hist(ksvd, bins=15)
    plt.title(f"{title} | Blue=SAE, Green=KSVD")
    plt.xlim(-0.2, 1.0)
    plt.show()


def plot_model(model: str, results_dir: Path, width: int, height: int, layout: str):
    """Plot histograms for a single model (standard vs matryoshka)."""
    configs = [
        ("", "standard"),
        ("_matryoshka", "matryoshka"),
    ]
    ks = [16, 32, 64]

    if layout == "grid":
        plt.subplots(2, 3)
        plt.plot_size(width, height)

    if layout == "column":
        plt.limitsize(height=False)
        plt.subplots(6, 1)
        plt.plot_size(width, height * 3)

    idx = 1
    for suffix, label in configs:
        for k in ks:
            sae_file = results_dir / f"sae_{model}_k{k}{suffix}.json"
            ksvd_file = results_dir / f"ksvd_{model}_k{k}{suffix}.json"
            title = f"{model} {label} k={k}"

            if not sae_file.exists() or not ksvd_file.exists():
                if layout == "single":
                    print(f"{title}: missing")
                else:
                    plt.subplot(idx // 3 + 1, idx % 3 + 1) if layout == "grid" else plt.subplot(idx, 1)
                    plt.title(f"{title} (missing)")
                idx += 1
                continue

            sae = load_corrs(sae_file)
            ksvd = load_corrs(ksvd_file)

            if layout == "single":
                plot_single(title, sae, ksvd, width, height)
            else:
                if layout == "grid":
                    row, col = (idx - 1) // 3 + 1, (idx - 1) % 3 + 1
                    plt.subplot(row, col)
                else:
                    plt.subplot(idx, 1)
                plt.hist(sae, bins=15)
                plt.hist(ksvd, bins=15)
                plt.title(title)
                plt.xlim(-0.2, 1.0)
            idx += 1

    if layout != "single":
        print(f"{model.upper()}: Blue=SAE, Green=KSVD")
        print()
        plt.show()


def plot_all(results_dir: Path, width: int, height: int, layout: str):
    """Plot histograms comparing both models."""
    configs = [
        ("vits14", "", "vits14 std"),
        ("vits14", "_matryoshka", "vits14 mat"),
        ("vitb14", "", "vitb14 std"),
        ("vitb14", "_matryoshka", "vitb14 mat"),
    ]
    ks = [16, 32, 64]

    if layout == "grid":
        plt.subplots(4, 3)
        plt.plot_size(width, height)

    if layout == "column":
        plt.limitsize(height=False)
        plt.subplots(12, 1)
        plt.plot_size(width, height * 6)

    idx = 1
    for model, suffix, label in configs:
        for k in ks:
            sae_file = results_dir / f"sae_{model}_k{k}{suffix}.json"
            ksvd_file = results_dir / f"ksvd_{model}_k{k}{suffix}.json"
            title = f"{label} k={k}"

            if not sae_file.exists() or not ksvd_file.exists():
                if layout == "single":
                    print(f"{title}: missing")
                else:
                    if layout == "grid":
                        plt.subplot((idx - 1) // 3 + 1, (idx - 1) % 3 + 1)
                    else:
                        plt.subplot(idx, 1)
                    plt.title(f"{title} (missing)")
                idx += 1
                continue

            sae = load_corrs(sae_file)
            ksvd = load_corrs(ksvd_file)

            if layout == "single":
                plot_single(title, sae, ksvd, width, height)
            else:
                if layout == "grid":
                    plt.subplot((idx - 1) // 3 + 1, (idx - 1) % 3 + 1)
                else:
                    plt.subplot(idx, 1)
                plt.hist(sae, bins=15)
                plt.hist(ksvd, bins=15)
                plt.title(title)
                plt.xlim(-0.2, 1.0)
            idx += 1

    if layout != "single":
        print("Legend: Blue=SAE, Green=KSVD")
        print()
        plt.show()


def main():
    parser = argparse.ArgumentParser(description="Plot autointerp correlation histograms")
    parser.add_argument(
        "--model",
        choices=["vits14", "vitb14", "all"],
        default="all",
        help="Model to plot (default: all)",
    )
    parser.add_argument(
        "--results-dir",
        default="results/autointerp",
        help="Directory containing JSON results",
    )
    parser.add_argument(
        "--layout",
        choices=["grid", "column", "single"],
        default="grid",
        help="Layout: grid (2x3/4x3), column (Nx1), single (one at a time)",
    )
    parser.add_argument("--width", type=int, default=140, help="Plot width (default: 140)")
    parser.add_argument("--height", type=int, default=40, help="Plot height (default: 40)")
    args = parser.parse_args()

    results_dir = Path(args.results_dir)
    if not results_dir.exists():
        print(f"Error: Results directory not found: {results_dir}")
        return 1

    if args.model == "all":
        plot_all(results_dir, args.width, args.height * 2, args.layout)
    else:
        plot_model(args.model, results_dir, args.width, args.height, args.layout)

    return 0


if __name__ == "__main__":
    exit(main())
