"""
Simple, focused tests for GCG attack that can run quickly without GPU/large models.

These tests focus on:
- Configuration validation
- Loss function correctness
- Basic functionality
- Mathematical properties
"""

import pytest
import torch
from src.attacks.gcg import GCGAttack, GCGConfig, compute_loss
from src.attacks.attack import GenerationConfig


class TestGCGConfigValidation:
    """Test GCG configuration validation and defaults."""

    def test_default_values(self):
        """Test that default configuration values are sensible."""
        config = GCGConfig()

        # Test core attack parameters
        assert config.name == "gcg"
        assert config.type == "discrete"
        assert config.placement in ["suffix", "prompt"]
        assert config.num_steps > 0
        assert config.search_width > 0
        assert config.topk > 0
        assert config.n_replace > 0

        # Test optimization parameters
        assert config.loss in ["mellowmax", "cw", "ce"]
        assert config.mellowmax_alpha > 0
        assert config.grad_smoothing >= 1
        assert config.grad_momentum >= 0.0

        # Test boolean flags
        assert isinstance(config.early_stop, bool)
        assert isinstance(config.use_prefix_cache, bool)
        assert isinstance(config.allow_non_ascii, bool)
        assert isinstance(config.filter_ids, bool)

    def test_generation_config_inheritance(self):
        """Test that GenerationConfig is properly inherited."""
        gen_config = GenerationConfig(max_new_tokens=50, temperature=0.7)
        config = GCGConfig(generation_config=gen_config)

        assert config.generation_config.max_new_tokens == 50
        assert config.generation_config.temperature == 0.7

    def test_custom_values(self):
        """Test setting custom configuration values."""
        config = GCGConfig(
            num_steps=100,
            search_width=256,
            topk=128,
            loss="mellowmax",
            mellowmax_alpha=2.0,
            placement="prompt",
            early_stop=True
        )

        assert config.num_steps == 100
        assert config.search_width == 256
        assert config.topk == 128
        assert config.loss == "mellowmax"
        assert config.mellowmax_alpha == 2.0
        assert config.placement == "prompt"
        assert config.early_stop == True


