"""
VLM-based auto-interpretability evaluation for sparse coding features.

Architecture:
- Embeddings are extracted in dataset order (no shuffling) → embedding[i] = dataset[i]
- Embeddings and encodings are cached separately for reuse
- Images are loaded lazily on demand via dataset index

Pipeline:
1. For each feature, find top-k activating images by code activation
2. Send images to VLM → get natural language explanation
3. Score using CLIP: correlate explanation embedding with held-out images
"""

# Import juliacall before torch to avoid segfault
# See: https://github.com/pytorch/pytorch/issues/78829
from juliacall import Main as jl

import base64
import hashlib
import io
import json
import os
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from tqdm import tqdm


_imagenet_labels = None

def get_imagenet_label_names(cache_dir: str | None = None) -> list[str]:
    """Load ImageNet-1k class names (synset labels)."""
    global _imagenet_labels
    if _imagenet_labels is None:
        from datasets import load_dataset
        ds = load_dataset("imagenet-1k", split="train", streaming=True, cache_dir=cache_dir)
        _imagenet_labels = ds.features["label"].names
    return _imagenet_labels


def get_imagenet_label_name(label_idx: int, cache_dir: str | None = None) -> str:
    """Get human-readable name for an ImageNet label index."""
    names = get_imagenet_label_names(cache_dir)
    if 0 <= label_idx < len(names):
        return names[label_idx]
    return f"unknown_{label_idx}"


@dataclass
class AutoInterpConfig:
    """Configuration for autointerp evaluation."""
    # Feature selection
    n_features: int = 20
    feature_selection: str = "top_variance"  # or "random" or "stratified"

    # VLM settings
    vlm_model: str = "gpt-4o-mini"
    top_k_images: int = 8  # images to show VLM

    # CLIP scoring
    clip_model: str = "ViT-B-32"
    n_scoring_images: int = 100  # held-out images for scoring

    # Data
    dataset: str = "imagenet-1k"
    n_samples: int | None = None  # None = full dataset
    batch_size: int = 256  # A100 80GB can handle large batches
    cache_dir: str | None = None  # HuggingFace cache
    work_dir: str = "cache/autointerp"  # our cache for embeddings/codes


class LazyImageDataset:
    """Lazy image loader - keeps dataset reference, loads images by index on demand."""

    def __init__(self, dataset, image_key: str = "image", label_key: str = "label"):
        self.dataset = dataset
        self.image_key = image_key
        self.label_key = label_key

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Image.Image:
        """Load single image on demand."""
        item = self.dataset[int(idx)]
        img = item[self.image_key]
        if img.mode != "RGB":
            img = img.convert("RGB")
        return img

    def get_images(self, indices: list[int]) -> list[Image.Image]:
        """Load multiple images by indices."""
        return [self[i] for i in indices]

    def get_label(self, idx: int) -> int:
        """Get label for single image."""
        return self.dataset[int(idx)][self.label_key]


def load_dataset_sorted(
    dataset_name: str = "imagenet-1k",
    n_samples: int | None = None,
    cache_dir: str | None = None,
) -> LazyImageDataset:
    """Load dataset in sorted order (no shuffling). Index i = dataset row i."""
    from datasets import load_dataset

    print(f"Loading {dataset_name} (sorted order)...")

    if dataset_name == "cifar100":
        ds = load_dataset("cifar100", split="train", cache_dir=cache_dir)
        image_key, label_key = "img", "fine_label"
    else:
        ds = load_dataset(dataset_name, split="train", streaming=False, cache_dir=cache_dir)
        image_key, label_key = "image", "label"

    if n_samples is not None and n_samples < len(ds):
        ds = ds.select(range(n_samples))

    print(f"  Loaded {len(ds)} samples")
    return LazyImageDataset(ds, image_key=image_key, label_key=label_key)


def get_embeddings_path(work_dir: str, dataset: str, model: str, n_samples: int | None) -> Path:
    """Get cache path for embeddings."""
    samples_str = str(n_samples) if n_samples is not None else "all"
    return Path(work_dir) / f"embeddings_{dataset}_{model}_{samples_str}.npz"


