#!/usr/bin/env python3
"""
Mock tests to verify correctness of every experiment component.

Uses synthetic data with KNOWN properties to validate code logic.
Does NOT depend on real model behavior.
"""

import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# ============================================================================
# Test utilities - synthetic data with known properties
# ============================================================================

def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)


def make_isotropic_embeddings(n, d):
    """Embeddings with uniform variance across dimensions."""
    set_seed()
    return torch.randn(n, d) / np.sqrt(d)


def make_anisotropic_embeddings(n, d, top_ratio=0.9):
    """Embeddings with variance concentrated in first dimension."""
    set_seed()
    x = torch.randn(n, d)
    x[:, 0] *= np.sqrt(top_ratio * d)
    x[:, 1:] *= np.sqrt((1 - top_ratio) * d / (d - 1))
    return x / x.norm(dim=1, keepdim=True)


class MockModel(nn.Module):
    """Mock model that returns predictable embeddings."""
    def __init__(self, embed_dim=64):
        super().__init__()
        self.embed_dim = embed_dim
        # Simple projection
        self.proj = nn.Linear(3 * 224 * 224, embed_dim, bias=False)

    def forward(self, x):
        return self.proj(x.view(x.shape[0], -1))


class PerfectBindingModel(nn.Module):
    """Mock model that perfectly solves binding by encoding shape+position."""
    def __init__(self):
        super().__init__()
        self.embed_dim = 64

    def forward(self, x):
        """Encode based on left/right pixel regions."""
        B = x.shape[0]
        # Left half average
        left = x[:, :, :, :112].mean(dim=(1, 2, 3))
        # Right half average
        right = x[:, :, :, 112:].mean(dim=(1, 2, 3))
        # Combine into embedding
        emb = torch.zeros(B, self.embed_dim)
        emb[:, 0] = left
        emb[:, 1] = right
        emb[:, 2] = left - right  # Encodes structure
        return F.normalize(emb, dim=-1)


# ============================================================================
# Test 1: Global Isotropy Metrics
# ============================================================================

def test_global_isotropy_discrimination():
    """Isotropic data should have HIGH isotropy, anisotropic should have LOW."""
    print("\n[TEST] Global isotropy discrimination...")
    from src.metrics import compute_global_isotropy

    iso_emb = make_isotropic_embeddings(500, 128)
    aniso_emb = make_anisotropic_embeddings(500, 128, top_ratio=0.95)

    iso_m = compute_global_isotropy(iso_emb)
    aniso_m = compute_global_isotropy(aniso_emb)

    print(f"  Isotropic:   score={iso_m.isotropy_score:.4f}, eff_dim={iso_m.effective_dim:.1f}")
    print(f"  Anisotropic: score={aniso_m.isotropy_score:.4f}, eff_dim={aniso_m.effective_dim:.1f}")

    # Key assertions
    assert iso_m.isotropy_score > 0.5, f"Isotropic should have score > 0.5, got {iso_m.isotropy_score}"
    assert aniso_m.isotropy_score < 0.1, f"Anisotropic should have score < 0.1, got {aniso_m.isotropy_score}"
    assert iso_m.isotropy_score > aniso_m.isotropy_score * 5, "Should discriminate by at least 5x"
    assert iso_m.effective_dim > aniso_m.effective_dim * 10, "Effective dim should differ greatly"

    print("  PASSED")


def test_eigenvalue_properties():
    """Eigenvalues should be sorted descending and non-negative."""
    print("\n[TEST] Eigenvalue properties...")
    from src.metrics import compute_global_isotropy

    emb = make_isotropic_embeddings(200, 64)
    m = compute_global_isotropy(emb)

    eigs = m.eigenvalues
    print(f"  Top 5 eigenvalues: {eigs[:5].tolist()}")

    # Sorted descending
    assert all(eigs[i] >= eigs[i+1] for i in range(len(eigs)-1)), "Must be sorted descending"
    # Non-negative
    assert (eigs >= 0).all(), "Must be non-negative"
    # Sum to reasonable value
    assert eigs.sum() > 0, "Must have positive total variance"

    print("  PASSED")


