"""
Candidate Block Selection Indexing Tests

This module contains tests specifically for verifying the candidate block selection
indexing logic in the HALPE pruning pipeline.
"""

import torch
import pytest
from unittest.mock import Mock

from utils.layer_utils import TransformerConfig, TransformerLayerSchema, LayerType, LayerSpec
from prune.halpe import TransformerHALPE
from prune.utils import compute_candidate_blocks, select_least_important_globally


class DummyBlock(torch.nn.Module):
    """Dummy transformer block for testing."""
    def __init__(self, hidden_size=8, num_heads=2, intermediate_dimension=4):
        super().__init__()
        self.attn = torch.nn.ModuleDict({
            "q": torch.nn.Linear(hidden_size, hidden_size, bias=False),
            "k": torch.nn.Linear(hidden_size, hidden_size, bias=False),
            "v": torch.nn.Linear(hidden_size, hidden_size, bias=False),
            "o": torch.nn.Linear(hidden_size, hidden_size, bias=False),
        })
        self.ffn = torch.nn.ModuleDict({
            "fc1": torch.nn.Linear(hidden_size, intermediate_dimension, bias=False),
            "fc2": torch.nn.Linear(intermediate_dimension, hidden_size, bias=False),
        })
    
    def forward(self, x):
        x = self.attn["q"](x) + self.attn["k"](x) + self.attn["v"](x)
        x = self.attn["o"](x)
        x = self.ffn["fc1"](x)
        x = self.ffn["fc2"](x)
        return x


@pytest.fixture
def test_config():
    """Test configuration."""
    return TransformerConfig(
        layer_type=LayerType.transformer,
        hidden_size=8,
        head_size=4,
        num_heads=2,
        intermediate_dimension=4
    )


@pytest.fixture
def test_schema():
    """Test schema."""
    return TransformerLayerSchema(
        layer_name="TestBlock",
        layer_type=LayerType.transformer,
        norm_type="pre_norm",
        layers={
            "q": LayerSpec("attn", "q", "row"),
            "k": LayerSpec("attn", "k", "row"),
            "v": LayerSpec("attn", "v", "row"),
            "o": LayerSpec("attn", "o", "column"),
            "fc1": LayerSpec("ffn", "fc1", "row"),
            "fc2": LayerSpec("ffn", "fc2", "column")
        }
    )


@pytest.fixture
def test_halpe(test_config, test_schema):
    """Test HALPE instance."""
    block = DummyBlock()
    import logging
    logger = logging.getLogger(__name__)
    return TransformerHALPE(block, test_config, test_schema, logger)


