#!/usr/bin/env python3
"""Test random feature selection for autointerp to check if scores diverge."""

import json
import numpy as np
import os
from pathlib import Path

# Lazy imports to avoid Julia
def run_test(model: str, method: str, k_values: list[int], n_features: int = 20, seed: int = 42):
    """Run autointerp with random features and compare across k values."""

    results_dir = Path("results/autointerp")
    work_dir = Path("cache/autointerp")
    models_dir = Path("models") / model

    np.random.seed(seed)

    print(f"\n=== Random Feature Test: {method.upper()} {model} ===\n")

    for k in k_values:
        # Load existing results to get codes path info
        results_file = results_dir / f"{method}_{model}_k{k}.json"
        with open(results_file) as f:
            existing = json.load(f)

        dict_path = existing["dictionary"]

        # Find codes file
        import hashlib
        dict_bytes = Path(dict_path).read_bytes()
        dict_hash = hashlib.md5(dict_bytes).hexdigest()[:8]
        dict_name = Path(dict_path).stem
        codes_path = work_dir / f"codes_{dict_name}_{dict_hash}_k{k}.npz"

        if not codes_path.exists():
            print(f"Codes not found: {codes_path}")
            continue

        codes = np.load(codes_path)["codes"]
        print(f"k={k}: Loaded codes shape {codes.shape}")

        # Random feature selection (same as autointerp.py)
        active = np.where(codes.max(axis=0) > 0)[0]
        np.random.seed(seed)  # Reset seed for reproducibility across k
        feature_indices = np.random.choice(active, size=min(n_features, len(active)), replace=False)

        print(f"  Selected features: {sorted(feature_indices[:5])}... (showing first 5)")

        # Compute variance for selected features
        variances = codes[:, feature_indices].var(axis=0)
        print(f"  Variance range: {variances.min():.4f} - {variances.max():.4f}")

        # Check which images would be selected for top feature
        top_feat = feature_indices[np.argmax(variances)]
        activations = codes[:, top_feat]
        nonzero = activations[activations != 0]
        if len(nonzero) > 0 and np.mean(nonzero) < 0:
            activations = -activations
        top_images = np.argsort(activations)[::-1][:8]
        print(f"  Top feature {top_feat}: images {top_images[:4]}...")
        print()


def compare_with_topvar(model: str, method: str, k_values: list[int]):
    """Compare random vs top_variance feature distributions."""

    results_dir = Path("results/autointerp")
    work_dir = Path("cache/autointerp")

    print(f"\n=== Comparing Top Variance vs Random Features ===\n")

    for k in k_values:
        results_file = results_dir / f"{method}_{model}_k{k}.json"
        with open(results_file) as f:
            existing = json.load(f)

        # Get top_variance features that were actually evaluated
        topvar_features = [r["feature_idx"] for r in existing["results"]]

        # Load codes and compute variance
        dict_path = existing["dictionary"]
        import hashlib
        dict_bytes = Path(dict_path).read_bytes()
        dict_hash = hashlib.md5(dict_bytes).hexdigest()[:8]
        dict_name = Path(dict_path).stem
        codes_path = work_dir / f"codes_{dict_name}_{dict_hash}_k{k}.npz"

        codes = np.load(codes_path)["codes"]
        all_variances = codes.var(axis=0)

        topvar_variance = all_variances[topvar_features].mean()

        # Random sample
        active = np.where(codes.max(axis=0) > 0)[0]
        np.random.seed(42)
        random_features = np.random.choice(active, size=100, replace=False)
        random_variance = all_variances[random_features].mean()

        print(f"k={k}:")
        print(f"  Top variance features: mean_var={topvar_variance:.4f}")
        print(f"  Random features: mean_var={random_variance:.4f}")
        print(f"  Ratio: {topvar_variance/random_variance:.2f}x")
        print()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="vitb14")
    parser.add_argument("--method", default="ksvd")
    args = parser.parse_args()

    k_values = [16, 32, 64]

    run_test(args.model, args.method, k_values)
    compare_with_topvar(args.model, args.method, k_values)