# ============================================================================
# Test 2: Local Anisotropy Metrics
# ============================================================================

def test_local_anisotropy_range():
    """Local anisotropy should be in [0, 1]."""
    print("\n[TEST] Local anisotropy range...")
    from src.metrics import compute_local_anisotropy_random

    emb = make_isotropic_embeddings(300, 128)
    m = compute_local_anisotropy_random(emb, n_samples=50, k=16, verbose=False)

    print(f"  Mean anisotropy: {m.mean_anisotropy:.4f}")
    print(f"  Std anisotropy: {m.std_anisotropy:.4f}")

    assert 0 <= m.mean_anisotropy <= 1, f"Must be in [0,1], got {m.mean_anisotropy}"
    assert m.std_anisotropy >= 0, "Std must be non-negative"
    assert m.mean_effective_dim >= 1, "Effective dim must be >= 1"

    print("  PASSED")


def test_local_anisotropy_single():
    """Single neighborhood with known structure."""
    print("\n[TEST] Local anisotropy single neighborhood...")
    from src.metrics import compute_local_anisotropy_single

    # Neighborhood concentrated in one direction
    set_seed()
    neighborhood = torch.randn(16, 64)
    neighborhood[:, 0] *= 10  # Amplify first dimension

    aniso, top2, eff_dim = compute_local_anisotropy_single(neighborhood)

    print(f"  Anisotropy: {aniso:.4f}")
    print(f"  Top-2 ratio: {top2:.4f}")
    print(f"  Effective dim: {eff_dim:.1f}")

    assert aniso > 0.5, f"Should detect concentrated variance, got {aniso}"
    assert top2 >= aniso, "Top-2 should be >= top-1"
    assert eff_dim < 5, f"Effective dim should be low, got {eff_dim}"

    print("  PASSED")


# ============================================================================
# Test 3: Jacobian Computation
# ============================================================================

def test_jacobian_linear_model():
    """Linear model should have predictable Jacobian properties."""
    print("\n[TEST] Jacobian on linear model...")
    from src.jacobian import compute_jacobian_anisotropy

    # Linear model: J = W (constant)
    # Use 224x224 to match MockModel's expected input
    model = MockModel(embed_dim=32)
    images = torch.randn(2, 3, 224, 224)  # Standard size

    metrics = compute_jacobian_anisotropy(model, images, n_directions=16, n_power_iterations=3)

    print(f"  Top singular value: {metrics['top_singular']:.4f}")
    print(f"  Effective rank: {metrics['effective_rank']:.2f}")

    assert metrics['top_singular'] > 0, "Must have positive singular values"
    assert 1 <= metrics['effective_rank'] <= 16, f"Rank must be in [1, 16], got {metrics['effective_rank']}"
    assert np.isfinite(metrics['spectral_ratio']), "Spectral ratio must be finite"

    print("  PASSED")


def test_jvp_finite_difference():
    """Finite difference JVP should approximate true Jacobian."""
    print("\n[TEST] JVP finite difference accuracy...")
    from src.jacobian import _compute_jvp_finite_diff

    # For linear model f(x) = Wx, Jv = Wv
    d_in, d_out = 64, 32
    W = torch.randn(d_out, d_in)

    class LinearModel(nn.Module):
        def forward(self, x):
            return x.view(x.shape[0], -1) @ W.T

    model = LinearModel()
    x = torch.randn(2, 1, 8, 8)  # 64 dims
    v = torch.randn(2, 1, 1, 8, 8)

    # True JVP
    true_jvp = (v.view(2, 64) @ W.T)  # (2, 32)

    # Finite diff JVP
    fd_jvp = _compute_jvp_finite_diff(model, x, v, eps=1e-4).squeeze(1)  # (2, 32)

    rel_error = (fd_jvp - true_jvp).norm() / (true_jvp.norm() + 1e-8)
    print(f"  Relative error: {rel_error:.6f}")

    assert rel_error < 0.01, f"FD error too high: {rel_error}"

    print("  PASSED")