def extract_or_load_embeddings(
    images: LazyImageDataset,
    model_name: str = "dinov2_vits14",
    batch_size: int = 64,
    device: str = "cuda",
    cache_path: Path | None = None,
) -> np.ndarray:
    """Extract DINOv2 embeddings in sorted order, or load from cache."""
    # Check cache
    if cache_path and cache_path.exists():
        print(f"Loading cached embeddings from {cache_path}")
        return np.load(cache_path)["embeddings"]

    # Extract fresh
    from torchvision import transforms
    print(f"Extracting {model_name} embeddings for {len(images)} images...")

    model = torch.hub.load("facebookresearch/dinov2", model_name)
    model = model.to(device).eval()

    transform = transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    embeddings = []
    for i in tqdm(range(0, len(images), batch_size), desc="Extracting DINO"):
        batch_imgs = [images[j] for j in range(i, min(i + batch_size, len(images)))]
        batch_tensors = torch.stack([transform(img) for img in batch_imgs]).to(device)
        with torch.no_grad():
            emb = model(batch_tensors)
        embeddings.append(emb.cpu().numpy())

    embeddings = np.vstack(embeddings)

    # Cache
    if cache_path:
        cache_path.parent.mkdir(parents=True, exist_ok=True)
        np.savez(cache_path, embeddings=embeddings)
        print(f"Cached embeddings to {cache_path}")

    return embeddings


def get_codes_path(work_dir: str, dict_path: str, k: int, matryoshka: bool = False) -> Path:
    """Get cache path for sparse codes based on dictionary file hash."""
    dict_hash = hashlib.md5(Path(dict_path).read_bytes()).hexdigest()[:8]
    dict_name = Path(dict_path).stem
    suffix = "_matryoshka" if matryoshka else ""
    return Path(work_dir) / f"codes_{dict_name}_{dict_hash}_k{k}{suffix}.npz"


def compute_or_load_codes(
    embeddings: np.ndarray,
    dictionary_path: str,
    k: int,
    cache_path: Path | None = None,
    matryoshka: bool = False,
) -> np.ndarray:
    """Compute sparse codes using OMP (or matryoshka variant), or load from cache."""
    # Check cache
    if cache_path and cache_path.exists():
        print(f"Loading cached codes from {cache_path}")
        return np.load(cache_path)["codes"]

    # Compute fresh
    import juliacall

    def jlmat32(M):
        return juliacall.convert(jl.Matrix[jl.Float32], M)

    dictionary = np.load(dictionary_path).astype(np.float32)
    coding_type = "matryoshka" if matryoshka else "standard"
    print(f"Computing sparse codes (k={k}, {coding_type}) for dictionary {dictionary_path}...")
    jl.seval("using KSVD")

    # embeddings: (n, d), dictionary: (d, dict_size)
    D = dictionary  # (d, dict_size)
    Y = embeddings.T.astype(np.float32)  # (d, n)

    # Precompute for efficiency
    DtD = (D.T @ D).astype(np.float32)
    DtY = (D.T @ Y).astype(np.float32)

    sparse_coding_method = jl.KSVD.ParallelMatchingPursuit(max_nnz=k)

    if matryoshka:
        codes = np.array(jl.KSVD.sparse_coding_matryoshka(
            sparse_coding_method, jlmat32(Y), jlmat32(D),
            log2min=8, DtD=jlmat32(DtD), DtY=jlmat32(DtY)
        ))
    else:
        codes = np.array(jl.KSVD.sparse_coding(
            sparse_coding_method, jlmat32(Y), jlmat32(D),
            DtD=jlmat32(DtD), DtY=jlmat32(DtY)
        ))
    codes = codes.T  # (n, dict_size)

    # Cache
    if cache_path:
        cache_path.parent.mkdir(parents=True, exist_ok=True)
        np.savez(cache_path, codes=codes)
        print(f"Cached codes to {cache_path}")

    return codes


# === VLM + CLIP scoring ===

