"""CLI entrypoint for ICML 2026 experiments."""

import typer
from pathlib import Path
from typing import Annotated

app = typer.Typer(help="DB-KSVD vs SAE comparison on ViT embeddings")


@app.command()
def extract(
    output: Annotated[Path, typer.Option(help="Output HDF5 path")] = Path("data/embeddings.h5"),
    model: Annotated[str, typer.Option(help="DINOv2 model name")] = "dinov2_vits14",
    dataset: Annotated[str, typer.Option(help="HuggingFace dataset name")] = "imagenet-1k",
    split: Annotated[str, typer.Option(help="Dataset split")] = "train",
    max_samples: Annotated[int | None, typer.Option(help="Max samples (None=all)")] = None,
    batch_size: Annotated[int, typer.Option(help="Batch size")] = 128,
    streaming: Annotated[bool, typer.Option(help="Use streaming mode")] = False,
    cache_dir: Annotated[str | None, typer.Option(help="HuggingFace cache dir")] = None,
    device: Annotated[str, typer.Option(help="Device (cuda/cpu)")] = "cuda",
    no_wandb: Annotated[bool, typer.Option(help="Disable wandb")] = False,
    dry_run: Annotated[bool, typer.Option(help="Quick test with 1k samples, cifar100, streaming, no wandb")] = False,
):
    """Extract DINOv2 embeddings from a dataset."""
    from .config import EmbeddingConfig
    from .extract_embeddings import extract_embeddings

    if dry_run:
        max_samples = 1000
        streaming = True  # WARNING: streaming uses small shuffle buffer, not suitable for production
        no_wandb = True
        dataset = "cifar100"  # No auth required
        print("DRY RUN: 1k samples, cifar100, streaming, no wandb")
        print("  (streaming mode - shuffle is approximate, for local debugging only)")

    config = EmbeddingConfig(
        model_name=model,
        dataset=dataset,
        split=split,
        batch_size=batch_size,
        max_samples=max_samples,
        output_path=output,
        streaming=streaming,
        cache_dir=cache_dir,
        device=device,
    )

    extract_embeddings(config, no_wandb=no_wandb)


@app.command()
def train_sae(
    embeddings: Annotated[Path, typer.Option(help="Input embeddings HDF5")] = Path("data/embeddings.h5"),
    output: Annotated[Path, typer.Option(help="Output model path")] = Path("models/sae.pt"),
    dict_size: Annotated[int, typer.Option(help="Dictionary size")] = 4096,
    k: Annotated[int, typer.Option(help="TopK sparsity")] = 16,
    lr: Annotated[float, typer.Option(help="Learning rate")] = 1e-3,
    batch_size: Annotated[int, typer.Option(help="Batch size")] = 2048,
    num_steps: Annotated[int, typer.Option(help="Training steps")] = 50_000,
    device: Annotated[str, typer.Option(help="Device")] = "cuda",
    no_wandb: Annotated[bool, typer.Option(help="Disable wandb")] = False,
    dry_run: Annotated[bool, typer.Option(help="Quick test with 1k steps, no wandb")] = False,
):
    """Train TopK SAE on embeddings."""
    import h5py
    from .config import SAEConfig
    from .sae import train_sae as _train_sae

    if dry_run:
        num_steps = 1000
        no_wandb = True
        print("DRY RUN: 1k steps, no wandb")

    # Get input dim from embeddings file
    with h5py.File(embeddings, "r") as f:
        input_dim = f.attrs["embedding_dim"]

    config = SAEConfig(
        input_dim=input_dim,
        dict_size=dict_size,
        k=k,
        learning_rate=lr,
        batch_size=batch_size,
        num_steps=num_steps,
        device=device,
    )

    _train_sae(embeddings, config, output, no_wandb=no_wandb)


@app.command()
def evaluate(
    embeddings: Annotated[Path, typer.Option(help="Embeddings HDF5")] = Path("data/embeddings.h5"),
    sae_model: Annotated[Path, typer.Option(help="SAE model path")] = Path("models/sae.pt"),
    sae_dict: Annotated[Path, typer.Option(help="SAE dictionary path")] = Path("models/sae.D.npy"),
    ksvd_dict: Annotated[Path, typer.Option(help="KSVD dictionary path")] = Path("models/ksvd.npy"),
    k: Annotated[int, typer.Option(help="Sparsity for KSVD encoding")] = 16,
    device: Annotated[str, typer.Option(help="Device")] = "cuda",
    output: Annotated[Path | None, typer.Option(help="Output JSON path")] = None,
    evals: Annotated[str, typer.Option(help="Which evals to run: all, recon, cluster, probe, sparse_probe (comma-separated)")] = "all",
    cache_codes: Annotated[Path | None, typer.Option(help="Cache sparse codes to/from this directory")] = None,
):
    """Evaluate and compare SAE vs DB-KSVD."""
    import json
    from .evaluate import compare_methods, print_comparison_table

    eval_set = set(evals.split(",")) if evals != "all" else {"recon", "cluster", "sparse_probe"}

    results = compare_methods(
        embeddings_path=embeddings,
        sae_model_path=sae_model,
        sae_dict_path=sae_dict,
        ksvd_dict_path=ksvd_dict,
        k=k,
        device=device,
        evals=eval_set,
        cache_codes_dir=cache_codes,
    )

    print_comparison_table(results)

    if output:
        output.parent.mkdir(parents=True, exist_ok=True)
        results_dict = {name: vars(r) for name, r in results.items()}
        with open(output, "w") as f:
            json.dump(results_dict, f, indent=2)
        print(f"Results saved to {output}")


