#!/usr/bin/env python3
"""
Example script demonstrating how to load and analyze extracted features.
"""

import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt


def analyze_features(feature_file):
    """
    Load and analyze features from a single image.

    Args:
        feature_file: Path to the .npy feature file
    """
    # Load features
    features = np.load(feature_file)

    print("=" * 80)
    print(f"Analyzing: {Path(feature_file).name}")
    print("=" * 80)

    # Basic information
    print(f"\nFeature shape: {features.shape}")
    print(f"  - Number of layers (L): {features.shape[0]}")
    print(f"  - Number of tokens (N): {features.shape[1]}")
    print(f"  - Feature dimension (C): {features.shape[2]}")

    # Statistics
    print(f"\nFeature statistics:")
    print(f"  - Mean: {features.mean():.4f}")
    print(f"  - Std: {features.std():.4f}")
    print(f"  - Min: {features.min():.4f}")
    print(f"  - Max: {features.max():.4f}")

    # Per-layer statistics
    print(f"\nPer-layer statistics:")
    for layer_idx in range(features.shape[0]):
        layer_features = features[layer_idx]
        print(f"  Layer {layer_idx:2d}: mean={layer_features.mean():7.4f}, "
              f"std={layer_features.std():.4f}, "
              f"norm={np.linalg.norm(layer_features, axis=-1).mean():.4f}")

    return features


def compare_features(feature_files):
    """
    Compare features across multiple images.

    Args:
        feature_files: List of paths to .npy feature files
    """
    print("\n" + "=" * 80)
    print("Comparing multiple images")
    print("=" * 80)

    all_features = []
    for f in feature_files:
        features = np.load(f)
        all_features.append(features)
        print(f"{Path(f).name}: {features.shape}")

    # Check if all have the same shape
    shapes = [f.shape for f in all_features]
    if len(set(shapes)) == 1:
        print(f"\n✓ All features have the same shape: {shapes[0]}")
    else:
        print(f"\n✗ Warning: Features have different shapes: {shapes}")

    # Compute pairwise similarity (cosine similarity on mean features)
    print(f"\nPairwise cosine similarities (averaged across layers and tokens):")
    n_images = len(all_features)
    for i in range(n_images):
        for j in range(i + 1, n_images):
            f1 = all_features[i].reshape(-1)
            f2 = all_features[j].reshape(-1)
            similarity = np.dot(f1, f2) / (np.linalg.norm(f1) * np.linalg.norm(f2))
            print(f"  Image {i} vs Image {j}: {similarity:.4f}")


def visualize_feature_norms(features, save_path=None):
    """
    Visualize the L2 norms of features across layers and tokens.

    Args:
        features: Feature array of shape [L, N, C]
        save_path: Optional path to save the plot
    """
    try:
        import matplotlib.pyplot as plt

        # Compute L2 norms across feature dimension
        norms = np.linalg.norm(features, axis=-1)  # Shape: [L, N]

        # Create visualization
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))

        # Plot 1: Heatmap of norms
        im = axes[0].imshow(norms, aspect='auto', cmap='viridis')
        axes[0].set_xlabel('Token Index')
        axes[0].set_ylabel('Layer Index')
        axes[0].set_title('L2 Norm of Features (Layers × Tokens)')
        plt.colorbar(im, ax=axes[0])

        # Plot 2: Average norm per layer
        avg_norms = norms.mean(axis=1)
        axes[1].plot(avg_norms, marker='o')
        axes[1].set_xlabel('Layer Index')
        axes[1].set_ylabel('Average L2 Norm')
        axes[1].set_title('Average Feature Norm per Layer')
        axes[1].grid(True)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"\nVisualization saved to: {save_path}")
        else:
            plt.show()

    except ImportError:
        print("\nWarning: matplotlib not installed. Skipping visualization.")


if __name__ == "__main__":
    # Example usage
    # Modify these paths to match your setup
    MODEL_NAME = "vit_base_patch16_224"
    FEATURE_DIR = f"path/to/features/{MODEL_NAME}"

    feature_path = Path(FEATURE_DIR)

    if not feature_path.exists():
        print(f"Error: Feature directory not found: {FEATURE_DIR}")
        print("Please run the feature extraction pipeline first.")
        exit(1)

    # Get all feature files
    feature_files = sorted(list(feature_path.glob("*.npy")))

    if len(feature_files) == 0:
        print(f"Error: No feature files found in {FEATURE_DIR}")
        exit(1)

    print(f"Found {len(feature_files)} feature files")

    # Analyze first feature file
    print("\n" + "=" * 80)
    print("Example 1: Analyzing a single image")
    print("=" * 80)
    features = analyze_features(feature_files[0])

    # Visualize features
    print("\n" + "=" * 80)
    print("Example 2: Visualizing feature norms")
    print("=" * 80)
    visualize_feature_norms(features, save_path="feature_visualization.png")

    # Compare multiple images
    if len(feature_files) >= 3:
        print("\n" + "=" * 80)
        print("Example 3: Comparing features from multiple images")
        print("=" * 80)
        compare_features(feature_files[:3])

    # Example: Access specific layer or token
    print("\n" + "=" * 80)
    print("Example 4: Accessing specific layers and tokens")
    print("=" * 80)

    print(f"\nOriginal feature shape: {features.shape}")

    # Get features from layer 5
    layer_5 = features[5]
    print(f"Layer 5 features shape: {layer_5.shape}")

    # Get features from first token across all layers
    token_0 = features[:, 0, :]
    print(f"Token 0 across all layers shape: {token_0.shape}")

    # Get features from last layer, all tokens
    last_layer = features[-1]
    print(f"Last layer features shape: {last_layer.shape}")

    print("\n" + "=" * 80)
    print("Analysis complete!")
    print("=" * 80)