def image_to_base64(img: Image.Image, max_size: int = 512) -> str:
    """Convert PIL image to base64 string, resizing if needed."""
    if max(img.size) > max_size:
        ratio = max_size / max(img.size)
        new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
        img = img.resize(new_size, Image.LANCZOS)
    buffer = io.BytesIO()
    img.save(buffer, format="JPEG", quality=85)
    return base64.b64encode(buffer.getvalue()).decode()


def get_vlm_explanation(images: list[Image.Image], api_key: str, model: str = "gpt-4o-mini") -> str:
    """Get VLM explanation for a set of top-activating images."""
    from openai import OpenAI
    client = OpenAI(api_key=api_key)

    content = [{
        "type": "text",
        "text": (
            "These images all strongly activate the same feature in a neural network. "
            "What visual concept, pattern, or attribute do they share? "
            "Be specific but concise. Respond with a single phrase or short sentence "
            "describing what this feature detects. Examples: 'dogs playing outdoors', "
            "'red circular objects', 'images with strong diagonal lines', 'close-up faces'."
        ),
    }]
    for img in images:
        b64 = image_to_base64(img)
        content.append({
            "type": "image_url",
            "image_url": {"url": f"data:image/jpeg;base64,{b64}", "detail": "low"},
        })

    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": content}],
        max_tokens=100,
    )
    return response.choices[0].message.content.strip()


_clip_cache = {}

def get_clip_model(clip_model_name: str = "ViT-B-32", device: str = "cuda"):
    """Get cached CLIP model."""
    import open_clip
    key = (clip_model_name, device)
    if key not in _clip_cache:
        model, _, preprocess = open_clip.create_model_and_transforms(
            clip_model_name, pretrained="openai", device=device
        )
        tokenizer = open_clip.get_tokenizer(clip_model_name)
        _clip_cache[key] = (model, preprocess, tokenizer)
    return _clip_cache[key]


def compute_clip_scores(
    explanation: str,
    images: list[Image.Image],
    clip_model_name: str = "ViT-B-32",
    device: str = "cuda",
) -> np.ndarray:
    """Compute CLIP similarity between explanation and images."""
    model, preprocess, tokenizer = get_clip_model(clip_model_name, device)

    text_tokens = tokenizer([explanation]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_tokens)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    image_features = []
    batch_size = 256
    for i in range(0, len(images), batch_size):
        batch = images[i:i + batch_size]
        batch_tensors = torch.stack([preprocess(img) for img in batch]).to(device)
        with torch.no_grad():
            feats = model.encode_image(batch_tensors)
            feats = feats / feats.norm(dim=-1, keepdim=True)
        image_features.append(feats.cpu().numpy())

    image_features = np.vstack(image_features)
    return (image_features @ text_features.cpu().numpy().T).squeeze()


# === Feature evaluation ===

def evaluate_feature(
    feature_idx: int,
    codes: np.ndarray,
    images: LazyImageDataset,
    api_key: str,
    config: AutoInterpConfig,
    device: str = "cuda",
) -> dict:
    """Evaluate interpretability of a single feature."""
    activations = codes[:, feature_idx]

    # Flip sign if feature is predominantly negative (for KSVD compatibility)
    nonzero = activations[activations != 0]
    sign_flipped = False
    if len(nonzero) > 0 and np.mean(nonzero) < 0:
        activations = -activations
        sign_flipped = True

    # Top-k for VLM explanation
    sorted_indices = np.argsort(activations)[::-1]
    top_indices = sorted_indices[:config.top_k_images]

    # Scoring set: mix of high-activating and random (excluding top-k)
    remaining = sorted_indices[config.top_k_images:]
    n_high = config.n_scoring_images // 2
    n_random = config.n_scoring_images - n_high
    high_indices = remaining[:n_high]
    random_indices = np.random.choice(
        remaining[n_high:], size=min(n_random, len(remaining) - n_high), replace=False
    )
    scoring_indices = np.concatenate([high_indices, random_indices])

    # Get VLM explanation (lazy load only needed images)
    top_images = images.get_images(top_indices.tolist())
    explanation = get_vlm_explanation(top_images, api_key, config.vlm_model)

    # Score using CLIP (lazy load only needed images)
    scoring_images = images.get_images(scoring_indices.tolist())
    clip_scores = compute_clip_scores(explanation, scoring_images, config.clip_model, device)
    actual_activations = activations[scoring_indices]

    # Spearman correlation
    from scipy.stats import spearmanr
    correlation, p_value = spearmanr(clip_scores, actual_activations)

    return {
        "feature_idx": int(feature_idx),
        "explanation": explanation,
        "correlation": float(correlation) if not np.isnan(correlation) else None,
        "p_value": float(p_value) if not np.isnan(p_value) else None,
        "top_activation": float(activations[top_indices[0]]),
        "mean_activation": float(activations.mean()),
        "sign_flipped": sign_flipped,
    }