@app.command()
def run_ksvd(
    embeddings: Annotated[Path, typer.Option(help="Embeddings HDF5")] = Path("data/embeddings.h5"),
    output: Annotated[Path, typer.Option(help="Output dictionary path")] = Path("models/ksvd.npy"),
    dict_size: Annotated[int, typer.Option(help="Dictionary size")] = 4096,
    k: Annotated[int, typer.Option(help="Sparsity (nnz per col)")] = 16,
    no_wandb: Annotated[bool, typer.Option(help="Disable wandb")] = False,
):
    """Print the Julia command to run DB-KSVD."""
    cmd = f"""julia --project=KSVD.jl scripts/ksvd_dino.jl \\
    {embeddings} {output} \\
    --dict-size={dict_size} --nnz={k}"""

    if no_wandb:
        cmd += " --no-wandb"

    print("Run the following command to train DB-KSVD:\n")
    print(cmd)
    print("\nNote: Make sure KSVD.jl dependencies are installed:")
    print("  cd KSVD.jl && julia --project -e 'using Pkg; Pkg.instantiate()'")


@app.command()
def full_pipeline(
    output_dir: Annotated[Path, typer.Option(help="Output directory")] = Path("outputs"),
    model: Annotated[str, typer.Option(help="DINOv2 model")] = "dinov2_vits14",
    dataset: Annotated[str, typer.Option(help="HuggingFace dataset")] = "imagenet-1k",
    dict_size: Annotated[int, typer.Option(help="Dictionary size")] = 4096,
    k: Annotated[int, typer.Option(help="Sparsity")] = 16,
    max_samples: Annotated[int | None, typer.Option(help="Max samples")] = None,
    device: Annotated[str, typer.Option(help="Device")] = "cuda",
    no_wandb: Annotated[bool, typer.Option(help="Disable wandb")] = False,
    dry_run: Annotated[bool, typer.Option(help="Quick dry run")] = False,
):
    """Run full pipeline: extract → train SAE → print KSVD command."""
    import h5py
    from .config import EmbeddingConfig, SAEConfig
    from .extract_embeddings import extract_embeddings
    from .sae import train_sae as _train_sae

    output_dir.mkdir(parents=True, exist_ok=True)

    emb_path = output_dir / "embeddings.h5"
    sae_path = output_dir / "sae.pt"

    if dry_run:
        max_samples = 1000
        no_wandb = True
        dataset = "cifar100"  # No auth required
        print("=" * 60)
        print("DRY RUN MODE (using cifar100)")
        print("=" * 60)

    # Step 1: Extract embeddings
    print("\n[1/3] Extracting embeddings...")
    emb_config = EmbeddingConfig(
        model_name=model,
        dataset=dataset,
        split="train",
        batch_size=128,
        max_samples=max_samples,
        output_path=emb_path,
        streaming=dry_run,
        cache_dir=None,
        device=device,
    )
    extract_embeddings(emb_config, no_wandb=no_wandb)

    # Step 2: Train SAE
    print("\n[2/3] Training SAE...")
    with h5py.File(emb_path, "r") as f:
        input_dim = f.attrs["embedding_dim"]

    sae_config = SAEConfig(
        input_dim=input_dim,
        dict_size=dict_size,
        k=k,
        learning_rate=1e-3,
        batch_size=min(2048, max_samples) if max_samples else 2048,
        num_steps=1000 if dry_run else 50_000,
        device=device,
    )
    _train_sae(emb_path, sae_config, sae_path, no_wandb=no_wandb)

    # Step 3: Print KSVD command
    ksvd_path = output_dir / "ksvd.npy"
    print("\n[3/3] To train DB-KSVD, run:")
    cmd = f"""julia --project=KSVD.jl scripts/ksvd_dino.jl \\
    {emb_path} {ksvd_path} \\
    --dict-size={dict_size} --nnz={k}"""
    if no_wandb:
        cmd += " --no-wandb"
    print(cmd)

    print(f"\n[4/4] After KSVD training, evaluate with:")
    print(f"  uv run python -m src.cli evaluate --embeddings={emb_path} \\")
    print(f"    --sae-model={sae_path} --sae-dict={sae_path.with_suffix('.D.npy')} \\")
    print(f"    --ksvd-dict={ksvd_path}")


if __name__ == "__main__":
    app()