class TestGCGLossFunctions:
    """Test mathematical correctness of loss functions."""

    def setup_method(self):
        """Set up test tensors for loss computation."""
        torch.manual_seed(42)  # For reproducible tests
        self.batch_size = 3
        self.seq_len = 4
        self.vocab_size = 10

        # Create deterministic test data
        self.logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
        self.labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
        self.disallowed_ids = torch.tensor([0, 1])

    def test_cross_entropy_properties(self):
        """Test mathematical properties of cross-entropy loss."""
        loss = compute_loss(self.logits, self.labels, "ce")

        # Basic shape and type checks
        assert loss.shape == (self.batch_size,)
        assert loss.dtype == torch.float32

        # Mathematical properties
        assert (loss >= 0).all(), "Cross-entropy loss should be non-negative"
        assert torch.isfinite(loss).all(), "Loss should be finite"
        assert not torch.isnan(loss).any(), "Loss should not contain NaN"

        # Test with perfect predictions (should give low loss)
        perfect_logits = torch.zeros_like(self.logits)
        for i in range(self.batch_size):
            for j in range(self.seq_len):
                perfect_logits[i, j, self.labels[i, j]] = 10.0  # High confidence

        perfect_loss = compute_loss(perfect_logits, self.labels, "ce")
        assert (perfect_loss < loss).all(), "Perfect predictions should have lower loss"

    def test_mellowmax_properties(self):
        """Test mathematical properties of mellowmax loss."""
        alpha = 1.0
        loss = compute_loss(self.logits, self.labels, "mellowmax", mellowmax_alpha=alpha)

        # Basic checks
        assert loss.shape == (self.batch_size,)
        assert torch.isfinite(loss).all()

        # Test alpha scaling
        alpha_small = 0.1
        alpha_large = 10.0

        loss_small = compute_loss(self.logits, self.labels, "mellowmax", mellowmax_alpha=alpha_small)
        loss_large = compute_loss(self.logits, self.labels, "mellowmax", mellowmax_alpha=alpha_large)

        # Different alpha values should produce different losses
        assert not torch.allclose(loss_small, loss_large, atol=1e-6)

    def test_cw_loss_properties(self):
        """Test Carlini-Wagner loss properties."""
        loss = compute_loss(self.logits, self.labels, "cw")

        assert loss.shape == (self.batch_size,)
        assert torch.isfinite(loss).all()

        # CW loss can be negative (unlike cross-entropy)
        # This is expected behavior for CW loss

    def test_entropy_loss_properties(self):
        """Test entropy-based loss properties."""
        loss = compute_loss(
            self.logits,
            self.labels,
            "entropy",
            disallowed_ids=self.disallowed_ids
        )

        assert loss.shape == (self.batch_size,)
        assert torch.isfinite(loss).all()

        # Test with uniform distribution (should have high entropy)
        uniform_logits = torch.zeros_like(self.logits)
        uniform_loss = compute_loss(
            uniform_logits,
            self.labels,
            "entropy",
            disallowed_ids=self.disallowed_ids
        )

        # Uniform distribution should have different entropy than random logits
        assert not torch.allclose(loss, uniform_loss)

    def test_loss_consistency(self):
        """Test that loss functions are deterministic and consistent."""
        # Same inputs should produce same outputs
        loss1 = compute_loss(self.logits, self.labels, "ce")
        loss2 = compute_loss(self.logits, self.labels, "ce")

        assert torch.allclose(loss1, loss2), "Loss should be deterministic"

        # Test with different batch sizes
        small_logits = self.logits[:1]
        small_labels = self.labels[:1]
        small_loss = compute_loss(small_logits, small_labels, "ce")

        assert small_loss.shape == (1,)
        assert torch.allclose(small_loss, loss1[:1]), "Loss should be consistent across batch sizes"

    def test_invalid_loss_type(self):
        """Test error handling for invalid loss types."""
        with pytest.raises((NotImplementedError, ValueError, KeyError)):
            compute_loss(self.logits, self.labels, "nonexistent_loss")

    def test_edge_case_inputs(self):
        """Test loss computation with edge case inputs."""
        # Single token sequence
        single_logits = torch.randn(1, 1, self.vocab_size)
        single_labels = torch.randint(0, self.vocab_size, (1, 1))

        loss = compute_loss(single_logits, single_labels, "ce")
        assert loss.shape == (1,)
        assert torch.isfinite(loss).all()

        # Large vocabulary
        large_vocab_logits = torch.randn(1, 2, 1000)
        large_vocab_labels = torch.randint(0, 1000, (1, 2))

        loss = compute_loss(large_vocab_logits, large_vocab_labels, "ce")
        assert loss.shape == (1,)
        assert torch.isfinite(loss).all()


class TestGCGMathematicalCorrectness:
    """Test mathematical correctness of GCG components."""

    def test_mellowmax_approximation(self):
        """Test that mellowmax loss computation works with different alpha values."""
        # Create test logits with multiple time steps to test mellowmax behavior
        batch_size, seq_len, vocab_size = 1, 3, 4
        logits = torch.tensor([[[1.0, 4.0, 3.0, 2.0],
                               [2.0, 1.0, 4.0, 3.0],
                               [3.0, 2.0, 1.0, 4.0]]], dtype=torch.float32)
        labels = torch.tensor([[0, 1, 2]], dtype=torch.long)

        # Test with different alpha values
        alphas = [0.1, 1.0, 10.0]
        losses = []

        for alpha in alphas:
            loss = compute_loss(logits, labels, "mellowmax", mellowmax_alpha=alpha)
            losses.append(loss.item())

        # Mellowmax should produce finite, valid loss values
        for loss in losses:
            assert torch.isfinite(torch.tensor(loss)), "Mellowmax loss should be finite"
            assert not torch.isnan(torch.tensor(loss)), "Mellowmax loss should not be NaN"

    def test_gradient_properties(self):
        """Test that loss functions have proper gradient properties."""
        logits = torch.randn(2, 3, 5, requires_grad=True)
        labels = torch.randint(0, 5, (2, 3))

        # Test that gradients can be computed
        loss = compute_loss(logits, labels, "ce")
        total_loss = loss.sum()
        total_loss.backward()

        assert logits.grad is not None, "Gradients should be computed"
        assert logits.grad.shape == logits.shape, "Gradient shape should match input"
        assert torch.isfinite(logits.grad).all(), "Gradients should be finite"

    def test_loss_scaling(self):
        """Test loss scaling properties."""
        logits = torch.randn(2, 3, 5)
        labels = torch.randint(0, 5, (2, 3))

        # Scale logits and test loss behavior
        scaled_logits = logits * 2.0

        loss_original = compute_loss(logits, labels, "ce")
        loss_scaled = compute_loss(scaled_logits, labels, "ce")

        # Scaling logits should change the loss
        assert not torch.allclose(loss_original, loss_scaled), "Scaling should affect loss"


