#!/usr/bin/env python3
"""
Compute layer-wise Jacobian effective rank across network depth.

For each model, we hook into intermediate layers and compute the
Jacobian effective rank at each depth level.

This reveals whether rank collapse happens early or late in the network.
"""

import json
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from tqdm import tqdm
from typing import Dict, List, Tuple
from functools import partial

from src.models import load_model


def get_layer_hooks_vit(model: nn.Module, model_name: str) -> List[Tuple[str, nn.Module]]:
    """Get intermediate layers for ViT models, including pooling/projection."""
    layers = []

    # Handle different ViT architectures
    if hasattr(model, 'blocks'):  # DINOv2, MAE, timm ViTs
        blocks = model.blocks
        for i, block in enumerate(blocks):
            layers.append((f'block_{i}', block))
        # Post-block layers
        if hasattr(model, 'norm'):
            layers.append(('norm', model.norm))
        if hasattr(model, 'fc_norm'):
            layers.append(('fc_norm', model.fc_norm))
        if hasattr(model, 'head') and model.head is not None:
            if not isinstance(model.head, nn.Identity):
                layers.append(('head', model.head))
    elif hasattr(model, 'visual'):  # CLIP
        if hasattr(model.visual, 'transformer'):
            # OpenAI CLIP
            blocks = model.visual.transformer.resblocks
            for i, block in enumerate(blocks):
                layers.append((f'block_{i}', block))
            # Post-transformer layers
            if hasattr(model.visual, 'ln_post'):
                layers.append(('ln_post', model.visual.ln_post))
        elif hasattr(model.visual, 'trunk'):
            # OpenCLIP with timm backend
            if hasattr(model.visual.trunk, 'blocks'):
                blocks = model.visual.trunk.blocks
                for i, block in enumerate(blocks):
                    layers.append((f'block_{i}', block))
                if hasattr(model.visual.trunk, 'norm'):
                    layers.append(('norm', model.visual.trunk.norm))
    elif hasattr(model, 'transformer'):  # Some CLIP variants
        blocks = model.transformer.resblocks
        for i, block in enumerate(blocks):
            layers.append((f'block_{i}', block))

    return layers


def get_layer_hooks_resnet(model: nn.Module) -> List[Tuple[str, nn.Module]]:
    """Get intermediate layers for ResNet models, including pooling/fc."""
    layers = []

    # Standard ResNet structure
    if hasattr(model, 'layer1'):
        layers.append(('layer1', model.layer1))
        if hasattr(model, 'layer2'):
            layers.append(('layer2', model.layer2))
        if hasattr(model, 'layer3'):
            layers.append(('layer3', model.layer3))
        if hasattr(model, 'layer4'):
            layers.append(('layer4', model.layer4))
        # Post-conv layers
        if hasattr(model, 'avgpool'):
            layers.append(('avgpool', model.avgpool))
        if hasattr(model, 'fc') and model.fc is not None:
            if not isinstance(model.fc, nn.Identity):
                layers.append(('fc', model.fc))
    elif isinstance(model, nn.Sequential):
        # Handle nn.Sequential wrapped models (e.g., SimCLR)
        # Try to find ResNet layers inside the sequential
        for i, child in enumerate(model.children()):
            child_name = child.__class__.__name__
            if child_name == 'Sequential':
                # Could be layer1, layer2, etc. bundled
                layers.append((f'stage_{i}', child))
            elif child_name == 'AdaptiveAvgPool2d':
                layers.append(('avgpool', child))
            elif child_name == 'Flatten':
                layers.append(('flatten', child))
            elif child_name == 'Linear':
                layers.append(('fc', child))
            elif hasattr(child, 'layer1'):
                # Nested ResNet backbone
                return get_layer_hooks_resnet(child)

    return layers


class LayerOutputCapture:
    """Capture intermediate layer outputs."""

    def __init__(self):
        self.outputs = {}
        self.hooks = []

    def hook_fn(self, name: str, module: nn.Module, input: tuple, output: torch.Tensor):
        if isinstance(output, torch.Tensor):
            self.outputs[name] = output.detach()
        elif isinstance(output, tuple):
            self.outputs[name] = output[0].detach()

    def register(self, layers: List[Tuple[str, nn.Module]]):
        for name, module in layers:
            hook = module.register_forward_hook(partial(self.hook_fn, name))
            self.hooks.append(hook)

    def clear(self):
        self.outputs = {}

    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []


