#!/usr/bin/env python3
"""
Example script demonstrating feature loading and layer analysis.

This script shows how to:
1. Load features from a single image
2. Understand layer numbering
3. Access specific layers and tokens
4. Perform basic analysis
"""

import numpy as np
from pathlib import Path


def print_separator(title=""):
    """Print a separator line with optional title"""
    if title:
        print(f"\n{'='*70}")
        print(f"  {title}")
        print(f"{'='*70}")
    else:
        print(f"{'='*70}")


def load_and_inspect_features(feature_file):
    """Load features and print basic information"""
    print_separator("BASIC FEATURE INFORMATION")

    features = np.load(feature_file)

    print(f"\nFile: {Path(feature_file).name}")
    print(f"Shape: {features.shape}")
    print(f"  - L (Layers): {features.shape[0]}")
    print(f"  - N (Spatial Tokens): {features.shape[1]}")
    print(f"  - C (Feature Channels): {features.shape[2]}")
    print(f"\nData type: {features.dtype}")
    print(f"Memory size: {features.nbytes / 1024 / 1024:.2f} MB")

    return features


def demonstrate_layer_access(features):
    """Demonstrate how to access different layers"""
    print_separator("LAYER ACCESS EXAMPLES")

    print("\n1. First Layer (Index 0 = Block 0 = Layer 1 in papers)")
    layer_0 = features[0]
    print(f"   features[0].shape: {layer_0.shape}")
    print(f"   Mean: {layer_0.mean():.4f}, Std: {layer_0.std():.4f}")

    print("\n2. Middle Layer (Index 11 = Block 11 = Layer 12 in papers)")
    layer_11 = features[11]
    print(f"   features[11].shape: {layer_11.shape}")
    print(f"   Mean: {layer_11.mean():.4f}, Std: {layer_11.std():.4f}")

    print("\n3. Last Layer (Index 23 = Block 23 = Layer 24 in papers)")
    layer_23 = features[23]
    print(f"   features[23].shape: {layer_23.shape}")
    print(f"   features[-1].shape: {features[-1].shape}  (equivalent)")
    print(f"   Mean: {layer_23.mean():.4f}, Std: {layer_23.std():.4f}")

    print("\n4. Layer Range (First 8 layers - Early features)")
    early_layers = features[0:8]
    print(f"   features[0:8].shape: {early_layers.shape}")


def demonstrate_token_access(features):
    """Demonstrate how to access specific tokens"""
    print_separator("TOKEN ACCESS EXAMPLES")

    print("\n1. First Token Across All Layers")
    token_0 = features[:, 0, :]
    print(f"   features[:, 0, :].shape: {token_0.shape}")
    print(f"   This gives the first spatial token's features across all 24 layers")

    print("\n2. Center Token (approximate)")
    center_idx = 98  # Token at position (7, 7) in 14×14 grid
    center_token = features[:, center_idx, :]
    print(f"   features[:, 98, :].shape: {center_token.shape}")
    print(f"   Token 98 is at grid position (7, 7) - center of image")

    print("\n3. All Tokens from Layer 10")
    layer_10_tokens = features[10, :, :]
    print(f"   features[10, :, :].shape: {layer_10_tokens.shape}")
    print(f"   This gives all 196 spatial tokens from layer 10")

    print("\n4. Specific Token from Specific Layer")
    specific_feature = features[15, 50, :]
    print(f"   features[15, 50, :].shape: {specific_feature.shape}")
    print(f"   This gives the 384-dim feature vector for token 50 at layer 15")


def demonstrate_spatial_mapping(features):
    """Demonstrate spatial position mapping"""
    print_separator("SPATIAL POSITION MAPPING")

    def token_to_position(token_idx, grid_size=14):
        """Convert token index to (row, col) position"""
        row = token_idx // grid_size
        col = token_idx % grid_size
        return row, col

    def position_to_token(row, col, grid_size=14):
        """Convert (row, col) position to token index"""
        return row * grid_size + col

    print("\nToken indices form a 14×14 grid:")
    print("  Token 0   → Position (0, 0)   [Top-left]")
    print("  Token 13  → Position (0, 13)  [Top-right]")
    print("  Token 98  → Position (7, 7)   [Center]")
    print("  Token 182 → Position (13, 0)  [Bottom-left]")
    print("  Token 195 → Position (13, 13) [Bottom-right]")

    print("\n\nExample: Extract features at specific positions")

    # Top-left corner
    top_left_idx = position_to_token(0, 0)
    top_left_features = features[:, top_left_idx, :]
    print(f"\n1. Top-left corner (row=0, col=0):")
    print(f"   Token index: {top_left_idx}")
    print(f"   Features shape: {top_left_features.shape}")
    print(f"   Mean across layers: {top_left_features.mean():.4f}")

    # Center
    center_idx = position_to_token(7, 7)
    center_features = features[:, center_idx, :]
    print(f"\n2. Center (row=7, col=7):")
    print(f"   Token index: {center_idx}")
    print(f"   Features shape: {center_features.shape}")
    print(f"   Mean across layers: {center_features.mean():.4f}")

    # Bottom-right corner
    bottom_right_idx = position_to_token(13, 13)
    bottom_right_features = features[:, bottom_right_idx, :]
    print(f"\n3. Bottom-right corner (row=13, col=13):")
    print(f"   Token index: {bottom_right_idx}")
    print(f"   Features shape: {bottom_right_features.shape}")
    print(f"   Mean across layers: {bottom_right_features.mean():.4f}")