def select_features(
    codes: np.ndarray, n_features: int, method: str = "top_variance", bin: int | None = None,
) -> list[int]:
    """Select features to evaluate.

    Args:
        codes: Sparse codes array (n_samples, n_features)
        n_features: Number of features to select
        method: Selection method (top_variance, random, stratified)
        bin: For stratified, which bin 1-5 (1=top 0-20%, 5=bottom 80-100%)
    """
    if method == "top_variance":
        variances = codes.var(axis=0)
        return np.argsort(variances)[::-1][:n_features].tolist()
    elif method == "random":
        np.random.seed(42)
        active = np.where(codes.max(axis=0) > 0)[0]
        return np.random.choice(active, size=min(n_features, len(active)), replace=False).tolist()
    elif method == "stratified":
        if bin is None or bin < 1 or bin > 5:
            raise ValueError("For stratified, --bin must be 1-5")
        np.random.seed(42)
        variances = codes.var(axis=0)
        sorted_indices = np.argsort(variances)[::-1]  # highest variance first
        n_total = len(sorted_indices)
        # Bins: 1=0-20%, 2=20-40%, 3=40-60%, 4=60-80%, 5=80-100%
        start_idx = int(n_total * (bin - 1) * 0.2)
        end_idx = int(n_total * bin * 0.2)
        candidates = sorted_indices[start_idx:end_idx]
        return np.random.choice(candidates, size=min(n_features, len(candidates)), replace=False).tolist()
    else:
        raise ValueError(f"Unknown method: {method}")


# === Main entry point ===