def compute_layerwise_jacobian_rank(
    model_wrapper,
    images: torch.Tensor,
    n_directions: int = 32,
    eps: float = 1e-3,
    verbose: bool = False,
) -> Dict[str, float]:
    """
    Compute Jacobian effective rank at each layer.

    Returns dict mapping layer_name -> effective_rank
    """
    model = model_wrapper.model
    device = images.device
    B, C, H, W = images.shape

    # Determine model type and get layers
    model_name = model_wrapper.name

    if 'resnet' in model_name.lower() or 'barlow' in model_name.lower() or 'vicreg' in model_name.lower() or 'swav' in model_name.lower() or 'simclr' in model_name.lower():
        layers = get_layer_hooks_resnet(model)
    else:
        layers = get_layer_hooks_vit(model, model_name)

    if not layers:
        print(f"Warning: No layers found for {model_name}")
        return {}

    if verbose:
        print(f"  Found {len(layers)} layers", flush=True)

    # Set up output capture
    capture = LayerOutputCapture()
    capture.register(layers)

    # Generate random directions
    torch.manual_seed(42)
    directions = torch.randn(B, n_directions, C, H, W, device=device)
    directions = directions / directions.view(B, n_directions, -1).norm(dim=-1, keepdim=True).unsqueeze(-1).unsqueeze(-1)

    layer_ranks = {}

    # For each layer, compute Jacobian rank
    for layer_idx, (layer_name, layer_module) in enumerate(layers):
        if verbose:
            print(f"  Layer {layer_idx+1}/{len(layers)}: {layer_name}", end="", flush=True)
        # Create a partial model up to this layer
        layer_jvps = []

        for d_idx in range(n_directions):
            v = directions[:, d_idx]  # (B, C, H, W)

            # Forward pass with perturbation
            capture.clear()
            with torch.no_grad():
                _ = model(images + eps * v)
            output_plus = capture.outputs.get(layer_name)

            capture.clear()
            with torch.no_grad():
                _ = model(images - eps * v)
            output_minus = capture.outputs.get(layer_name)

            if output_plus is None or output_minus is None:
                continue

            # Flatten spatial dimensions if present (use reshape for non-contiguous)
            if output_plus.dim() == 4:  # (B, C, H, W)
                output_plus = output_plus.reshape(B, -1)
                output_minus = output_minus.reshape(B, -1)
            elif output_plus.dim() == 3:  # (B, seq_len, dim) for ViT
                output_plus = output_plus.reshape(B, -1)
                output_minus = output_minus.reshape(B, -1)

            jvp = (output_plus - output_minus) / (2 * eps)
            layer_jvps.append(jvp)

        if not layer_jvps:
            if verbose:
                print(" -> skipped (no JVPs)", flush=True)
            continue

        # Stack JVPs: (B, n_directions, output_dim)
        jvp_matrix = torch.stack(layer_jvps, dim=1)

        # Compute singular values for each sample
        singular_values_batch = []
        for b in range(B):
            try:
                U, S, V = torch.linalg.svd(jvp_matrix[b], full_matrices=False)
                singular_values_batch.append(S)
            except:
                continue

        if not singular_values_batch:
            if verbose:
                print(" -> skipped (SVD failed)", flush=True)
            continue

        singular_values = torch.stack(singular_values_batch)  # (B, k)

        # Compute effective rank: (Σσ)² / Σσ²
        sv_sum = singular_values.sum(dim=-1)
        sv_sq_sum = (singular_values ** 2).sum(dim=-1)
        effective_rank = (sv_sum ** 2 / (sv_sq_sum + 1e-8)).mean().item()

        layer_ranks[layer_name] = effective_rank

        if verbose:
            print(f" -> rank={effective_rank:.1f}", flush=True)

    capture.remove_hooks()

    return layer_ranks


def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device}")

    # Models to analyze - comprehensive coverage
    models_to_test = [
        # Variance-Decorrelation
        ("barlow_twins_resnet50", "Barlow Twins"),
        ("vicreg_resnet50", "VICReg"),
        # Vision-Language
        ("clip_vitb16", "CLIP"),
        ("siglip_vitb16", "SigLIP"),
        # Self-Supervised (self-distillation)
        ("dinov2_vitb14", "DINOv2"),
        ("dino_vitb16", "DINO"),
        # Masked Image Modeling
        ("mae_vitb16", "MAE"),
        ("beit_vitb16", "BEiT"),
        # Other SSL
        ("simclr_resnet50", "SimCLR"),
        ("swav_resnet50", "SwAV"),
        ("ijepa_vitb16", "I-JEPA"),
    ]

    # Generate random images (matching main experiment settings)
    n_images = 20
    torch.manual_seed(42)
    images = torch.randn(n_images, 3, 224, 224, device=device)
    images = images * 0.225 + 0.45  # Match normalization from run_jacobian_all.py
    images = images.clamp(0, 1)

    results = {}

    for model_name, display_name in models_to_test:
        print(f"\nProcessing {display_name}...", flush=True)

        try:
            model_wrapper = load_model(model_name, device=device)
            model_wrapper.model.eval()

            layer_ranks = compute_layerwise_jacobian_rank(
                model_wrapper, images, n_directions=32, verbose=True
            )

            if layer_ranks:
                # Convert to ordered list with proper ordering
                def layer_sort_key(name):
                    # Handle block_N format
                    if name.startswith('block_'):
                        try:
                            return (0, int(name.split('_')[1]))
                        except ValueError:
                            return (1, 0)
                    # Handle layerN format (ResNet)
                    if name.startswith('layer'):
                        try:
                            return (0, int(name[5:]))
                        except ValueError:
                            return (1, 0)
                    # Post-block layers in order
                    post_order = {'norm': 100, 'ln_post': 101, 'fc_norm': 102,
                                  'avgpool': 103, 'fc': 104, 'head': 105}
                    return (1, post_order.get(name, 200))

                layer_names = sorted(layer_ranks.keys(), key=layer_sort_key)
                ranks = [layer_ranks[name] for name in layer_names]

                results[display_name] = {
                    'layer_names': layer_names,
                    'effective_ranks': ranks,
                    'n_layers': len(layer_names),
                }

                print(f"  Layers: {len(layer_names)}")
                print(f"  Rank range: {min(ranks):.2f} - {max(ranks):.2f}")

            del model_wrapper
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"  Error: {e}")
            continue

    # Save results
    output_dir = Path("../outputs")
    output_dir.mkdir(exist_ok=True)

    with open(output_dir / "layerwise_rank.json", "w") as f:
        json.dump(results, f, indent=2)

    print(f"\nSaved to {output_dir / 'layerwise_rank.json'}")


if __name__ == "__main__":
    main()
