#!/usr/bin/env python
"""
Visualization script: Image → Top Features

For a given image, show which features activate most strongly and their explanations.
"""

import argparse
import json
import os
import sys
from pathlib import Path

# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))

import numpy as np
from PIL import Image
from tqdm import tqdm

from src.autointerp import (
    AutoInterpConfig,
    load_dataset_sorted,
    extract_or_load_embeddings,
    compute_or_load_codes,
    get_embeddings_path,
    get_codes_path,
    get_imagenet_label_name,
    get_vlm_explanation,
)


def get_top_features_for_image(
    codes: np.ndarray,
    image_idx: int,
    n_features: int = 5,
) -> list[dict]:
    """
    Get top-k features by activation magnitude for a given image.

    Returns list of dicts with feature_idx, activation, sign_flipped.
    """
    code = codes[image_idx]
    nonzero_mask = code != 0
    nonzero_indices = np.where(nonzero_mask)[0]
    nonzero_values = code[nonzero_mask]

    # Sort by absolute activation
    sorted_order = np.argsort(np.abs(nonzero_values))[::-1]
    top_indices = nonzero_indices[sorted_order[:n_features]]
    top_values = nonzero_values[sorted_order[:n_features]]

    results = []
    for feat_idx, activation in zip(top_indices, top_values):
        # Determine if this feature would be sign-flipped in autointerp
        feat_activations = codes[:, feat_idx]
        nonzero_feat = feat_activations[feat_activations != 0]
        sign_flipped = len(nonzero_feat) > 0 and np.mean(nonzero_feat) < 0

        results.append({
            "feature_idx": int(feat_idx),
            "activation": float(activation),
            "sign_flipped": bool(sign_flipped),
        })

    return results


def save_image(img: Image.Image, path: Path, size: int = 256):
    """Save image as JPEG thumbnail."""
    img_copy = img.copy()
    img_copy.thumbnail((size, size), Image.LANCZOS)
    path.parent.mkdir(parents=True, exist_ok=True)
    img_copy.save(path, "JPEG", quality=85)


def main():
    parser = argparse.ArgumentParser(description="Image → Top Features visualization")
    parser.add_argument("--dictionary", required=True, help="Path to dictionary .npy file")
    parser.add_argument("--k", type=int, default=32, help="Sparsity level")
    parser.add_argument("--image-index", default="random",
                        help="Image index in dataset, or 'random'")
    parser.add_argument("--autointerp-json", help="Path to autointerp JSON (for explanations)")
    parser.add_argument("--n-features", type=int, default=5, help="Number of top features to show")
    parser.add_argument("--output", help="Output JSON path")
    parser.add_argument("--save-image", help="Path to save the input image as 256px JPEG")
    parser.add_argument("--api-key", help="OpenAI API key for generating missing explanations")
    parser.add_argument("--n-samples", type=int, help="Limit dataset size (for dry runs)")
    parser.add_argument("--cache-dir", help="HuggingFace cache directory")
    parser.add_argument("--work-dir", default="cache/autointerp", help="Working directory")
    parser.add_argument("--device", default="cuda", help="Device for embeddings")
    args = parser.parse_args()

    # Load autointerp JSON if provided
    autointerp_lookup = {}
    if args.autointerp_json:
        with open(args.autointerp_json) as f:
            autointerp_data = json.load(f)
        autointerp_lookup = {r["feature_idx"]: r for r in autointerp_data.get("results", [])}

    # Infer DINO model from dictionary dimension
    dictionary = np.load(args.dictionary)
    emb_dim = dictionary.shape[0]
    dino_model = {384: "dinov2_vits14", 768: "dinov2_vitb14"}.get(emb_dim)
    if dino_model is None:
        raise ValueError(f"Unknown embedding dim {emb_dim}")
    print(f"Dictionary: {args.dictionary} (shape={dictionary.shape}, expects {dino_model})")

    # Load dataset and embeddings
    config = AutoInterpConfig(
        n_samples=args.n_samples,
        cache_dir=args.cache_dir,
        work_dir=args.work_dir,
    )

    emb_cache = get_embeddings_path(config.work_dir, config.dataset, dino_model, config.n_samples)
    if emb_cache.exists():
        print(f"Loading cached embeddings from {emb_cache}")
        embeddings = np.load(emb_cache)["embeddings"]
        images = load_dataset_sorted(config.dataset, config.n_samples, config.cache_dir)
    else:
        images = load_dataset_sorted(config.dataset, config.n_samples, config.cache_dir)
        embeddings = extract_or_load_embeddings(
            images, dino_model, config.batch_size, args.device, emb_cache
        )

    # Load sparse codes
    codes_cache = get_codes_path(config.work_dir, args.dictionary, args.k)
    codes = compute_or_load_codes(embeddings, args.dictionary, args.k, codes_cache)

    # Determine image index
    if args.image_index == "random":
        np.random.seed(None)  # True random
        image_idx = np.random.randint(0, len(images))
    else:
        image_idx = int(args.image_index)

    print(f"\nAnalyzing image {image_idx}...")

    # Get image info
    label = images.get_label(image_idx)
    label_name = get_imagenet_label_name(label, config.cache_dir)
    print(f"  Label: {label} ({label_name})")

    # Save image if requested
    if args.save_image:
        img = images[image_idx]
        save_image(img, Path(args.save_image))
        print(f"  Saved to {args.save_image}")

    # Get top features
    top_features = get_top_features_for_image(codes, image_idx, args.n_features)

    # API key for on-the-fly generation
    api_key = args.api_key or os.environ.get("OPENAI_API_KEY")

    # Look up or generate explanations
    print(f"\nTop {len(top_features)} features:")
    for feat_info in top_features:
        feat_idx = feat_info["feature_idx"]

        # Look up explanation
        explanation = None
        if feat_idx in autointerp_lookup:
            explanation = autointerp_lookup[feat_idx].get("explanation")

        # Generate on-the-fly if missing and API key available
        if explanation is None and api_key:
            print(f"  Generating explanation for feature {feat_idx}...")
            # Get top-activating images for this feature
            feat_activations = codes[:, feat_idx].copy()
            nonzero = feat_activations[feat_activations != 0]
            if len(nonzero) > 0 and np.mean(nonzero) < 0:
                feat_activations = -feat_activations
            top_indices = np.argsort(feat_activations)[::-1][:8]
            top_imgs = images.get_images(top_indices.tolist())
            try:
                explanation = get_vlm_explanation(top_imgs, api_key)
            except Exception as e:
                print(f"    Error: {e}")

        feat_info["explanation"] = explanation
        sign_str = " (sign-flipped)" if feat_info["sign_flipped"] else ""
        print(f"  Feature {feat_idx}: activation={feat_info['activation']:.2f}{sign_str}")
        if explanation:
            print(f"    → {explanation}")

    # Output
    output_data = {
        "dictionary": args.dictionary,
        "k": args.k,
        "image_index": image_idx,
        "label": label,
        "label_name": label_name,
        "top_features": top_features,
    }

    if args.output:
        Path(args.output).parent.mkdir(parents=True, exist_ok=True)
        with open(args.output, "w") as f:
            json.dump(output_data, f, indent=2)
        print(f"\nSaved to {args.output}")
    else:
        print("\n" + json.dumps(output_data, indent=2))


if __name__ == "__main__":
    main()
