#!/usr/bin/env python3
"""
Compute Jacobian singular value spectrum for Figure 3.

Target models: Barlow Twins, CLIP, DINOv2, MAE
"""

import argparse
import json
import sys
from pathlib import Path

import torch
import numpy as np
from tqdm import tqdm

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.models import load_model
from src.jacobian import estimate_jacobian_singular_values


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--models", nargs="+", default=None)
    parser.add_argument("--output", type=str, default="outputs/fig3_spectrum.json")
    parser.add_argument("--n-samples", type=int, default=50)
    parser.add_argument("--n-directions", type=int, default=64)
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device}")

    # Figure 3 models
    target_models = args.models or [
        "barlow_twins_resnet50",
        "clip_vitb16",
        "dinov2_vitb14",
        "mae_vitb16",
    ]

    results = {}

    for model_name in target_models:
        print(f"\n{'='*60}")
        print(f"Processing {model_name}...")
        print('='*60)

        model_wrapper = load_model(model_name, device=device)
        model = model_wrapper.model.eval()

        # Generate random noise images
        torch.manual_seed(42)
        images = torch.randn(args.n_samples, 3, 224, 224, device=device) * 0.225 + 0.45
        images = images.clamp(0, 1)
        mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
        images = (images - mean) / std

        all_sv = []

        for i in tqdm(range(args.n_samples), desc="Computing spectrum"):
            img = images[i:i+1]
            singular_values, _ = estimate_jacobian_singular_values(
                model, img, n_directions=args.n_directions, n_power_iterations=10
            )
            all_sv.append(singular_values.cpu().numpy())

        all_sv = np.concatenate(all_sv, axis=0)
        mean_spectrum = all_sv.mean(axis=0)
        std_spectrum = all_sv.std(axis=0)

        # Normalize by top singular value
        top_sv = mean_spectrum[0]
        normalized_spectrum = mean_spectrum / top_sv

        results[model_name] = {
            "mean_spectrum": mean_spectrum.tolist(),
            "std_spectrum": std_spectrum.tolist(),
            "normalized_spectrum": normalized_spectrum.tolist(),
            "top_singular_value": float(top_sv),
            "decay_ratio": float(mean_spectrum[0] / mean_spectrum[-1]),
        }

        print(f"  Top SV: {top_sv:.4f}")
        print(f"  Decay ratio (σ₁/σ₆₄): {results[model_name]['decay_ratio']:.2f}")

        del model, model_wrapper
        torch.cuda.empty_cache()

    # Save
    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nSaved to {output_path}")


if __name__ == "__main__":
    main()