def analyze_layer_progression(features):
    """Analyze how features change across layers"""
    print_separator("LAYER PROGRESSION ANALYSIS")

    print("\nComputing statistics for each layer...")

    layer_stats = []
    for layer_idx in range(features.shape[0]):
        layer_features = features[layer_idx]
        stats = {
            'layer': layer_idx,
            'mean': layer_features.mean(),
            'std': layer_features.std(),
            'norm': np.linalg.norm(layer_features, axis=1).mean()
        }
        layer_stats.append(stats)

    # Print table
    print(f"\n{'Layer':<8} {'Block':<8} {'Mean':<12} {'Std':<12} {'Avg L2 Norm':<12}")
    print("-" * 60)

    for stats in layer_stats:
        layer_idx = stats['layer']
        print(f"{layer_idx:<8} "
              f"Block_{layer_idx:<2} "
              f"{stats['mean']:>11.4f} "
              f"{stats['std']:>11.4f} "
              f"{stats['norm']:>11.4f}")

    # Highlight key layers
    print("\n\nKey observations:")
    print(f"  - First layer (0): mean={layer_stats[0]['mean']:.4f}")
    print(f"  - Middle layer (11): mean={layer_stats[11]['mean']:.4f}")
    print(f"  - Last layer (23): mean={layer_stats[23]['mean']:.4f}")


def demonstrate_reshaping(features):
    """Demonstrate reshaping for spatial analysis"""
    print_separator("RESHAPING FOR SPATIAL ANALYSIS")

    print("\nOriginal shape: (24, 196, 384)")
    print("  - 24 layers")
    print("  - 196 tokens (14×14 grid)")
    print("  - 384 feature channels")

    # Reshape layer 10 to spatial grid
    layer_10 = features[10]
    layer_10_grid = layer_10.reshape(14, 14, 384)

    print(f"\nAfter reshaping layer 10:")
    print(f"  layer_10.shape: {layer_10.shape}")
    print(f"  layer_10_grid.shape: {layer_10_grid.shape}")
    print(f"  Now indexed as: [row, col, channel]")

    # Access specific positions
    print("\nExamples:")
    print(f"  layer_10_grid[0, 0, :] → Top-left corner features (384-dim)")
    print(f"  layer_10_grid[7, 7, :] → Center features (384-dim)")
    print(f"  layer_10_grid[13, 13, :] → Bottom-right corner features (384-dim)")

    # Compute spatial statistics
    print("\nSpatial statistics for layer 10:")
    spatial_norms = np.linalg.norm(layer_10_grid, axis=2)  # (14, 14)
    print(f"  Spatial norm map shape: {spatial_norms.shape}")
    print(f"  Max activation at position: {np.unravel_index(spatial_norms.argmax(), spatial_norms.shape)}")
    print(f"  Min activation at position: {np.unravel_index(spatial_norms.argmin(), spatial_norms.shape)}")


def main():
    """Main function to run all examples"""

    # Configuration
    feature_dir = Path("features/cait_s24_224")

    # Check if features exist
    if not feature_dir.exists():
        print(f"Error: Feature directory not found: {feature_dir}")
        print("Please run the feature extraction first.")
        return

    # Get first feature file
    feature_files = sorted(list(feature_dir.glob("*.npy")))
    if not feature_files:
        print(f"Error: No feature files found in {feature_dir}")
        return

    feature_file = feature_files[0]

    print("\n" + "="*70)
    print("  CaiT Feature Analysis Examples")
    print("  Model: cait_s24_224")
    print("="*70)

    # Run demonstrations
    features = load_and_inspect_features(feature_file)
    demonstrate_layer_access(features)
    demonstrate_token_access(features)
    demonstrate_spatial_mapping(features)
    analyze_layer_progression(features)
    demonstrate_reshaping(features)

    print_separator()
    print("\nAnalysis complete!")
    print(f"For more information, see: FEATURE_EXTRACTION_GUIDE.md")
    print("="*70 + "\n")


if __name__ == "__main__":
    main()