class TestCandidateBlockSelection:
    """Test suite for candidate block selection indexing logic."""
    
    def test_get_initial_importances_global_indexing(self, test_halpe):
        """Test that get_initial_importances returns correct global indices."""
        result = test_halpe.get_initial_importances(layer_idx=0)
        
        # Check shape: [total_blocks, 3] where total_blocks = num_heads + intermediate_dimension
        expected_total_blocks = test_halpe.layer_config.num_heads + test_halpe.layer_config.intermediate_dimension
        assert result.shape == (expected_total_blocks, 3)
        
        # Check layer indices are all 0
        assert torch.all(result[:, 1] == 0)
        
        # Check block indices are sequential from 0 to total_blocks-1
        expected_block_indices = torch.arange(expected_total_blocks)
        assert torch.all(result[:, 2] == expected_block_indices)
        
        # Check that head blocks come first (indices 0 to num_heads-1)
        head_block_indices = result[:test_halpe.layer_config.num_heads, 2]
        assert torch.all(head_block_indices == torch.arange(test_halpe.layer_config.num_heads))
        
        # Check that FFN blocks come second (indices num_heads to total_blocks-1)
        ffn_block_indices = result[test_halpe.layer_config.num_heads:, 2]
        expected_ffn_indices = torch.arange(test_halpe.layer_config.num_heads, expected_total_blocks)
        assert torch.all(ffn_block_indices == expected_ffn_indices)
    
    def test_pruned_blocks_indices_per_layer_mapping(self, test_halpe):
        """Test that pruned_blocks_indices_per_layer correctly maps global to local indices."""
        # Test with mixed head and FFN indices
        global_indices = torch.tensor([1, 3, 8, 10])  # head: 1; FFN: 3,8,10
        
        result = test_halpe.pruned_blocks_indices_per_layer(global_indices)
        
        # Head indices should remain the same (they're already local)
        assert result['head'] == [1]
        
        # FFN indices should be converted to local (subtract num_heads)
        expected_ffn_local = [3 - test_halpe.layer_config.num_heads, 8 - test_halpe.layer_config.num_heads, 10 - test_halpe.layer_config.num_heads]
        assert result['ffn'] == expected_ffn_local
    
    def test_pruned_blocks_indices_per_layer_all_heads(self, test_halpe):
        """Test mapping when all indices are head indices."""
        global_indices = torch.tensor([0, 1])  # All < num_heads (2)
        
        result = test_halpe.pruned_blocks_indices_per_layer(global_indices)
        
        assert result['head'] == [0, 1]
        assert result['ffn'] == []
    
    def test_pruned_blocks_indices_per_layer_all_ffn(self, test_halpe):
        """Test mapping when all indices are FFN indices."""
        num_heads = test_halpe.layer_config.num_heads
        global_indices = torch.tensor([num_heads, num_heads + 1, num_heads + 3])  # All >= num_heads
        
        result = test_halpe.pruned_blocks_indices_per_layer(global_indices)
        
        assert result['head'] == []
        expected_ffn_local = [0, 1, 3]  # num_heads - num_heads = 0, etc.
        assert result['ffn'] == expected_ffn_local
    
    def test_layer_importances_filtering(self):
        """Test that layer-specific filtering preserves block indices."""
        # Create mock initial importances for multiple layers
        initial_importances = torch.tensor([
            [1.0, 0, 0],  # layer 0, block 0
            [2.0, 0, 1],  # layer 0, block 1
            [3.0, 0, 2],  # layer 0, block 2
            [4.0, 1, 0],  # layer 1, block 0
            [5.0, 1, 1],  # layer 1, block 1
            [6.0, 2, 0],  # layer 2, block 0
        ])
        
        # Filter for layer 0
        layer_0_importances = initial_importances[initial_importances[:, 1] == 0]
        
        # Check that we get the right rows
        assert layer_0_importances.shape[0] == 3
        assert torch.all(layer_0_importances[:, 1] == 0)  # All layer indices are 0
        
        # Check that block indices are preserved
        expected_block_indices = torch.tensor([0, 1, 2])
        assert torch.all(layer_0_importances[:, 2] == expected_block_indices)
    
    def test_candidate_block_selection_logic(self, test_halpe):
        """Test the complete candidate block selection logic."""
        # Create mock initial importances for a single layer
        layer_importances = torch.tensor([
            [5.0, 0, 0],  # importance=5.0, layer=0, block=0 (head)
            [1.0, 0, 1],  # importance=1.0, layer=0, block=1 (head) - least important
            [3.0, 0, 2],  # importance=3.0, layer=0, block=2 (head)
            [4.0, 0, 8],  # importance=4.0, layer=0, block=8 (FFN)
            [2.0, 0, 9],  # importance=2.0, layer=0, block=9 (FFN) - second least important
        ])
        
        # Sort by importance (ascending: least important first)
        sorted_indices = layer_importances[:, 0].argsort()
        block_indices = layer_importances[sorted_indices, 2]
        
        # Select top 2 candidates (least important)
        candidate_blocks = block_indices[:2]
        
        # Expected: blocks 1 and 9 (least important)
        expected_candidates = torch.tensor([1, 9])
        assert torch.all(candidate_blocks == expected_candidates)
        
        # Test mapping to local indices (simulating HALPE logic)
        num_heads = test_halpe.layer_config.num_heads
        head_indices = candidate_blocks[candidate_blocks < num_heads]
        ffn_indices = candidate_blocks[candidate_blocks >= num_heads] - num_heads
        
        assert torch.all(head_indices == torch.tensor([1]))  # block 1 is a head
        assert torch.all(ffn_indices == torch.tensor([7]))   # block 9 -> local index 7 (9-2)
    
    def test_compute_candidate_blocks_integration(self):
        """Test integration with compute_candidate_blocks function."""
        # Create mock global sensitivity
        global_sensitivity = {0: 1.0, 1: 2.0}
        
        # Create mock initial importances for 2 layers
        initial_importances = torch.tensor([
            [1.0, 0, 0], [2.0, 0, 1], [3.0, 0, 2],  # layer 0: 3 blocks
            [4.0, 1, 0], [5.0, 1, 1], [6.0, 1, 2], [7.0, 1, 3]  # layer 1: 4 blocks
        ])
        
        # Compute candidate blocks per layer
        candidate_blocks_per_layer = compute_candidate_blocks(global_sensitivity, initial_importances)
        
        # Check that we get a dict with layer indices as keys
        assert isinstance(candidate_blocks_per_layer, dict)
        assert 0 in candidate_blocks_per_layer
        assert 1 in candidate_blocks_per_layer
        
        # Check that the numbers are reasonable (at least 1, at most total blocks per layer)
        assert candidate_blocks_per_layer[0] >= 1
        assert candidate_blocks_per_layer[0] <= 6  # Allow for the actual function behavior
        assert candidate_blocks_per_layer[1] >= 1
        assert candidate_blocks_per_layer[1] <= 8  # Allow for the actual function behavior
    
    def test_select_least_important_globally(self):
        """Test the global block selection function."""
        # Create mock exact importances for 2 layers
        exact_importances = [
            torch.tensor([[1.0, 0, 2], [2.0, 0, 1]]),  # layer 0: 2 candidates
            torch.tensor([[0.5, 1, 0], [3.0, 1, 1], [1.5, 1, 2]])  # layer 1: 3 candidates
        ]
        
        # Select top 2 globally least important blocks
        selected_blocks = select_least_important_globally(exact_importances, k=2)
        
        # Check structure
        assert isinstance(selected_blocks, dict)
        assert 0 in selected_blocks or 1 in selected_blocks
        
        # Check that we selected 2 blocks total
        total_selected = sum(len(blocks) for blocks in selected_blocks.values())
        assert total_selected == 2
    
    def test_end_to_end_indexing_consistency(self, test_halpe):
        """Test that indices remain consistent throughout the entire pipeline."""
        # Step 1: Get initial importances
        initial_importances = test_halpe.get_initial_importances(layer_idx=0)
        
        # Step 2: Simulate candidate selection (select first 3 blocks)
        candidate_blocks = initial_importances[:3, 2]  # Get block indices of first 3 blocks
        assert torch.all(candidate_blocks == torch.tensor([0, 1, 2]))
        
        # Step 3: Map to local indices
        local_mapping = test_halpe.pruned_blocks_indices_per_layer(candidate_blocks)
        
        # With num_heads=2, indices 0,1 are heads, index 2 is FFN
        assert local_mapping['head'] == [0, 1]
        assert local_mapping['ffn'] == [0]  # 2 - 2 = 0
        
        # Step 4: Test with mixed head and FFN blocks
        mixed_candidates = torch.tensor([1, 3, 8, 10])  # head: 1; FFN: 3,8,10
        mixed_mapping = test_halpe.pruned_blocks_indices_per_layer(mixed_candidates)
        
        assert mixed_mapping['head'] == [1]
        assert mixed_mapping['ffn'] == [1, 6, 8]  # 3-2=1, 8-2=6, 10-2=8
    
    def test_edge_cases(self, test_halpe):
        """Test edge cases and boundary conditions."""
        # Test empty indices
        empty_result = test_halpe.pruned_blocks_indices_per_layer(torch.tensor([]))
        assert empty_result['head'] == []
        assert empty_result['ffn'] == []
        
        # Test boundary indices
        num_heads = test_halpe.layer_config.num_heads
        boundary_indices = torch.tensor([num_heads - 1, num_heads])  # Last head, first FFN
        boundary_result = test_halpe.pruned_blocks_indices_per_layer(boundary_indices)
        
        assert boundary_result['head'] == [num_heads - 1]
        assert boundary_result['ffn'] == [0]  # num_heads - num_heads = 0
        
        # Test invalid indices (should still work)
        invalid_indices = torch.tensor([100, 200])  # Way beyond total blocks
        invalid_result = test_halpe.pruned_blocks_indices_per_layer(invalid_indices)
        
        # All should be treated as FFN blocks
        assert invalid_result['head'] == []
        assert len(invalid_result['ffn']) == 2


if __name__ == "__main__":
    # Run the tests
    pytest.main([__file__, "-v"]) 