import torch
import torch.nn as nn

# Ensure we import from project package path
from prune.hybrid_obs_pruner_certified import HybridOBSLinearPruner


def build_basis_calibration(in_features: int, repeats: int = 1) -> torch.Tensor:
    """
    Build a calibration batch whose X^T X is diagonal (basis vectors), so that after
    diagonal whitening the Hessian becomes the identity on selected indices.
    """
    eye = torch.eye(in_features)
    if repeats > 1:
        eye = eye.repeat(repeats, 1)
    return eye.to(torch.float32)


def test_prune_lowest_block_linear_identity_hessian():
    torch.manual_seed(0)

    # Small, controlled Linear layer: out=3, in=8, block_size=2 -> 4 blocks
    out_features = 3
    in_features = 8
    block_size = 2
    layer = nn.Linear(in_features, out_features, bias=False)

    # Set deterministic weights so that one block has the smallest L2 norm
    # Blocks: [0-1], [2-3], [4-5], [6-7]
    W = torch.zeros(out_features, in_features)
    # Block 0 heavier
    W[:, 0:2] = torch.tensor([[1.0, -1.0], [0.5, 0.5], [0.0, 1.0]])
    # Block 1 medium
    W[:, 2:4] = torch.tensor([[0.2, 0.0], [0.0, -0.2], [0.1, 0.1]])
    # Block 2 light (expected to be pruned)
    W[:, 4:6] = torch.tensor([[0.05, -0.05], [0.0, 0.0], [0.02, -0.02]])
    # Block 3 medium
    W[:, 6:8] = torch.tensor([[0.3, 0.0], [0.0, 0.3], [0.1, 0.0]])
    with torch.no_grad():
        layer.weight.copy_(W)

    # Build pruner
    pruner = HybridOBSLinearPruner(layer=layer, block_size=block_size)

    # Select all blocks
    pruner.set_selected_blocks([0, 1, 2, 3])

    # Calibration: feed basis so H is diagonal -> after whitening, identity
    X = build_basis_calibration(in_features, repeats=1)
    pruner.add_batch(X)
    pruner.finalize_calibration(min_damping=1e-2, max_damping=10.0)

    # Snapshot pre-prune weights
    W_before = layer.weight.detach().clone()

    # Compute importance and prune lowest
    scores = pruner.importance_all(return_tensor=True)
    pruned_block, pruned_score = pruner.prune_lowest()

    # Expected lowest block is block 2 (indices 4 and 5)
    assert pruned_block == 2, f"Expected to prune block 2, got {pruned_block}"

    # Verify pruned columns are zeroed
    with torch.no_grad():
        pruned_cols = slice(pruned_block * block_size, (pruned_block + 1) * block_size)
        assert torch.allclose(layer.weight[:, pruned_cols], torch.zeros_like(layer.weight[:, pruned_cols])), \
            "Pruned block columns are not zeroed"

    # Because H becomes identity and off-diagonals are zero, ΔW should be ~0 for survivors
    with torch.no_grad():
        survivors = list(range(in_features))
        for c in range(pruned_block * block_size, (pruned_block + 1) * block_size):
            survivors.remove(c)
        assert torch.allclose(layer.weight[:, survivors], W_before[:, survivors]), \
            "Non-pruned columns changed unexpectedly with identity Hessian"

    print("Test passed: pruned lowest-norm block and zeroed correct columns with stable survivors.")


if __name__ == "__main__":
    test_prune_lowest_block_linear_identity_hessian()
