#!/usr/bin/env python3
"""
Simple test runner for candidate block selection indexing tests.
This script can be run without pytest to verify the indexing logic.
"""

import torch
import sys
import os

# Add the project root to the path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

def test_initial_importances_indexing():
    """Test that get_initial_importances returns correct global indices."""
    print("Testing get_initial_importances indexing...")
    
    from unittest.mock import Mock
    from utils.layer_utils import TransformerConfig, TransformerLayerSchema, LayerType, LayerSpec
    from prune.halpe import TransformerHALPE
    
    # Create test config
    config = TransformerConfig(
        layer_type=LayerType.transformer,
        hidden_size=8,
        head_size=4,
        num_heads=2,
        intermediate_dimension=4
    )
    
    # Create test schema
    schema = TransformerLayerSchema(
        layer_name="TestBlock",
        layer_type=LayerType.transformer,
        norm_type="pre_norm",
        layers={
            "q": LayerSpec("attn", "q", "row"),
            "k": LayerSpec("attn", "k", "row"),
            "v": LayerSpec("attn", "v", "row"),
            "o": LayerSpec("attn", "o", "column"),
            "fc1": LayerSpec("ffn", "fc1", "row"),
            "fc2": LayerSpec("ffn", "fc2", "column")
        }
    )
    
    # Create dummy block
    class DummyBlock(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.attn = torch.nn.ModuleDict({
                "q": torch.nn.Linear(8, 8, bias=False),
                "k": torch.nn.Linear(8, 8, bias=False),
                "v": torch.nn.Linear(8, 8, bias=False),
                "o": torch.nn.Linear(8, 8, bias=False),
            })
            self.ffn = torch.nn.ModuleDict({
                "fc1": torch.nn.Linear(8, 4, bias=False),
                "fc2": torch.nn.Linear(4, 8, bias=False),
            })
        
        def forward(self, x):
            x = self.attn["q"](x) + self.attn["k"](x) + self.attn["v"](x)
            x = self.attn["o"](x)
            x = self.ffn["fc1"](x)
            x = self.ffn["fc2"](x)
            return x
    
    # Create HALPE instance
    import logging
    logger = logging.getLogger(__name__)
    halpe = TransformerHALPE(DummyBlock(), config, schema, logger, device="cpu")
    
    # Mock the compute methods
    with torch.no_grad():
        # Mock head magnitudes (2 heads)
        halpe.compute_head_magnitudes = lambda: torch.tensor([1.0, 2.0])
        # Mock FFN magnitudes (4 units)
        halpe.compute_ffn_magnitudes = lambda: torch.tensor([10.0, 20.0, 30.0, 40.0])
        
        result = halpe.get_initial_importances(layer_idx=0)
        
        # Check shape
        expected_total_blocks = config.num_heads + config.intermediate_dimension
        assert result.shape == (expected_total_blocks, 3), f"Expected shape ({expected_total_blocks}, 3), got {result.shape}"
        
        # Check layer indices
        assert torch.all(result[:, 1] == 0), "All layer indices should be 0"
        
        # Check block indices are sequential
        expected_block_indices = torch.arange(expected_total_blocks)
        assert torch.all(result[:, 2] == expected_block_indices), "Block indices should be sequential"
        
        # Check head blocks come first
        head_block_indices = result[:config.num_heads, 2]
        expected_head_indices = torch.arange(config.num_heads)
        assert torch.all(head_block_indices == expected_head_indices), "Head blocks should have indices 0 to num_heads-1"
        
        # Check FFN blocks come second
        ffn_block_indices = result[config.num_heads:, 2]
        expected_ffn_indices = torch.arange(config.num_heads, expected_total_blocks)
        assert torch.all(ffn_block_indices == expected_ffn_indices), "FFN blocks should have indices num_heads to total_blocks-1"
        
        print("✓ get_initial_importances indexing test passed")
        return True

def test_index_mapping():
    """Test that pruned_blocks_indices_per_layer correctly maps global to local indices."""
    print("Testing index mapping...")
    
    from unittest.mock import Mock
    from utils.layer_utils import TransformerConfig, TransformerLayerSchema, LayerType, LayerSpec
    from prune.halpe import TransformerHALPE
    
    # Create test config
    config = TransformerConfig(
        layer_type=LayerType.transformer,
        hidden_size=8,
        head_size=4,
        num_heads=2,
        intermediate_dimension=4
    )
    
    # Create test schema
    schema = TransformerLayerSchema(
        layer_name="TestBlock",
        layer_type=LayerType.transformer,
        norm_type="pre_norm",
        layers={
            "q": LayerSpec("attn", "q", "row"),
            "k": LayerSpec("attn", "k", "row"),
            "v": LayerSpec("attn", "v", "row"),
            "o": LayerSpec("attn", "o", "column"),
            "fc1": LayerSpec("ffn", "fc1", "row"),
            "fc2": LayerSpec("ffn", "fc2", "column")
        }
    )
    
    # Create dummy block
    class DummyBlock(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.attn = torch.nn.ModuleDict({
                "q": torch.nn.Linear(8, 8, bias=False),
                "k": torch.nn.Linear(8, 8, bias=False),
                "v": torch.nn.Linear(8, 8, bias=False),
                "o": torch.nn.Linear(8, 8, bias=False),
            })
            self.ffn = torch.nn.ModuleDict({
                "fc1": torch.nn.Linear(8, 4, bias=False),
                "fc2": torch.nn.Linear(4, 8, bias=False),
            })
        
        def forward(self, x):
            x = self.attn["q"](x) + self.attn["k"](x) + self.attn["v"](x)
            x = self.attn["o"](x)
            x = self.ffn["fc1"](x)
            x = self.ffn["fc2"](x)
            return x
    
    # Create HALPE instance
    import logging
    logger = logging.getLogger(__name__)
    halpe = TransformerHALPE(DummyBlock(), config, schema, logger, device="cpu")
    
    # Test mixed head and FFN indices
    global_indices = torch.tensor([1, 3, 8, 10])  # head: 1; FFN: 3,8,10
    result = halpe.pruned_blocks_indices_per_layer(global_indices)
    
    # Head indices should remain the same
    assert result['head'] == [1], f"Expected head indices [1], got {result['head']}"
    
    # FFN indices should be converted to local (3->1, 8->6, 10->8)
    expected_ffn_local = [3 - config.num_heads, 8 - config.num_heads, 10 - config.num_heads]
    assert result['ffn'] == expected_ffn_local, f"Expected FFN indices {expected_ffn_local}, got {result['ffn']}"
    
    # Test all head indices
    head_indices = torch.tensor([0, 1])
    head_result = halpe.pruned_blocks_indices_per_layer(head_indices)
    assert head_result['head'] == [0, 1], f"Expected head indices [0, 1], got {head_result['head']}"
    assert head_result['ffn'] == [], f"Expected empty FFN indices, got {head_result['ffn']}"
    
    # Test all FFN indices
    ffn_indices = torch.tensor([2, 3, 5])
    ffn_result = halpe.pruned_blocks_indices_per_layer(ffn_indices)
    assert ffn_result['head'] == [], f"Expected empty head indices, got {ffn_result['head']}"
    expected_ffn = [2 - config.num_heads, 3 - config.num_heads, 5 - config.num_heads]
    assert ffn_result['ffn'] == expected_ffn, f"Expected FFN indices {expected_ffn}, got {ffn_result['ffn']}"
    
    print("✓ Index mapping test passed")
    return True

def test_layer_filtering():
    """Test that layer-specific filtering preserves block indices."""
    print("Testing layer filtering...")
    
    # Create mock initial importances for multiple layers
    initial_importances = torch.tensor([
        [1.0, 0, 0],  # layer 0, block 0
        [2.0, 0, 1],  # layer 0, block 1
        [3.0, 0, 2],  # layer 0, block 2
        [4.0, 1, 0],  # layer 1, block 0
        [5.0, 1, 1],  # layer 1, block 1
        [6.0, 2, 0],  # layer 2, block 0
    ])
    
    # Filter for layer 0
    layer_0_importances = initial_importances[initial_importances[:, 1] == 0]
    
    # Check that we get the right rows
    assert layer_0_importances.shape[0] == 3, f"Expected 3 rows, got {layer_0_importances.shape[0]}"
    assert torch.all(layer_0_importances[:, 1] == 0), "All layer indices should be 0"
    
    # Check that block indices are preserved
    expected_block_indices = torch.tensor([0, 1, 2])
    assert torch.all(layer_0_importances[:, 2] == expected_block_indices), "Block indices should be preserved"
    
    print("✓ Layer filtering test passed")
    return True

def test_candidate_selection_logic():
    """Test the complete candidate block selection logic."""
    print("Testing candidate selection logic...")
    
    # Create mock initial importances for a single layer
    layer_importances = torch.tensor([
        [5.0, 0, 0],  # importance=5.0, layer=0, block=0 (head)
        [1.0, 0, 1],  # importance=1.0, layer=0, block=1 (head) - least important
        [3.0, 0, 2],  # importance=3.0, layer=0, block=2 (head)
        [4.0, 0, 8],  # importance=4.0, layer=0, block=8 (FFN)
        [2.0, 0, 9],  # importance=2.0, layer=0, block=9 (FFN) - second least important
    ])
    
    # Sort by importance (ascending: least important first)
    sorted_indices = layer_importances[:, 0].argsort()
    block_indices = layer_importances[sorted_indices, 2]
    
    # Select top 2 candidates (least important)
    candidate_blocks = block_indices[:2]
    
    # Expected: blocks 1 and 9 (least important)
    expected_candidates = torch.tensor([1, 9])
    assert torch.all(candidate_blocks == expected_candidates), f"Expected candidates {expected_candidates}, got {candidate_blocks}"
    
    # Test mapping to local indices (simulating HALPE logic)
    num_heads = 2
    head_indices = candidate_blocks[candidate_blocks < num_heads]
    ffn_indices = candidate_blocks[candidate_blocks >= num_heads] - num_heads
    
    assert torch.all(head_indices == torch.tensor([1])), f"Expected head indices [1], got {head_indices}"
    assert torch.all(ffn_indices == torch.tensor([7])), f"Expected FFN indices [7], got {ffn_indices}"
    
    print("✓ Candidate selection logic test passed")
    return True

def test_edge_cases():
    """Test edge cases and boundary conditions."""
    print("Testing edge cases...")
    
    from unittest.mock import Mock
    from utils.layer_utils import TransformerConfig, TransformerLayerSchema, LayerType, LayerSpec
    from prune.halpe import TransformerHALPE
    
    # Create test config
    config = TransformerConfig(
        layer_type=LayerType.transformer,
        hidden_size=8,
        head_size=4,
        num_heads=2,
        intermediate_dimension=4
    )
    
    # Create test schema
    schema = TransformerLayerSchema(
        layer_name="TestBlock",
        layer_type=LayerType.transformer,
        norm_type="pre_norm",
        layers={
            "q": LayerSpec("attn", "q", "row"),
            "k": LayerSpec("attn", "k", "row"),
            "v": LayerSpec("attn", "v", "row"),
            "o": LayerSpec("attn", "o", "column"),
            "fc1": LayerSpec("ffn", "fc1", "row"),
            "fc2": LayerSpec("ffn", "fc2", "column")
        }
    )
    
    # Create dummy block
    class DummyBlock(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.attn = torch.nn.ModuleDict({
                "q": torch.nn.Linear(8, 8, bias=False),
                "k": torch.nn.Linear(8, 8, bias=False),
                "v": torch.nn.Linear(8, 8, bias=False),
                "o": torch.nn.Linear(8, 8, bias=False),
            })
            self.ffn = torch.nn.ModuleDict({
                "fc1": torch.nn.Linear(8, 4, bias=False),
                "fc2": torch.nn.Linear(4, 8, bias=False),
            })
        
        def forward(self, x):
            x = self.attn["q"](x) + self.attn["k"](x) + self.attn["v"](x)
            x = self.attn["o"](x)
            x = self.ffn["fc1"](x)
            x = self.ffn["fc2"](x)
            return x
    
    # Create HALPE instance
    import logging
    logger = logging.getLogger(__name__)
    halpe = TransformerHALPE(DummyBlock(), config, schema, logger, device="cpu")
    
    # Test empty indices
    empty_result = halpe.pruned_blocks_indices_per_layer(torch.tensor([]))
    assert empty_result['head'] == [], "Empty indices should result in empty head list"
    assert empty_result['ffn'] == [], "Empty indices should result in empty FFN list"
    
    # Test boundary indices
    boundary_indices = torch.tensor([config.num_heads - 1, config.num_heads])  # Last head, first FFN
    boundary_result = halpe.pruned_blocks_indices_per_layer(boundary_indices)
    
    assert boundary_result['head'] == [config.num_heads - 1], f"Expected head index [{config.num_heads - 1}], got {boundary_result['head']}"
    assert boundary_result['ffn'] == [0], f"Expected FFN index [0], got {boundary_result['ffn']}"
    
    print("✓ Edge cases test passed")
    return True

def main():
    """Run all tests."""
    print("Running candidate block selection indexing tests...\n")
    
    tests = [
        test_initial_importances_indexing,
        test_index_mapping,
        test_layer_filtering,
        test_candidate_selection_logic,
        test_edge_cases
    ]
    
    passed = 0
    total = len(tests)
    
    for test in tests:
        try:
            if test():
                passed += 1
        except Exception as e:
            print(f"✗ {test.__name__} failed: {e}")
    
    print(f"\nTest Results: {passed}/{total} tests passed")
    
    if passed == total:
        print("🎉 All tests passed! The indexing logic is working correctly.")
        return 0
    else:
        print("❌ Some tests failed. Please check the indexing logic.")
        return 1

if __name__ == "__main__":
    exit(main()) 