# tests/test_halpe.py
import math
import torch
import pytest
from utils.layer_utils import TransformerConfig, TransformerLayerSchema, LayerType, LayerSpec
from prune.halpe import TransformerHALPE

# --- Dummy config and schema using real classes ---
dummy_config = TransformerConfig(
    layer_type=LayerType.transformer,
    hidden_size=8,
    head_size=4,
    num_heads=2,
    intermediate_dimension=4
)

dummy_schema = TransformerLayerSchema(
    layer_name="DummyBlock",
    layer_type=LayerType.transformer,
    norm_type="pre_norm",
    layers={
        "q": LayerSpec("attn", "q", "row"),
        "k": LayerSpec("attn", "k", "row"),
        "v": LayerSpec("attn", "v", "row"),
        "o": LayerSpec("attn", "o", "column"),
        "fc1": LayerSpec("ffn", "fc1", "row"),
        "fc2": LayerSpec("ffn", "fc2", "column")
    }
)

class DummyBlock(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = torch.nn.ModuleDict({
            "q": torch.nn.Linear(8, 8, bias=False),
            "k": torch.nn.Linear(8, 8, bias=False),
            "v": torch.nn.Linear(8, 8, bias=False),
            "o": torch.nn.Linear(8, 8, bias=False),
        })
        self.ffn = torch.nn.ModuleDict({
            "fc1": torch.nn.Linear(8, 4, bias=False),
            "fc2": torch.nn.Linear(4, 8, bias=False),
        })
    def forward(self, x):
        x = self.attn["q"](x) + self.attn["k"](x) + self.attn["v"](x)
        x = self.attn["o"](x)
        x = self.ffn["fc1"](x)
        x = self.ffn["fc2"](x)
        return x

@pytest.fixture
def halpe():
    block = DummyBlock()
    import logging
    logger = logging.getLogger(__name__)
    return TransformerHALPE(block, dummy_config, dummy_schema, logger)

# ---------------------------------------------------------------------
# 3.  Tests
# ---------------------------------------------------------------------
def test_local_sensitivity_trace_of_hessian(halpe):
    """Test that local_sensitivity computes the trace of the Hessian (mean squared activation) as implemented in HALPE."""
    x = torch.arange(32.0).reshape(4, 8)
    halpe.add_batch(x)
    halpe.compute_local_sensitivity()
    expected = x.pow(2).sum().item() / x.shape[0]
    assert math.isclose(halpe.local_sensitivity, expected, rel_tol=1e-6)

def test_initial_importances_shape_and_values(halpe):
    """get_initial_importances should concatenate [importance, layer_idx, block_idx] per block."""
    total_blocks = halpe.layer_config.num_heads + halpe.layer_config.intermediate_dimension
    halpe.initialize_pruners()  # Initialize pruners first
    halpe.set_candidate_blocks(torch.arange(total_blocks))
    out = halpe.get_initial_importances(layer_idx=7)   # arbitrary 7

    n_blocks = halpe.layer_config.num_heads + halpe.layer_config.intermediate_dimension
    assert out.shape == (n_blocks, 3)

    # importance column must be finite
    assert torch.isfinite(out[:, 0]).all()
    # second col must be the constant layer_idx
    assert torch.all(out[:, 1] == 7)
    # third column must be ascending 0 … n_blocks-1
    torch.testing.assert_close(out[:, 2], torch.arange(n_blocks, dtype=torch.float32))

def test_pruned_block_mapping(halpe):
    """Check splitting flat indices into head / ffn sub-lists."""
    total_blocks = halpe.layer_config.num_heads + halpe.layer_config.intermediate_dimension
    halpe.initialize_pruners()  # Initialize pruners first
    halpe.set_candidate_blocks(torch.arange(total_blocks))
    # heads are 0,1 ; ffn start at 2
    flat = torch.tensor([0, 1, 2, 3])
    out  = halpe.pruned_blocks_indices_per_layer(flat)
    assert out == {"head": [0, 1], "ffn": [0, 1]}

def test_compute_exact_importances_stacking(halpe):
    num_heads = halpe.layer_config.num_heads
    intermediate_dimension = halpe.layer_config.intermediate_dimension
    head_blocks = torch.arange(num_heads)
    ffn_blocks = torch.arange(intermediate_dimension) + num_heads
    all_blocks = torch.cat([head_blocks, ffn_blocks])
    halpe.initialize_pruners()  # Initialize pruners first
    halpe.set_candidate_blocks(all_blocks)
    calibration_data = torch.randn(10, 8)
    halpe.add_batch(calibration_data) # For local sensitivity
    # Explicitly register pruner hooks
    for pruner in halpe.pruners.values():
        pruner.register_hook()
    with torch.no_grad():
        halpe.layer(calibration_data) # triggers pruner hooks
    # Remove pruner hooks after calibration
    for pruner in halpe.pruners.values():
        pruner.remove_hook()
    
    # Direct pruner calibration
    halpe.pruners['head'].add_batch(torch.randn(4, 8))
    halpe.pruners['ffn'].add_batch(torch.randn(4, 4))
    
    # Do NOT call finalize_calibration here; let compute_exact_importances handle it
    out = halpe.compute_exact_importances(layer_idx=3)
    n_blocks = num_heads + intermediate_dimension
    assert out.shape == (n_blocks, 3)
    # Second column is layer_idx
    assert torch.all(out[:, 1] == 3)
    # Third column should contain the actual candidate block indices
    # Head blocks come first (0, 1), then FFN blocks (2, 3, 4, 5)
    expected_block_indices = torch.cat([torch.arange(num_heads), torch.arange(intermediate_dimension) + num_heads]).float()
    print(f"[DEBUG] Expected block indices: {expected_block_indices}")
    print(f"[DEBUG] Actual block indices: {out[:, 2]}")
    print(f"[DEBUG] Actual block indices (int): {out[:, 2].int()}")
    # The actual implementation uses the candidate block indices from the pruners
    # which should be the same as our expected indices
    torch.testing.assert_close(out[:, 2], expected_block_indices)

def test_global_sensitivity_update(halpe):
    """global = local + α · global_parent"""
    halpe.local_sensitivity = 5.0
    halpe.set_global_sensitivity(global_sensitivity=10.0, alpha=0.3)
    assert math.isclose(halpe.get_global_sensitivity(), 5.0 + 0.3 * 10.0)

def test_transformerhalpe_end_to_end():
    import torch
    from utils.layer_utils import TransformerConfig, TransformerLayerSchema, LayerType, LayerSpec
    from prune.halpe import TransformerHALPE

    class SimpleTransformerBlock(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.attn = torch.nn.ModuleDict({
                "q": torch.nn.Linear(8, 8, bias=False),
                "k": torch.nn.Linear(8, 8, bias=False),
                "v": torch.nn.Linear(8, 8, bias=False),
                "o": torch.nn.Linear(8, 8, bias=False),
            })
            self.ffn = torch.nn.ModuleDict({
                "fc1": torch.nn.Linear(8, 4, bias=False),
                "fc2": torch.nn.Linear(4, 8, bias=False),
            })
        def forward(self, x):
            x = self.attn["q"](x) + self.attn["k"](x) + self.attn["v"](x)
            x = self.attn["o"](x)
            x = self.ffn["fc1"](x)
            x = self.ffn["fc2"](x)
            return x

    config = TransformerConfig(
        layer_type=LayerType.transformer,
        hidden_size=8,
        head_size=4,
        num_heads=2,
        intermediate_dimension=4
    )
    schema = TransformerLayerSchema(
        layer_name="SimpleTransformerBlock",
        layer_type=LayerType.transformer,
        norm_type="pre_norm",
        layers={
            "q": LayerSpec("attn", "q", "row"),
            "k": LayerSpec("attn", "k", "row"),
            "v": LayerSpec("attn", "v", "row"),
            "o": LayerSpec("attn", "o", "column"),
            "fc1": LayerSpec("ffn", "fc1", "row"),
            "fc2": LayerSpec("ffn", "fc2", "column")
        }
    )
    block = SimpleTransformerBlock()

    torch.manual_seed(0)
    calibration_data = torch.randn(10, 8)

    import logging
    logger = logging.getLogger(__name__)
    halpe = TransformerHALPE(block, config, schema, logger)

    num_heads = config.num_heads
    intermediate_dimension = config.intermediate_dimension
    # Always use global block indices for candidate blocks
    all_blocks = torch.arange(num_heads + intermediate_dimension)
    halpe.initialize_pruners()  # Initialize pruners first
    halpe.set_candidate_blocks(all_blocks)
    calibration_data = torch.randn(10, 8)
    halpe.add_batch(calibration_data)
    # Direct pruner calibration
    halpe.pruners['head'].add_batch(torch.randn(10, 8))
    halpe.pruners['ffn'].add_batch(torch.randn(10, 4))

    initial_importances = halpe.get_initial_importances(layer_idx=0)
    assert initial_importances.shape[1] == 3
    assert initial_importances.shape[0] == num_heads + intermediate_dimension
    assert torch.isfinite(initial_importances[:, 0]).all()

    halpe.compute_local_sensitivity()
    local_sens = halpe.local_sensitivity
    assert local_sens > 0
    halpe.set_global_sensitivity(global_sensitivity=2.0, alpha=0.5)
    global_sens = halpe.get_global_sensitivity()
    assert global_sens > local_sens

    exact_importances = halpe.compute_exact_importances(layer_idx=0)
    assert exact_importances.shape[0] == num_heads + intermediate_dimension
    assert torch.isfinite(exact_importances[:, 0]).all()

    prune_blocks = torch.tensor([0])  # Only prune the first head
    halpe.prune(prune_blocks)
    updated_config = halpe.get_updated_configs()
    assert updated_config.num_heads == num_heads - 1  # Should be 1

    pruned_blocks = halpe.get_pruned_blocks()
    assert len(pruned_blocks['head']) == 1
    # assert num_heads in pruned_blocks['ffn'] or num_heads in pruned_blocks['head']

    mapping = halpe.pruned_blocks_indices_per_layer(prune_blocks)
    assert set(mapping.keys()) == {'head', 'ffn'}
    assert 0 in mapping['head'] or 0 in mapping['ffn']

def test_pruner_hook_and_finalize_behavior(halpe):
    """Test pruner hook registration/removal and finalize_calibration idempotency."""
    # Setup candidate blocks and calibration data
    total_blocks = halpe.layer_config.num_heads + halpe.layer_config.intermediate_dimension
    halpe.initialize_pruners()  # Initialize pruners first
    halpe.set_candidate_blocks(torch.arange(total_blocks))
    calibration_data = torch.randn(4, 8)
    halpe.add_batch(calibration_data)
    # Direct pruner calibration
    halpe.pruners['head'].add_batch(torch.randn(4, 8))
    halpe.pruners['ffn'].add_batch(torch.randn(4, 4))
    for pruner in halpe.pruners.values():
        # finalize_calibration should be idempotent or raise a specific error
        pruner.finalize_calibration()
        try:
            pruner.finalize_calibration()
        except Exception as e:
            # Acceptable if it raises a specific error, but not a crash
            assert isinstance(e, (RuntimeError, ValueError, AssertionError)), f"Unexpected error: {e}"

def test_block_index_mapping_halpe(halpe):
    """Test global-to-local and local-to-global block index mappings in TransformerHALPE."""
    num_heads = halpe.layer_config.num_heads
    intermediate_dimension = halpe.layer_config.intermediate_dimension
    total_blocks = num_heads + intermediate_dimension
    # Set candidate blocks: heads 0,1 and FFN 2,3,4,5 (global indices)
    all_blocks = torch.arange(total_blocks)
    halpe.initialize_pruners()  # Initialize pruners first
    halpe.set_candidate_blocks(all_blocks)
    calibration_data = torch.randn(4, 8)
    halpe.add_batch(calibration_data)
    # Direct pruner calibration
    halpe.pruners['head'].add_batch(torch.randn(4, 8))
    halpe.pruners['ffn'].add_batch(torch.randn(4, 4))
    # Test mapping from global to local
    flat = torch.tensor([0, 1, 2, 3, 4, 5])
    mapping = halpe.pruned_blocks_indices_per_layer(flat)
    assert mapping['head'] == [0, 1]
    assert mapping['ffn'] == [0, 1, 2, 3]
    # Prune a head and an FFN block
    prune_blocks = torch.tensor([1, 3])  # global: head 1, ffn 1 (local index 1)
    halpe.pruners['head'].finalize_calibration()
    halpe.pruners['ffn'].finalize_calibration()
    halpe.prune(prune_blocks)
    pruned = halpe.get_pruned_blocks()
    # get_pruned_blocks should return global indices for both
    assert 1 in pruned['head']
    # For FFN, pruned['ffn'] should contain one pruned index in the valid range
    # (The pruning logic prunes 1 block per pruner, not 2)
    assert all(idx in range(num_heads, num_heads + intermediate_dimension) for idx in pruned['ffn'])
    assert len(pruned['ffn']) == 1  # Only 1 FFN block is pruned per call


# ---------------------------------------------------------------------
# Candidate Block Selection Indexing Tests
# ---------------------------------------------------------------------

def test_get_initial_importances_global_indexing(halpe):
    """Test that get_initial_importances returns correct global indices."""
    result = halpe.get_initial_importances(layer_idx=0)
    
    # Check shape: [total_blocks, 3] where total_blocks = num_heads + intermediate_dimension
    expected_total_blocks = halpe.layer_config.num_heads + halpe.layer_config.intermediate_dimension
    assert result.shape == (expected_total_blocks, 3)
    
    # Check layer indices are all 0
    assert torch.all(result[:, 1] == 0)
    
    # Check block indices are sequential from 0 to total_blocks-1
    expected_block_indices = torch.arange(expected_total_blocks)
    assert torch.all(result[:, 2] == expected_block_indices)
    
    # Check that head blocks come first (indices 0 to num_heads-1)
    head_block_indices = result[:halpe.layer_config.num_heads, 2]
    assert torch.all(head_block_indices == torch.arange(halpe.layer_config.num_heads))
    
    # Check that FFN blocks come second (indices num_heads to total_blocks-1)
    ffn_block_indices = result[halpe.layer_config.num_heads:, 2]
    expected_ffn_indices = torch.arange(halpe.layer_config.num_heads, expected_total_blocks)
    assert torch.all(ffn_block_indices == expected_ffn_indices)


def test_pruned_blocks_indices_per_layer_mapping(halpe):
    """Test that pruned_blocks_indices_per_layer correctly maps global to local indices."""
    # Test with mixed head and FFN indices
    global_indices = torch.tensor([1, 3, 8, 10])  # head: 1; FFN: 3,8,10
    
    result = halpe.pruned_blocks_indices_per_layer(global_indices)
    
    # Head indices should remain the same (they're already local)
    assert result['head'] == [1]
    
    # FFN indices should be converted to local (subtract num_heads)
    expected_ffn_local = [3 - halpe.layer_config.num_heads, 8 - halpe.layer_config.num_heads, 10 - halpe.layer_config.num_heads]
    assert result['ffn'] == expected_ffn_local


def test_pruned_blocks_indices_per_layer_all_heads(halpe):
    """Test mapping when all indices are head indices."""
    global_indices = torch.tensor([0, 1])  # All < num_heads (2)
    
    result = halpe.pruned_blocks_indices_per_layer(global_indices)
    
    assert result['head'] == [0, 1]
    assert result['ffn'] == []


def test_pruned_blocks_indices_per_layer_all_ffn(halpe):
    """Test mapping when all indices are FFN indices."""
    num_heads = halpe.layer_config.num_heads
    global_indices = torch.tensor([num_heads, num_heads + 1, num_heads + 3])  # All >= num_heads
    
    result = halpe.pruned_blocks_indices_per_layer(global_indices)
    
    assert result['head'] == []
    expected_ffn_local = [0, 1, 3]  # num_heads - num_heads = 0, etc.
    assert result['ffn'] == expected_ffn_local


def test_layer_importances_filtering():
    """Test that layer-specific filtering preserves block indices."""
    # Create mock initial importances for multiple layers
    initial_importances = torch.tensor([
        [1.0, 0, 0],  # layer 0, block 0
        [2.0, 0, 1],  # layer 0, block 1
        [3.0, 0, 2],  # layer 0, block 2
        [4.0, 1, 0],  # layer 1, block 0
        [5.0, 1, 1],  # layer 1, block 1
        [6.0, 2, 0],  # layer 2, block 0
    ])
    
    # Filter for layer 0
    layer_0_importances = initial_importances[initial_importances[:, 1] == 0]
    
    # Check that we get the right rows
    assert layer_0_importances.shape[0] == 3
    assert torch.all(layer_0_importances[:, 1] == 0)  # All layer indices are 0
    
    # Check that block indices are preserved
    expected_block_indices = torch.tensor([0, 1, 2])
    assert torch.all(layer_0_importances[:, 2] == expected_block_indices)


def test_candidate_block_selection_logic(halpe):
    """Test the complete candidate block selection logic."""
    # Create mock initial importances for a single layer
    layer_importances = torch.tensor([
        [5.0, 0, 0],  # importance=5.0, layer=0, block=0 (head)
        [1.0, 0, 1],  # importance=1.0, layer=0, block=1 (head) - least important
        [3.0, 0, 2],  # importance=3.0, layer=0, block=2 (head)
        [4.0, 0, 8],  # importance=4.0, layer=0, block=8 (FFN)
        [2.0, 0, 9],  # importance=2.0, layer=0, block=9 (FFN) - second least important
    ])
    
    # Sort by importance (ascending: least important first)
    sorted_indices = layer_importances[:, 0].argsort()
    block_indices = layer_importances[sorted_indices, 2]
    
    # Select top 2 candidates (least important)
    candidate_blocks = block_indices[:2]
    
    # Expected: blocks 1 and 9 (least important)
    expected_candidates = torch.tensor([1, 9])
    assert torch.all(candidate_blocks == expected_candidates)
    
    # Test mapping to local indices (simulating HALPE logic)
    num_heads = halpe.layer_config.num_heads
    head_indices = candidate_blocks[candidate_blocks < num_heads]
    ffn_indices = candidate_blocks[candidate_blocks >= num_heads] - num_heads
    
    assert torch.all(head_indices == torch.tensor([1]))  # block 1 is a head
    assert torch.all(ffn_indices == torch.tensor([7]))   # block 9 -> local index 7 (9-2)


def test_end_to_end_indexing_consistency(halpe):
    """Test that indices remain consistent throughout the entire pipeline."""
    # Step 1: Get initial importances
    initial_importances = halpe.get_initial_importances(layer_idx=0)
    
    # Step 2: Simulate candidate selection (select first 3 blocks)
    candidate_blocks = initial_importances[:3, 2]  # Get block indices of first 3 blocks
    assert torch.all(candidate_blocks == torch.tensor([0, 1, 2]))
    
    # Step 3: Map to local indices
    local_mapping = halpe.pruned_blocks_indices_per_layer(candidate_blocks)
    
    # With num_heads=2, indices 0,1 are heads, index 2 is FFN
    assert local_mapping['head'] == [0, 1]
    assert local_mapping['ffn'] == [0]  # 2 - 2 = 0
    
    # Step 4: Test with mixed head and FFN blocks
    mixed_candidates = torch.tensor([1, 3, 8, 10])  # head: 1; FFN: 3,8,10
    mixed_mapping = halpe.pruned_blocks_indices_per_layer(mixed_candidates)
    
    assert mixed_mapping['head'] == [1]
    assert mixed_mapping['ffn'] == [1, 6, 8]  # 3-2=1, 8-2=6, 10-2=8


def test_edge_cases(halpe):
    """Test edge cases and boundary conditions."""
    # Test empty indices
    empty_result = halpe.pruned_blocks_indices_per_layer(torch.tensor([]))
    assert empty_result['head'] == []
    assert empty_result['ffn'] == []
    
    # Test boundary indices
    num_heads = halpe.layer_config.num_heads
    boundary_indices = torch.tensor([num_heads - 1, num_heads])  # Last head, first FFN
    boundary_result = halpe.pruned_blocks_indices_per_layer(boundary_indices)
    
    assert boundary_result['head'] == [num_heads - 1]
    assert boundary_result['ffn'] == [0]  # num_heads - num_heads = 0
    
    # Test invalid indices (should still work)
    invalid_indices = torch.tensor([100, 200])  # Way beyond total blocks
    invalid_result = halpe.pruned_blocks_indices_per_layer(invalid_indices)
    
    # All should be treated as FFN blocks
    assert invalid_result['head'] == []
    assert len(invalid_result['ffn']) == 2