class TestGCGAttackBasics:
    """Test basic GCG attack functionality without requiring full model."""

    def test_attack_initialization(self):
        """Test that GCG attack initializes correctly."""
        config = GCGConfig(num_steps=10, search_width=32)
        attack = GCGAttack(config)

        assert attack.config == config
        assert hasattr(attack, 'logger')
        assert attack.tokenizer is None  # Should be None before run()

    def test_attack_logging_setup(self):
        """Test that logging is properly configured."""
        config = GCGConfig()
        attack = GCGAttack(config)

        # Check logger configuration
        assert attack.logger.name == "nanogcg"

        # Logger should have handlers (either directly or through parent hierarchy)
        assert attack.logger.hasHandlers()

        # Should be able to log without errors
        attack.logger.info("Test log message")

    def test_config_parameter_ranges(self):
        """Test that config parameters are within reasonable ranges."""
        config = GCGConfig()

        # Test parameter bounds
        assert 0 < config.num_steps <= 10000
        assert 0 < config.search_width <= 100000
        assert 0 < config.topk <= config.search_width
        assert 0 < config.n_replace <= 100
        assert config.mellowmax_alpha > 0
        assert 0 <= config.grad_momentum <= 1.0
        assert config.grad_smoothing >= 1


# Test fixtures for parameterized tests
@pytest.mark.parametrize("loss_type", ["ce", "mellowmax", "cw"])
def test_loss_function_basic_properties(loss_type):
    """Parameterized test for basic loss function properties."""
    batch_size, seq_len, vocab_size = 2, 3, 10
    logits = torch.randn(batch_size, seq_len, vocab_size)
    labels = torch.randint(0, vocab_size, (batch_size, seq_len))

    if loss_type == "mellowmax":
        loss = compute_loss(logits, labels, loss_type, mellowmax_alpha=1.0)
    else:
        loss = compute_loss(logits, labels, loss_type)

    # Universal properties all loss functions should satisfy
    assert loss.shape == (batch_size,), f"Loss shape incorrect for {loss_type}"
    assert torch.isfinite(loss).all(), f"Loss not finite for {loss_type}"
    assert not torch.isnan(loss).any(), f"Loss contains NaN for {loss_type}"


@pytest.mark.parametrize("alpha", [0.1, 1.0, 2.0, 10.0])
def test_mellowmax_alpha_scaling(alpha):
    """Test mellowmax loss with different alpha values."""
    logits = torch.randn(2, 3, 5)
    labels = torch.randint(0, 5, (2, 3))

    loss = compute_loss(logits, labels, "mellowmax", mellowmax_alpha=alpha)

    assert loss.shape == (2,)
    assert torch.isfinite(loss).all()
    assert alpha > 0  # Ensure test parameter is valid


@pytest.mark.parametrize("placement", ["suffix", "prompt"])
def test_placement_config(placement):
    """Test different placement configurations."""
    config = GCGConfig(placement=placement)
    attack = GCGAttack(config)

    assert attack.config.placement == placement


if __name__ == "__main__":
    # Run tests with verbose output
    pytest.main([__file__, "-v", "--tb=short"])