def run_autointerp(
    dictionary_path: str,
    k: int,
    config: AutoInterpConfig,
    api_key: str | None = None,
    output_path: str | None = None,
    device: str = "cuda",
    embeddings_only: bool = False,
    codes_only: bool = False,
    matryoshka: bool = False,
    bin: int | None = None,
) -> dict | None:
    """
    Run full autointerp evaluation.

    Stages (each can be run independently):
    1. embeddings_only: Extract and cache embeddings, then exit
    2. codes_only: Compute and cache codes for this dictionary, then exit
    3. Full run: Load/compute everything, run VLM+CLIP evaluation
    """
    work_dir = Path(config.work_dir)
    work_dir.mkdir(parents=True, exist_ok=True)

    # Infer DINO model from dictionary dimension
    dictionary = np.load(dictionary_path)
    emb_dim = dictionary.shape[0]
    dino_model = {384: "dinov2_vits14", 768: "dinov2_vitb14"}.get(emb_dim)
    if dino_model is None:
        raise ValueError(f"Unknown embedding dim {emb_dim}")
    print(f"Dictionary: {dictionary_path} (shape={dictionary.shape}, expects {dino_model})")

    # Stage 1: Embeddings
    emb_cache = get_embeddings_path(work_dir, config.dataset, dino_model, config.n_samples)
    if emb_cache.exists():
        print(f"Embeddings cached at {emb_cache}")
        if embeddings_only:
            return None
        embeddings = np.load(emb_cache)["embeddings"]
        # Still need dataset for image access
        images = load_dataset_sorted(config.dataset, config.n_samples, config.cache_dir)
    else:
        images = load_dataset_sorted(config.dataset, config.n_samples, config.cache_dir)
        embeddings = extract_or_load_embeddings(
            images, dino_model, config.batch_size, device, emb_cache
        )
        if embeddings_only:
            return None

    # Stage 2: Codes
    codes_cache = get_codes_path(work_dir, dictionary_path, k, matryoshka)
    codes = compute_or_load_codes(embeddings, dictionary_path, k, codes_cache, matryoshka)
    if codes_only:
        return None

    # Stage 3: VLM + CLIP evaluation
    if api_key is None:
        raise ValueError("API key required for VLM evaluation")

    feature_indices = select_features(codes, config.n_features, config.feature_selection, bin)
    print(f"Evaluating {len(feature_indices)} features...")

    results = []
    for feat_idx in tqdm(feature_indices, desc="Autointerp"):
        try:
            result = evaluate_feature(feat_idx, codes, images, api_key, config, device)
            results.append(result)
            r_str = f"{result['correlation']:.3f}" if result['correlation'] is not None else "N/A"
            print(f"  Feature {feat_idx}: '{result['explanation']}' (r={r_str})")
        except Exception as e:
            print(f"  Feature {feat_idx}: ERROR - {e}")

    # Aggregate
    correlations = [r["correlation"] for r in results if r["correlation"] is not None]
    summary = {
        "dictionary": dictionary_path,
        "k": k,
        "n_samples": config.n_samples,
        "mean_correlation": float(np.mean(correlations)) if correlations else None,
        "median_correlation": float(np.median(correlations)) if correlations else None,
        "std_correlation": float(np.std(correlations)) if correlations else None,
        "n_features": len(results),
        "results": results,
    }

    if output_path:
        Path(output_path).parent.mkdir(parents=True, exist_ok=True)
        with open(output_path, "w") as f:
            json.dump(summary, f, indent=2)
        print(f"Saved results to {output_path}")

    return summary


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="VLM-based autointerp evaluation")
    parser.add_argument("--dictionary", required=True, help="Path to dictionary .npy file")
    parser.add_argument("--k", type=int, required=True, help="Sparsity level")
    parser.add_argument("--n-features", type=int, default=20, help="Features to evaluate")
    parser.add_argument("--feature-selection", default="top_variance",
                        choices=["top_variance", "random", "stratified"],
                        help="Feature selection method")
    parser.add_argument("--bin", type=int, help="For stratified: bin 1-5 (1=top 0-20%%, 5=bottom 80-100%%)")
    parser.add_argument("--n-samples", type=int, default=None, help="Images to use (default: all)")
    parser.add_argument("--dataset", default="imagenet-1k", help="Dataset name")
    parser.add_argument("--output", help="Output JSON path")
    parser.add_argument("--api-key", help="OpenAI API key (or OPENAI_API_KEY env)")
    parser.add_argument("--cache-dir", help="HuggingFace cache directory")
    parser.add_argument("--work-dir", default="cache/autointerp", help="Working directory for caches")
    parser.add_argument("--device", default="cuda", help="Device")
    parser.add_argument("--embeddings-only", action="store_true", help="Only extract embeddings")
    parser.add_argument("--codes-only", action="store_true", help="Only compute codes")
    parser.add_argument("--matryoshka", action="store_true", help="Use matryoshka sparse coding (log2min=8)")
    args = parser.parse_args()

    api_key = args.api_key or os.environ.get("OPENAI_API_KEY")
    if not api_key and not (args.embeddings_only or args.codes_only):
        raise ValueError("Must provide --api-key or set OPENAI_API_KEY")

    config = AutoInterpConfig(
        n_features=args.n_features,
        feature_selection=args.feature_selection,
        n_samples=args.n_samples,
        dataset=args.dataset,
        cache_dir=args.cache_dir,
        work_dir=args.work_dir,
    )

    results = run_autointerp(
        dictionary_path=args.dictionary,
        k=args.k,
        config=config,
        api_key=api_key,
        output_path=args.output,
        device=args.device,
        embeddings_only=args.embeddings_only,
        codes_only=args.codes_only,
        matryoshka=args.matryoshka,
        bin=args.bin,
    )

    if results:
        print(f"\n=== Summary ===")
        print(f"Mean correlation: {results['mean_correlation']:.3f}")
        print(f"Median correlation: {results['median_correlation']:.3f}")