# ============================================================================
# Test 4: Benchmark Logic
# ============================================================================

def test_benchmark_trial_generation():
    """Benchmark should generate valid trials with disjoint colors."""
    print("\n[TEST] Benchmark trial generation...")
    from src.benchmarks import AttributeBindingBenchmark, SyntheticShapeGenerator

    bench = AttributeBindingBenchmark(device='cpu', num_samples=10, seed=42)

    # Generate multiple trials and check properties
    for i in range(5):
        query, candidates, correct_idx = bench.generate_trial()

        assert query.size == (224, 224), f"Query wrong size: {query.size}"
        assert len(candidates) == 4, f"Should have 4 candidates, got {len(candidates)}"
        assert 0 <= correct_idx < 4, f"Invalid correct_idx: {correct_idx}"

        for c in candidates:
            assert c.size == (224, 224), f"Candidate wrong size"

    print(f"  Generated 5 valid trials")
    print("  PASSED")


def test_benchmark_with_perfect_model():
    """A model that encodes structure should achieve high accuracy."""
    print("\n[TEST] Benchmark with structure-aware mock model...")
    from src.benchmarks import AttributeBindingBenchmark

    # Use a model that actually looks at spatial structure
    model = PerfectBindingModel()

    bench = AttributeBindingBenchmark(device='cpu', num_samples=20, seed=42)
    result = bench.evaluate(model)

    print(f"  Accuracy: {result.accuracy:.4f} (chance: {result.chance_level:.4f})")

    # Should be above chance (0.25), though not necessarily perfect
    # because our mock model is simple
    assert result.accuracy >= 0.0, "Accuracy must be non-negative"
    assert result.accuracy <= 1.0, "Accuracy must be <= 1"
    assert result.num_samples == 20, f"Wrong num_samples: {result.num_samples}"

    print("  PASSED")


def test_same_different_benchmark():
    """Same/Different benchmark should work correctly."""
    print("\n[TEST] Same/Different benchmark...")
    from src.benchmarks import SameDifferentBenchmark

    model = PerfectBindingModel()
    bench = SameDifferentBenchmark(device='cpu', num_samples=20, seed=42)
    result = bench.evaluate(model)

    print(f"  Accuracy: {result.accuracy:.4f} (chance: {result.chance_level:.4f})")

    assert 0 <= result.accuracy <= 1, "Accuracy must be in [0, 1]"
    assert result.chance_level == 0.5, "Chance should be 0.5"

    print("  PASSED")


# ============================================================================
# Test 5: Shape Generator
# ============================================================================

def test_shape_generator_colors():
    """Generated images should contain specified colors."""
    print("\n[TEST] Shape generator colors...")
    from src.benchmarks import SyntheticShapeGenerator

    gen = SyntheticShapeGenerator(seed=42)

    # Create image with red circle and blue square
    objects = [
        {'shape': 'circle', 'color': 'red', 'position': 'left'},
        {'shape': 'square', 'color': 'blue', 'position': 'right'},
    ]
    img = gen.create_image(objects)
    pixels = np.array(img)

    # Check for red pixels (R > 200, G < 100, B < 100)
    red_mask = (pixels[:, :, 0] > 200) & (pixels[:, :, 1] < 100) & (pixels[:, :, 2] < 100)
    # Check for blue pixels (R < 100, G < 100, B > 200)
    blue_mask = (pixels[:, :, 0] < 100) & (pixels[:, :, 1] < 100) & (pixels[:, :, 2] > 200)

    red_count = red_mask.sum()
    blue_count = blue_mask.sum()

    print(f"  Red pixels: {red_count}")
    print(f"  Blue pixels: {blue_count}")

    assert red_count > 1000, f"Should have red pixels, got {red_count}"
    assert blue_count > 1000, f"Should have blue pixels, got {blue_count}"

    # Check positions (red should be on left, blue on right)
    red_x = np.where(red_mask)[1].mean()
    blue_x = np.where(blue_mask)[1].mean()

    print(f"  Red center x: {red_x:.0f}, Blue center x: {blue_x:.0f}")

    assert red_x < 112, f"Red should be on left (x < 112), got {red_x}"
    assert blue_x > 112, f"Blue should be on right (x > 112), got {blue_x}"

    print("  PASSED")


