import pytest
import torch
import torch.nn as nn

from prune.obs_hybrid_pruner_vectorized_column import (
    HybridOBSLinearPruner,
    HybridOBSConv1dPruner,
    HybridOBSConv2dPruner,
)

# --------------------------------------------------------------------------- #
# Helpers                                                                     #
# --------------------------------------------------------------------------- #
def _generic_test(pruner_cls, layer, input_builder, block_size, cand_blocks):
    """Run a minimal OBS cycle: Hessian → importance → prune-1."""
    pruner = pruner_cls(layer, block_size)
    pruner.set_selected_blocks(cand_blocks)

    # accumulate Hessian
    torch.manual_seed(0)
    for _ in range(3):
        pruner.add_batch(input_builder())

    pruner.finalize_calibration()

    importance_dict = pruner.importance_all()
    assert len(importance_dict) == len(cand_blocks)
    assert all(torch.isfinite(torch.tensor(score)) for score in importance_dict.values())

    w_before = layer.weight.detach().clone()
    blk, _ = pruner.prune_lowest()
    assert blk in cand_blocks
    assert len(pruner.selected_blocks) == len(cand_blocks) - 1
    assert not torch.allclose(w_before, layer.weight)

    # importance length shrank
    assert len(pruner.importance_all()) == len(cand_blocks) - 1


def _run_stress(pruner_cls, layer, build_input, block_size, n_blocks, n_prunes):
    pruner = pruner_cls(layer, block_size)
    pruner.set_selected_blocks(list(range(n_blocks)))

    for _ in range(8):
        pruner.add_batch(build_input())
    pruner.finalize_calibration()

    prev = n_blocks
    for _ in range(n_prunes):
        importance_dict = pruner.importance_all()
        assert len(importance_dict) == prev
        blk, _ = pruner.prune_lowest()
        assert len(pruner.selected_blocks) == prev - 1
        prev -= 1


# --------------------------------------------------------------------------- #
# Smoke tests (one prune)                                                     #
# --------------------------------------------------------------------------- #
@pytest.mark.parametrize(
    "layer,block_size,cand_blocks,input_builder",
    [
        pytest.param(
            nn.Linear(8, 4, bias=False),
            2,
            [0, 1, 2],
            lambda: torch.randn(5, 8, dtype=torch.bfloat16),
            id="linear",
        ),
        pytest.param(
            nn.Conv1d(6, 3, kernel_size=3, bias=False),
            2,
            [0, 1],
            lambda: torch.randn(4, 6, 12, dtype=torch.bfloat16),
            id="conv1d",
        ),
        pytest.param(
            nn.Conv2d(6, 3, kernel_size=3, bias=False),
            2,
            [0, 1],
            lambda: torch.randn(4, 6, 10, 10, dtype=torch.bfloat16),
            id="conv2d",
        ),
    ],
)
def test_smoke(layer, block_size, cand_blocks, input_builder):
    if torch.cuda.is_available():
        layer = layer.cuda()
        input_builder = lambda fn=input_builder: fn().cuda()
    _generic_test(
        {
            nn.Linear: HybridOBSLinearPruner,
            nn.Conv1d: HybridOBSConv1dPruner,
            nn.Conv2d: HybridOBSConv2dPruner,
        }[type(layer)],  # pick pruner
        layer,
        input_builder,
        block_size,
        cand_blocks,
    )


# --------------------------------------------------------------------------- #
# Stress tests (multi-prune loop)                                             #
# --------------------------------------------------------------------------- #
@pytest.mark.parametrize(
    "layer,block_size,n_blocks,n_prunes,input_builder,pruner_cls",
    [
        pytest.param(
            nn.Linear(24, 32, bias=False),
            3,
            6,
            5,
            lambda: torch.randn(16, 24, dtype=torch.bfloat16),
            HybridOBSLinearPruner,
            id="big_linear",
        ),
        pytest.param(
            nn.Conv1d(16, 12, kernel_size=3, padding=1, bias=False),
            4,
            3,
            3,
            lambda: torch.randn(8, 16, 64, dtype=torch.bfloat16),
            HybridOBSConv1dPruner,
            id="big_conv1d",
        ),
        pytest.param(
            nn.Conv2d(12, 10, kernel_size=3, padding=1, bias=False),
            3,
            4,
            3,
            lambda: torch.randn(6, 12, 32, 32, dtype=torch.bfloat16),
            HybridOBSConv2dPruner,
            id="big_conv2d",
        ),
    ],
)
def test_stress(layer, block_size, n_blocks, n_prunes, input_builder, pruner_cls):
    if torch.cuda.is_available():
        layer = layer.cuda()
        input_builder = lambda fn=input_builder: fn().cuda()
    _run_stress(pruner_cls, layer, input_builder, block_size, n_blocks, n_prunes)


