import torch
import pytest
from sae import SAEConfig, VanillaSAE, BatchTopKSAE, TopKSAE, JumpReLUSAE


@pytest.fixture
def sample_config():
    return SAEConfig(
        act_size=64,
        dict_size=128,
        input_unit_norm=True,
        input_mean=0.5,
        input_std=1.2,
    )


@pytest.fixture
def sample_input():
    torch.manual_seed(42)
    return torch.randn(32, 64)


def test_fold_stats_into_weights(sample_config, sample_input):
    """Test that folding statistics preserves the output"""
    sae = VanillaSAE(sample_config)
    
    # Get original output
    with torch.no_grad():
        original_output = sae(sample_input)
    
    # Fold statistics
    sae.fold_stats_into_weights()
    
    # Get new output
    with torch.no_grad():
        folded_output = sae(sample_input)
    
    # Check that outputs are close
    assert (original_output['sae_out'] - folded_output['sae_out']).abs().max() < 1e-3
    
    # Check that input_unit_norm is disabled
    assert not sae.config.input_unit_norm


def test_fold_W_dec_norm(sample_config, sample_input):
    """Test that folding decoder norms preserves the output"""
    sae = VanillaSAE(sample_config)
    
    # Get original output and norms
    with torch.no_grad():
        original_output = sae(sample_input)
        original_norms = sae.W_dec.norm(dim=-1)
    
    # Fold decoder norms
    sae.fold_W_dec_norm()
    
    # Get new output and norms
    with torch.no_grad():
        folded_output = sae(sample_input)
        folded_norms = sae.W_dec.norm(dim=-1)
    
    # Check that outputs are close
    assert torch.allclose(
        original_output['sae_out'],
        folded_output['sae_out'],
        rtol=1e-4
    )
    
    # Check that decoder weights are unit norm
    assert torch.allclose(folded_norms, torch.ones_like(folded_norms), rtol=1e-5)


@pytest.mark.parametrize("sae_class", [VanillaSAE, BatchTopKSAE, TopKSAE, JumpReLUSAE])
def test_folding_operations(sae_class, sample_config, sample_input):
    """Test each folding operation separately"""
    sae = sae_class(sample_config)
    
    # Test stats folding
    with torch.no_grad():
        original_output = sae(sample_input)
        sae.fold_stats_into_weights()
        stats_folded_output = sae(sample_input)
        stats_diff = (original_output['sae_out'] - stats_folded_output['sae_out']).abs().max()
        assert stats_diff < 1e-3, f"Stats folding changed output by {stats_diff}"
    
    # Test W_dec_norm folding
    with torch.no_grad():
        sae.fold_W_dec_norm()
        both_folded_output = sae(sample_input)
        norm_diff = (stats_folded_output['sae_out'] - both_folded_output['sae_out']).abs().max()
        assert norm_diff < 1e-3, f"W_dec_norm folding changed output by {norm_diff}"
        
        # Verify decoder norms
        assert torch.allclose(
            sae.W_dec.norm(dim=-1),
            torch.ones(sae.config.dict_size),
            rtol=1e-3
        )


def test_jumprelu_threshold_scaling(sample_config, sample_input):
    """Test that JumpReLU thresholds are properly scaled"""
    sae = JumpReLUSAE(sample_config)
    
    # Get original thresholds and decoder norms
    with torch.no_grad():
        original_thresholds = torch.exp(sae.jumprelu.log_threshold.clone())
        original_norms = sae.W_dec.norm(dim=-1)
    
    # Fold decoder norms
    sae.fold_W_dec_norm()
    
    # Get new thresholds
    new_thresholds = torch.exp(sae.jumprelu.log_threshold)
    
    # Check that thresholds are scaled by the original norms
    assert torch.allclose(
        new_thresholds,
        original_thresholds * original_norms,
        rtol=1e-3, atol=1e-1
    ) 