def test_shape_generator_all_shapes():
    """All shape types should be drawable."""
    print("\n[TEST] Shape generator all shapes...")
    from src.benchmarks import SyntheticShapeGenerator

    gen = SyntheticShapeGenerator(seed=42)
    bg_color = np.array([200, 200, 200])

    for shape in ['circle', 'square', 'triangle']:
        objects = [{'shape': shape, 'color': 'red', 'position': 'left'}]
        img = gen.create_image(objects)
        pixels = np.array(img)

        # Count non-background pixels
        diff = np.abs(pixels.astype(float) - bg_color).sum(axis=2)
        non_bg = (diff > 50).sum()

        print(f"  {shape}: {non_bg} non-background pixels")
        assert non_bg > 500, f"{shape} should have visible area"

    print("  PASSED")


# ============================================================================
# Test 6: Integration - Metrics on Benchmark Embeddings
# ============================================================================

def test_metrics_on_controlled_embeddings():
    """Metrics should work correctly on controlled embedding distributions."""
    print("\n[TEST] Metrics integration...")
    from src.metrics import compute_global_isotropy, compute_local_anisotropy_random

    # Create embeddings with known structure
    n, d = 200, 64
    set_seed()

    # Clustered embeddings (5 clusters)
    n_clusters = 5
    centers = torch.randn(n_clusters, d)
    centers = F.normalize(centers, dim=-1) * 2
    labels = torch.arange(n) % n_clusters
    embeddings = centers[labels] + 0.1 * torch.randn(n, d)
    embeddings = F.normalize(embeddings, dim=-1)

    # Compute metrics
    global_m = compute_global_isotropy(embeddings)
    local_m = compute_local_anisotropy_random(embeddings, n_samples=50, k=10, verbose=False)

    print(f"  G.Iso: {global_m.isotropy_score:.4f}")
    print(f"  L.Ani: {local_m.mean_anisotropy:.4f}")

    # Clustered data should have:
    # - Lower global isotropy (variance concentrated in cluster directions)
    # - Moderate local anisotropy (within-cluster structure)
    assert 0 < global_m.isotropy_score < 1, "G.Iso should be in (0, 1)"
    assert 0 < local_m.mean_anisotropy < 1, "L.Ani should be in (0, 1)"

    print("  PASSED")


# ============================================================================
# Main runner
# ============================================================================

def run_all_tests():
    print("="*70)
    print("MOCK IDENTITY TESTS FOR SUPP_MINIMAL")
    print("="*70)

    tests = [
        # Metrics tests
        test_global_isotropy_discrimination,
        test_eigenvalue_properties,
        test_local_anisotropy_range,
        test_local_anisotropy_single,

        # Jacobian tests
        test_jacobian_linear_model,
        test_jvp_finite_difference,

        # Benchmark tests
        test_benchmark_trial_generation,
        test_benchmark_with_perfect_model,
        test_same_different_benchmark,

        # Shape generator tests
        test_shape_generator_colors,
        test_shape_generator_all_shapes,

        # Integration
        test_metrics_on_controlled_embeddings,
    ]

    passed = 0
    failed = 0
    errors = []

    for test in tests:
        try:
            test()
            passed += 1
        except AssertionError as e:
            failed += 1
            errors.append((test.__name__, f"ASSERT: {e}"))
            print(f"  FAILED: {e}")
        except Exception as e:
            failed += 1
            errors.append((test.__name__, f"ERROR: {type(e).__name__}: {e}"))
            print(f"  ERROR: {type(e).__name__}: {e}")

    print("\n" + "="*70)
    print(f"RESULTS: {passed} passed, {failed} failed")
    print("="*70)

    if errors:
        print("\nFailures:")
        for name, msg in errors:
            print(f"  - {name}: {msg}")

    return failed == 0


if __name__ == "__main__":
    success = run_all_tests()
    sys.exit(0 if success else 1)