# --------------------------------------------------------------------------- #
# Static weights test (verification of importance scores and weight updates)  #
# --------------------------------------------------------------------------- #
def test_static_linear_importance_and_weight_update():
    """Test importance scores and weight updates with static weights and input."""
    # Create a small linear layer with static weights
    # Input: 4 features, Output: 3 features
    layer = nn.Linear(4, 3, bias=False)
    
    # Set static weights for reproducible testing
    static_weights = torch.tensor([
        [1.0, 2.0, 3.0, 4.0],
        [5.0, 6.0, 7.0, 8.0],
        [9.0, 10.0, 11.0, 12.0]
    ], dtype=torch.float32)
    layer.weight.data = static_weights.clone()
    
    # Create static input activation
    static_input = torch.tensor([
        [1.0, 2.0, 3.0, 4.0],
        [2.0, 3.0, 4.0, 5.0],
        [3.0, 4.0, 5.0, 6.0],
        [4.0, 5.0, 6.0, 7.0]
    ], dtype=torch.float32)
    
    # Move to GPU if available
    if torch.cuda.is_available():
        layer = layer.cuda()
        static_input = static_input.cuda()
    
    # Create pruner with block size 1 (each column is a block)
    block_size = 1
    pruner = HybridOBSLinearPruner(layer, block_size)
    
    # Set all columns as candidate blocks
    cand_blocks = [0, 1, 2, 3]  # 4 columns
    pruner.set_selected_blocks(cand_blocks)
    
    # Add the same static input multiple times to accumulate Hessian
    for _ in range(5):
        pruner.add_batch(static_input)
    
    pruner.finalize_calibration()
    
    # Debug: Let's examine the Hessian and G_CC
    print(f"Hessian shape: {pruner._H_accum.shape if pruner._H_accum is not None else 'None'}")
    print(f"H_panel shape: {pruner.H_panel.shape if pruner.H_panel is not None else 'None'}")
    print(f"G_CC shape: {pruner.G_CC.shape if pruner.G_CC is not None else 'None'}")
    
    if pruner._H_accum is not None:
        print(f"Hessian:\n{pruner._H_accum}")
        # Check for NaN/Inf before computing eigenvalues
        if torch.isfinite(pruner._H_accum).all():
            print(f"Hessian eigenvalues: {torch.linalg.eigvals(pruner._H_accum)}")
        else:
            print("Hessian contains NaN/Inf values")
    
    if pruner.G_CC is not None:
        print(f"G_CC:\n{pruner.G_CC}")
        # Check for NaN/Inf before computing eigenvalues
        if torch.isfinite(pruner.G_CC).all():
            print(f"G_CC eigenvalues: {torch.linalg.eigvals(pruner.G_CC)}")
        else:
            print("G_CC contains NaN/Inf values")
            print(f"G_CC has {torch.isnan(pruner.G_CC).sum()} NaN values")
            print(f"G_CC has {torch.isinf(pruner.G_CC).sum()} Inf values")
    
    # Get importance scores for all columns
    importance_scores = pruner.importance_all()
    
    # Verify importance scores
    assert importance_scores.numel() == 4  # 4 columns
    assert torch.isfinite(importance_scores).all()
    # Note: OBS importance scores can be negative, which is mathematically correct
    
    # Print importance scores for debugging
    print(f"Importance scores: {importance_scores}")
    
    # Store original weights before pruning
    weights_before = layer.weight.detach().clone()
    
    # Prune the lowest importance column
    pruned_block, _ = pruner.prune_lowest()
    print(f"Pruned block: {pruned_block}")
    
    # Verify that the pruned block was in our candidate blocks
    assert pruned_block in cand_blocks
    
    # Verify that selected blocks decreased by 1
    assert len(pruner.selected_blocks) == len(cand_blocks) - 1
    
    # Verify that the pruned column is no longer in selected blocks
    assert pruned_block not in pruner.selected_blocks
    
    # Get weights after pruning
    weights_after = layer.weight.detach().clone()
    
    # Verify that weights changed (OBS should update remaining weights)
    assert not torch.allclose(weights_before, weights_after)
    
    # Verify that the pruned column has zero weights
    assert torch.allclose(weights_after[:, pruned_block], torch.zeros_like(weights_after[:, pruned_block]))
    
    # Verify that non-pruned columns have non-zero weights
    remaining_columns = [i for i in range(4) if i != pruned_block]
    for col in remaining_columns:
        assert not torch.allclose(weights_after[:, col], torch.zeros_like(weights_after[:, col]))
    
    # Test that the layer still produces valid output
    with torch.no_grad():
        output = layer(static_input)
        assert output.shape == (4, 3)  # batch_size=4, out_features=3
        assert torch.isfinite(output).all()
    
    # Verify importance scores decreased by 1
    new_importance = pruner.importance_all()
    assert new_importance.numel() == len(cand_blocks) - 1
    
    print(f"Original weights:\n{weights_before}")
    print(f"Updated weights:\n{weights_after}")
    print(f"Pruned column: {pruned_block}")
    print(f"Remaining columns: {remaining_columns}")


