#!/usr/bin/env python3
"""
Main experiment script for Table 1.

Evaluates 21 vision encoders on:
- Geometric metrics (G.PR, G.Iso, L.Iso) on ImageNet
- Jacobian Effective Rank (JER) on random noise
- Compositional benchmarks (Attribute Binding, Same/Different)
"""

import argparse
import json
import os
import sys
from pathlib import Path

import torch
import numpy as np
from tqdm import tqdm

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.models import load_model, list_models
from src.metrics import compute_global_isotropy, compute_local_anisotropy_random
from src.jacobian import compute_jacobian_anisotropy
from src.benchmarks import (
    SameDifferentBenchmark,
    AttributeBindingBenchmark,
    get_default_transform,
)


# All 21 models from the paper
PAPER_MODELS = [
    # Vision-Language
    "clip_vitb32", "clip_vitb16", "clip_vitl14",
    "siglip_vitb16", "eva_clip_vitb16",
    # Self-Distillation
    "dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14",
    "dino_vits16", "dino_vitb16",
    # Masked Prediction
    "mae_vitb16", "mae_vitl16",
    "ijepa_vitb16", "ijepa_vitl16",
    "beit_vitb16", "beitv2_vitb16",
    # Variance-Decorrelation
    "barlow_twins_resnet50", "vicreg_resnet50",
    # Clustering / Contrastive
    "swav_resnet50", "simclr_resnet50",
    # Supervised
    "convnext_base",
]


def generate_imagenet_proxy(n_samples: int, device: str) -> torch.Tensor:
    """Generate random images as ImageNet proxy for geometric metrics."""
    torch.manual_seed(42)
    transform = get_default_transform()

    # Random images with ImageNet-like statistics
    images = torch.randn(n_samples, 3, 224, 224) * 0.225 + 0.45
    images = images.clamp(0, 1)

    # Apply normalization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    images = (images - mean) / std

    return images.to(device)


def extract_embeddings(model, images: torch.Tensor, batch_size: int = 32) -> torch.Tensor:
    """Extract embeddings from images."""
    embeddings = []
    model.eval()

    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            emb = model.model(batch)
            if isinstance(emb, dict):
                emb = emb.get("image", emb.get("features", list(emb.values())[0]))
            if isinstance(emb, tuple):
                emb = emb[0]
            # Handle different output shapes
            if emb.dim() == 4:  # ConvNet
                emb = emb.mean(dim=[2, 3])
            elif emb.dim() == 3:  # ViT
                emb = emb[:, 0]
            embeddings.append(emb.cpu())

    return torch.cat(embeddings, dim=0)


def compute_jer(model, device: str, n_samples: int = 20, n_directions: int = 32) -> float:
    """Compute Jacobian Effective Rank on random noise."""
    torch.manual_seed(42)

    # Random noise images
    images = torch.randn(n_samples, 3, 224, 224, device=device) * 0.225 + 0.45
    images = images.clamp(0, 1)

    # Normalize
    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
    images = (images - mean) / std

    all_jer = []
    for i in tqdm(range(n_samples), desc="Computing JER"):
        img = images[i:i+1]
        metrics = compute_jacobian_anisotropy(
            model.model, img, n_directions=n_directions, n_power_iterations=5
        )
        all_jer.append(metrics['effective_rank'])

    return np.mean(all_jer)


