#!/usr/bin/env python
"""
Visualization script: Feature → Top Images

For a set of features, find their top-activating images and output visualization data.
"""

import argparse
import json
import os
import sys
from pathlib import Path

# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))

import numpy as np
from PIL import Image
from tqdm import tqdm

from src.autointerp import (
    AutoInterpConfig,
    load_dataset_sorted,
    extract_or_load_embeddings,
    compute_or_load_codes,
    get_embeddings_path,
    get_codes_path,
    get_imagenet_label_name,
    get_vlm_explanation,
    image_to_base64,
)


def get_top_activating_images(
    codes: np.ndarray,
    feature_idx: int,
    n_images: int = 8,
) -> tuple[np.ndarray, np.ndarray, bool]:
    """
    Get indices and activations of top-activating images for a feature.

    Returns:
        top_indices: Image indices sorted by activation
        top_activations: Activation values for those images
        sign_flipped: Whether sign was flipped (for KSVD compatibility)
    """
    activations = codes[:, feature_idx].copy()

    # Flip sign if feature is predominantly negative (for KSVD)
    nonzero = activations[activations != 0]
    sign_flipped = False
    if len(nonzero) > 0 and np.mean(nonzero) < 0:
        activations = -activations
        sign_flipped = True

    sorted_indices = np.argsort(activations)[::-1]
    top_indices = sorted_indices[:n_images]
    top_activations = activations[top_indices]

    return top_indices, top_activations, sign_flipped


def save_thumbnail(img: Image.Image, path: Path, size: int = 256):
    """Save image as thumbnail JPEG."""
    img_copy = img.copy()
    img_copy.thumbnail((size, size), Image.LANCZOS)
    path.parent.mkdir(parents=True, exist_ok=True)
    img_copy.save(path, "JPEG", quality=85)


def main():
    parser = argparse.ArgumentParser(description="Feature → Top Images visualization")
    parser.add_argument("--dictionary", required=True, help="Path to dictionary .npy file")
    parser.add_argument("--k", type=int, default=32, help="Sparsity level")
    parser.add_argument("--features", default="0,1,2",
                        help="Comma-separated feature indices, or 'from-json' to use autointerp results")
    parser.add_argument("--autointerp-json", help="Path to autointerp JSON (for explanations)")
    parser.add_argument("--n-images", type=int, default=8, help="Number of top images per feature")
    parser.add_argument("--output", help="Output JSON path")
    parser.add_argument("--save-images", help="Directory to save 256px JPEG thumbnails")
    parser.add_argument("--api-key", help="OpenAI API key for generating explanations")
    parser.add_argument("--n-samples", type=int, help="Limit dataset size (for dry runs)")
    parser.add_argument("--cache-dir", help="HuggingFace cache directory")
    parser.add_argument("--work-dir", default="cache/autointerp", help="Working directory")
    parser.add_argument("--device", default="cuda", help="Device for embeddings")
    args = parser.parse_args()

    # Load autointerp JSON if provided
    autointerp_data = None
    if args.autointerp_json:
        with open(args.autointerp_json) as f:
            autointerp_data = json.load(f)
        # Build lookup by feature_idx
        autointerp_lookup = {r["feature_idx"]: r for r in autointerp_data.get("results", [])}
    else:
        autointerp_lookup = {}

    # Determine feature indices
    if args.features == "from-json":
        if not autointerp_data:
            raise ValueError("--autointerp-json required when using --features=from-json")
        feature_indices = [r["feature_idx"] for r in autointerp_data["results"]]
    else:
        feature_indices = [int(x.strip()) for x in args.features.split(",")]

    # Infer DINO model from dictionary dimension
    dictionary = np.load(args.dictionary)
    emb_dim = dictionary.shape[0]
    dino_model = {384: "dinov2_vits14", 768: "dinov2_vitb14"}.get(emb_dim)
    if dino_model is None:
        raise ValueError(f"Unknown embedding dim {emb_dim}")
    print(f"Dictionary: {args.dictionary} (shape={dictionary.shape}, expects {dino_model})")

    # Load dataset and embeddings
    config = AutoInterpConfig(
        n_samples=args.n_samples,
        cache_dir=args.cache_dir,
        work_dir=args.work_dir,
    )

    emb_cache = get_embeddings_path(config.work_dir, config.dataset, dino_model, config.n_samples)
    if emb_cache.exists():
        print(f"Loading cached embeddings from {emb_cache}")
        embeddings = np.load(emb_cache)["embeddings"]
        images = load_dataset_sorted(config.dataset, config.n_samples, config.cache_dir)
    else:
        images = load_dataset_sorted(config.dataset, config.n_samples, config.cache_dir)
        embeddings = extract_or_load_embeddings(
            images, dino_model, config.batch_size, args.device, emb_cache
        )

    # Load sparse codes
    codes_cache = get_codes_path(config.work_dir, args.dictionary, args.k)
    codes = compute_or_load_codes(embeddings, args.dictionary, args.k, codes_cache)

    # API key for on-the-fly generation
    api_key = args.api_key or os.environ.get("OPENAI_API_KEY")

    # Process each feature
    results = []
    for feat_idx in tqdm(feature_indices, desc="Processing features"):
        top_indices, top_activations, sign_flipped = get_top_activating_images(
            codes, feat_idx, args.n_images
        )

        # Get explanation
        explanation = None
        if feat_idx in autointerp_lookup:
            explanation = autointerp_lookup[feat_idx].get("explanation")

        # Generate on-the-fly if missing and API key available
        if explanation is None and api_key:
            print(f"  Generating explanation for feature {feat_idx}...")
            top_imgs = images.get_images(top_indices.tolist())
            try:
                explanation = get_vlm_explanation(top_imgs, api_key)
            except Exception as e:
                print(f"    Error: {e}")
                explanation = None

        # Build image info
        top_images_info = []
        for i, (img_idx, activation) in enumerate(zip(top_indices, top_activations)):
            label = images.get_label(img_idx)
            label_name = get_imagenet_label_name(label, config.cache_dir)

            img_info = {
                "index": int(img_idx),
                "activation": float(activation),
                "label": int(label),
                "label_name": label_name,
            }

            # Save thumbnail if requested
            if args.save_images:
                img = images[img_idx]
                thumb_path = Path(args.save_images) / f"feat{feat_idx}_img{i}.jpg"
                save_thumbnail(img, thumb_path)
                img_info["thumbnail"] = str(thumb_path)

            top_images_info.append(img_info)

        results.append({
            "feature_idx": int(feat_idx),
            "explanation": explanation,
            "sign_flipped": sign_flipped,
            "top_images": top_images_info,
        })

        if explanation:
            print(f"  Feature {feat_idx}: '{explanation}'")

    # Output
    output_data = {
        "dictionary": args.dictionary,
        "k": args.k,
        "n_images": args.n_images,
        "features": results,
    }

    if args.output:
        Path(args.output).parent.mkdir(parents=True, exist_ok=True)
        with open(args.output, "w") as f:
            json.dump(output_data, f, indent=2)
        print(f"\nSaved to {args.output}")
    else:
        print(json.dumps(output_data, indent=2))


if __name__ == "__main__":
    main()