def test_debug_importance_calculation():
    """Debug test to understand the importance calculation step by step."""
    # Create a very simple case
    layer = nn.Linear(2, 1, bias=False)  # 1 output, 2 inputs
    
    # Simple weights
    layer.weight.data = torch.tensor([[1.0, 2.0]], dtype=torch.float32)
    
    # Simple input
    static_input = torch.tensor([
        [1.0, 2.0, 3.0, 4.0],
        [2.0, 3.0, 4.0, 5.0],
        [3.0, 4.0, 5.0, 6.0],
        [4.0, 5.0, 6.0, 7.0]
    ], dtype=torch.float32)
    
    if torch.cuda.is_available():
        layer = layer.cuda()
        static_input = static_input.cuda()
    
    # Create pruner
    pruner = HybridOBSLinearPruner(layer, block_size=1)
    pruner.set_selected_blocks([0, 1])  # Both columns
    
    # Add batch
    pruner.add_batch(static_input)
    pruner.finalize_calibration()
    
    # Manual calculation to understand what's happening
    print(f"Layer weights: {layer.weight}")
    print(f"Input: {static_input}")
    
    if pruner._H_accum is not None:
        print(f"Hessian: {pruner._H_accum}")
        print(f"Hessian eigenvalues: {torch.linalg.eigvals(pruner._H_accum)}")
    
    if pruner.G_CC is not None:
        print(f"G_CC: {pruner.G_CC}")
        print(f"G_CC eigenvalues: {torch.linalg.eigvals(pruner.G_CC)}")
    
    # Get importance scores
    importance_scores = pruner.importance_all()
    print(f"Importance scores: {importance_scores}")
    
    # Manual verification of the calculation
    if pruner.G_CC is not None:
        # For block_size=1, G_CC should be 2x2
        # Each block is 1x1, so G_pp should be diagonal elements
        G_pp = pruner.G_CC.diagonal().unsqueeze(1)  # (2, 1)
        print(f"G_pp (diagonal blocks): {G_pp}")
        
        # Weights for each block (each column)
        W_blocks = layer.weight.T.unsqueeze(-1)  # (2, 1, 1)
        print(f"W_blocks: {W_blocks}")
        
        # Solve G_pp * x = w for each block (element-wise division since G_pp is diagonal)
        sol = W_blocks / G_pp  # (2, 1, 1)
        print(f"Solution x: {sol}")
        
        # Importance = 0.5 * w^T * x
        importance_manual = 0.5 * (W_blocks.squeeze(-1) * sol.squeeze(-1))  # (2,)
        print(f"Manual importance: {importance_manual}")
        
        print(f"Computed importance: {importance_scores}")
        print(f"Match: {torch.allclose(importance_manual, importance_scores)}")
        
        # Let's also check what the actual importance calculation does
        print(f"G_CC shape: {pruner.G_CC.shape}")
        print(f"Selected blocks: {pruner.selected_blocks}")
        print(f"Block size: {pruner.block_size}")
        print(f"Kernel elems: {getattr(pruner, 'kernel_elems', 1)}")
        

@pytest.mark.parametrize("pruner_cls,layer,input_shape,dead_indices,weight_zero_fn", [
    # Linear: dead input features (columns)
    (
        HybridOBSLinearPruner,
        nn.Linear(8, 4, bias=False),
        (2, 8),
        [2, 5],
        lambda w, dead: (w[:, dead].abs() < 1e-5).all()
    ),
    # Conv1d: dead input channels
    (
        HybridOBSConv1dPruner,
        nn.Conv1d(6, 3, kernel_size=3, bias=False),
        (2, 6, 10),
        [1, 4],
        lambda w, dead: (w[:, dead, :].abs() < 1e-5).all()
    ),
    # Conv2d: dead input channels
    (
        HybridOBSConv2dPruner,
        nn.Conv2d(6, 3, kernel_size=3, bias=False),
        (2, 6, 8, 8),
        [0, 5],
        lambda w, dead: (w[:, dead, :, :].abs() < 1e-5).all()
    ),
])
def test_dead_direction_zeroing(pruner_cls, layer, input_shape, dead_indices, weight_zero_fn):
    # Simulate a Hessian with dead directions
    d = layer.weight.shape[1] if isinstance(layer, nn.Linear) else layer.weight.shape[1]
    H = torch.eye(d)
    H = H.to(layer.weight.device).to(layer.weight.dtype)
    for idx in dead_indices:
        H[idx, idx] = 0

    # Set up pruner
    pruner = pruner_cls(layer, block_size=1)
    pruner._H_accum = H
    pruner.calibration_count = 1
    pruner.selected_indices = list(range(d))

    # Fill weights with nonzero values
    layer.weight.data.fill_(1.0)

    pruner.finalize_calibration()

    # Check that the dead weights are zeroed
    assert weight_zero_fn(layer.weight.data, torch.tensor(dead_indices)), \
        f"Dead directions {dead_indices} not zeroed for {pruner_cls.__name__}"

    # Check that non-dead weights are not all zero
    non_dead = [i for i in range(d) if i not in dead_indices]
    if isinstance(layer, nn.Linear):
        assert (layer.weight.data[:, non_dead] != 0).any()
    elif isinstance(layer, nn.Conv1d):
        assert (layer.weight.data[:, non_dead, :] != 0).any()
    elif isinstance(layer, nn.Conv2d):
        assert (layer.weight.data[:, non_dead, :, :] != 0).any()

    print(f"{pruner_cls.__name__} dead direction zeroing test passed.")
        