def run_experiment(
    models: list,
    output_path: str,
    device: str = "cuda",
    n_geo_samples: int = 1000,
    n_benchmark_samples: int = 500,
):
    """Run full experiment."""
    results = {}

    # Generate images for geometric metrics
    print("Generating images for geometric metrics...")
    geo_images = generate_imagenet_proxy(n_geo_samples, device)

    for model_key in models:
        print(f"\n{'='*60}")
        print(f"Evaluating: {model_key}")
        print('='*60)

        try:
            model = load_model(model_key, torch.device(device))
        except Exception as e:
            print(f"Failed to load {model_key}: {e}")
            results[model_key] = {"error": str(e)}
            continue

        model_results = {
            "model_name": model.name,
            "embed_dim": model.embed_dim,
        }

        # 1. Extract embeddings for geometric metrics
        print("  Extracting embeddings...")
        embeddings = extract_embeddings(model, geo_images).to(device)

        # 2. Global isotropy metrics
        print("  Computing global isotropy...")
        global_metrics = compute_global_isotropy(embeddings)
        model_results["G.PR"] = float(global_metrics.participation_ratio / global_metrics.embed_dim)
        model_results["G.Iso"] = float(global_metrics.isotropy_score)
        print(f"    G.PR: {model_results['G.PR']:.4f}, G.Iso: {model_results['G.Iso']:.4f}")

        # 3. Local isotropy
        print("  Computing local isotropy...")
        local_metrics = compute_local_anisotropy_random(
            embeddings[:500], n_samples=200, k=16, verbose=False
        )
        # L.Iso = 1 - local_anisotropy (convert from anisotropy to isotropy)
        model_results["L.Iso"] = float(1 - local_metrics.mean_anisotropy)
        print(f"    L.Iso: {model_results['L.Iso']:.4f}")

        # 4. Jacobian Effective Rank
        print("  Computing JER...")
        jer = compute_jer(model, device)
        model_results["JER"] = float(jer)
        print(f"    JER: {model_results['JER']:.2f}")

        # 5. Same/Different benchmark
        print("  Running Same/Different...")
        sd_bench = SameDifferentBenchmark(device=device, num_samples=n_benchmark_samples)
        sd_result = sd_bench.evaluate(model.model)
        model_results["Same/Diff"] = float(sd_result.accuracy)
        print(f"    Same/Diff: {model_results['Same/Diff']:.4f}")

        # 6. Attribute Binding benchmark
        print("  Running Attribute Binding...")
        ab_bench = AttributeBindingBenchmark(device=device, num_samples=n_benchmark_samples)
        ab_result = ab_bench.evaluate(model.model)
        model_results["Binding"] = float(ab_result.accuracy)
        print(f"    Binding: {model_results['Binding']:.4f}")

        results[model_key] = model_results

        # Cleanup
        del model, embeddings
        torch.cuda.empty_cache()

        # Save incrementally
        with open(output_path, 'w') as f:
            json.dump(results, f, indent=2)

    # Print summary table
    print("\n" + "="*90)
    print("RESULTS SUMMARY (Table 1)")
    print("="*90)
    print(f"{'Model':<25} {'G.PR':>8} {'G.Iso':>8} {'L.Iso':>8} {'JER':>8} {'S/D':>8} {'Bind':>8}")
    print("-"*90)

    for model_key, data in results.items():
        if "error" in data:
            print(f"{model_key:<25} ERROR")
            continue
        print(f"{model_key:<25} {data.get('G.PR', 0):>8.4f} {data.get('G.Iso', 0):>8.4f} "
              f"{data.get('L.Iso', 0):>8.4f} {data.get('JER', 0):>8.2f} "
              f"{data.get('Same/Diff', 0):>8.4f} {data.get('Binding', 0):>8.4f}")

    return results


def main():
    parser = argparse.ArgumentParser(description="Run main experiment (Table 1)")
    parser.add_argument("--models", type=str, default=None,
                        help="Comma-separated model list, or 'all' for all 21 models")
    parser.add_argument("--output", type=str, default="outputs/table1.json")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--list-models", action="store_true")
    args = parser.parse_args()

    if args.list_models:
        print("Available models:")
        for m in list_models():
            marker = " [paper]" if m in PAPER_MODELS else ""
            print(f"  {m}{marker}")
        return

    if args.models == "all" or args.models is None:
        models = PAPER_MODELS
    else:
        models = [m.strip() for m in args.models.split(",")]

    os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)

    run_experiment(models, args.output, args.device)


if __name__ == "__main__":
    main()
