"""
Correctness tests for Soft Hamming Distance.
"""

import pytest
import torch

from reference import soft_hamming_forward_naive, soft_hamming_naive

try:
    import d2p
    D2P_AVAILABLE = True
except ImportError:
    D2P_AVAILABLE = False

CUDA_AVAILABLE = torch.cuda.is_available()


def allclose(a, b, rtol=1e-4, atol=1e-5):
    return torch.allclose(a.cpu(), b.cpu(), rtol=rtol, atol=atol)


def max_diff(a, b):
    return (a.cpu() - b.cpu()).abs().max().item()


@pytest.fixture(params=[1, 4])
def batch_size(request):
    return request.param


@pytest.fixture(params=[8, 16, 32])
def seq_length(request):
    return request.param


@pytest.fixture(params=[0.1, 1.0, 2.0])
def temperature(request):
    return request.param


@pytest.fixture
def device():
    return torch.device('cuda' if CUDA_AVAILABLE else 'cpu')


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestForward:

    def test_distance(self, batch_size, seq_length, temperature, device):
        """Test that Hamming distances match."""
        torch.manual_seed(42)
        costs = torch.rand(batch_size, seq_length, device=device) * 2

        distance_ref, _ = soft_hamming_forward_naive(costs, temperature)
        distance_d2p = d2p.soft_hamming_float(costs, temperature, None)[0]

        assert allclose(distance_ref, distance_d2p), \
            f"Distance mismatch: max diff = {max_diff(distance_ref, distance_d2p)}"

    def test_distance_is_sum(self, batch_size, device):
        """Verify distance is just the sum of costs."""
        L = 10
        temperature = 1.0

        torch.manual_seed(42)
        costs = torch.rand(batch_size, L, device=device)

        distance = d2p.soft_hamming_float(costs, temperature, None)[0]
        expected = costs.sum(dim=1)

        assert allclose(distance, expected), \
            f"Distance should be sum: max diff = {max_diff(distance, expected)}"

    def test_zero_costs(self, batch_size, temperature, device):
        """Test with all zero costs (perfect match)."""
        L = 10
        costs = torch.zeros(batch_size, L, device=device)

        distance, posteriors = d2p.soft_hamming_float(costs, temperature, None)

        assert allclose(distance, torch.zeros(batch_size, device=device)), \
            "Zero costs should give zero distance"
        assert allclose(posteriors, torch.ones_like(posteriors)), \
            "Posteriors should be all ones"

    def test_all_ones_costs(self, batch_size, device):
        """Test with all ones costs (all mismatches)."""
        L = 10
        temperature = 1.0
        costs = torch.ones(batch_size, L, device=device)

        distance = d2p.soft_hamming_float(costs, temperature, None)[0]

        # Distance should be L (sum of L ones)
        expected = torch.full((batch_size,), float(L), device=device)
        assert allclose(distance, expected), \
            f"All ones should give L: max diff = {max_diff(distance, expected)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestBackward:

    def test_posteriors(self, batch_size, seq_length, temperature, device):
        """Test that posteriors are all ones."""
        torch.manual_seed(42)
        costs = torch.rand(batch_size, seq_length, device=device) * 2

        posteriors_ref = soft_hamming_naive(costs, temperature)
        posteriors_d2p = d2p.soft_hamming_float(costs, temperature, None)[1]

        assert allclose(posteriors_ref, posteriors_d2p), \
            f"Posteriors mismatch: max diff = {max_diff(posteriors_ref, posteriors_d2p)}"

    def test_posteriors_are_ones(self, batch_size, device):
        """Verify posteriors are exactly ones (gradient of sum is 1)."""
        L = 10
        temperature = 1.0

        torch.manual_seed(42)
        costs = torch.rand(batch_size, L, device=device)

        posteriors = d2p.soft_hamming_float(costs, temperature, None)[1]
        expected = torch.ones_like(posteriors)

        assert allclose(posteriors, expected), \
            "Posteriors should be all ones"

    def test_gradients(self, batch_size, temperature, device):
        """Test gradients through the soft distance."""
        L = 10

        torch.manual_seed(42)
        costs = torch.rand(batch_size, L, device=device)
        costs.requires_grad_(True)

        distance = d2p.soft_hamming_float(costs, temperature, None)[0]
        loss = distance.sum()
        loss.backward()
        grad_d2p = costs.grad.clone()

        # Gradient of sum w.r.t. each element is 1
        grad_ref = torch.ones_like(costs)

        assert allclose(grad_ref, grad_d2p), \
            f"Gradient mismatch: max diff = {max_diff(grad_ref, grad_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestWithGrads:

    def test_with_grads_returns_param_grads(self, device):
        """Test that soft_hamming_with_grads returns temperature gradient (zero)."""
        B, L = 2, 10
        temperature = 1.0

        torch.manual_seed(42)
        costs = torch.rand(B, L, device=device)

        distance, posteriors, grad_T = d2p.soft_hamming_with_grads(costs, temperature, None)

        # Check shapes
        assert distance.shape == (B,)
        assert posteriors.shape == (B, L)
        assert grad_T.shape == (B,)

        # Temperature gradient should be zero (Hamming doesn't use temperature)
        assert allclose(grad_T, torch.zeros(B, device=device)), \
            "Temperature gradient should be zero"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestHVP:

    def test_hvp_is_zero(self, device):
        """Test HVP is zero (Hessian of linear function is 0)."""
        B, L = 2, 10
        temperature = 1.0

        torch.manual_seed(42)
        costs = torch.rand(B, L, device=device)
        V = torch.randn(B, L, device=device)

        hvp = d2p.soft_hamming_hvp(costs, V, temperature, None)

        # HVP should be all zeros
        assert allclose(hvp, torch.zeros_like(hvp)), \
            "HVP should be zero for linear function"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available")
class TestCPUCUDA:

    def test_consistency(self):
        """Test CPU vs CUDA produce identical results."""
        B, L = 2, 16
        temperature = 1.0

        torch.manual_seed(42)
        costs_cpu = torch.rand(B, L)
        costs_cuda = costs_cpu.cuda()

        distance_cpu, posteriors_cpu = d2p.soft_hamming_float(costs_cpu, temperature, None)
        distance_cuda, posteriors_cuda = d2p.soft_hamming_float(costs_cuda, temperature, None)

        assert allclose(distance_cpu, distance_cuda), \
            f"CPU/CUDA distance mismatch: max diff = {max_diff(distance_cpu, distance_cuda)}"
        assert allclose(posteriors_cpu, posteriors_cuda), \
            f"CPU/CUDA posteriors mismatch: max diff = {max_diff(posteriors_cpu, posteriors_cuda)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestVariableLength:

    def test_variable_lengths(self, device):
        """Test with variable sequence lengths in batch."""
        B = 4
        max_L = 20
        temperature = 1.0

        torch.manual_seed(42)
        costs = torch.rand(B, max_L, device=device)

        # Variable lengths
        lengths = torch.tensor([15, 20, 10, 18], device=device, dtype=torch.int32)

        distance, posteriors = d2p.soft_hamming_float(costs, temperature, lengths)

        # Check each batch element individually
        for b in range(B):
            l = lengths[b].item()
            costs_b = costs[b:b+1, :l]

            distance_ref, _ = soft_hamming_forward_naive(costs_b, temperature)

            assert allclose(distance_ref, distance[b:b+1]), \
                f"Distance mismatch for batch {b}: {distance_ref.item()} vs {distance[b].item()}"

            # Posteriors should be 1 for valid positions, 0 for invalid
            for i in range(max_L):
                expected = 1.0 if i < l else 0.0
                assert abs(posteriors[b, i].item() - expected) < 1e-5, \
                    f"Posterior mismatch at batch {b}, position {i}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestEdgeCases:

    def test_single_element(self, device):
        """Test 1x1 costs."""
        costs = torch.tensor([[0.5]], device=device)
        temperature = 1.0

        distance, posteriors = d2p.soft_hamming_float(costs, temperature, None)

        assert allclose(distance, torch.tensor([0.5], device=device)), \
            f"Single element distance wrong: {distance.item()}"
        assert allclose(posteriors, torch.ones_like(posteriors)), \
            "Single element posterior should be 1"

    def test_large_batch(self, device):
        """Test with large batch size."""
        B, L = 64, 32
        temperature = 1.0

        torch.manual_seed(42)
        costs = torch.rand(B, L, device=device)

        distance, posteriors = d2p.soft_hamming_float(costs, temperature, None)

        # Just verify shapes and that values are reasonable
        assert distance.shape == (B,)
        assert posteriors.shape == (B, L)
        assert not torch.isnan(distance).any(), "Distance contains NaN"
        assert not torch.isnan(posteriors).any(), "Posteriors contain NaN"

    def test_long_sequence(self, device):
        """Test with long sequence."""
        B, L = 2, 256
        temperature = 1.0

        torch.manual_seed(42)
        costs = torch.rand(B, L, device=device)

        distance, posteriors = d2p.soft_hamming_float(costs, temperature, None)

        # Verify against reference
        distance_ref, _ = soft_hamming_forward_naive(costs, temperature)
        assert allclose(distance, distance_ref), \
            f"Long sequence distance mismatch: max diff = {max_diff(distance, distance_ref)}"

    def test_temperature_independence(self, device):
        """Verify that results don't depend on temperature."""
        B, L = 2, 10

        torch.manual_seed(42)
        costs = torch.rand(B, L, device=device)

        # Test with different temperatures
        distance1, posteriors1 = d2p.soft_hamming_float(costs, 0.1, None)
        distance2, posteriors2 = d2p.soft_hamming_float(costs, 1.0, None)
        distance3, posteriors3 = d2p.soft_hamming_float(costs, 10.0, None)

        assert allclose(distance1, distance2), \
            "Distance should not depend on temperature"
        assert allclose(distance2, distance3), \
            "Distance should not depend on temperature"
        assert allclose(posteriors1, posteriors2), \
            "Posteriors should not depend on temperature"
        assert allclose(posteriors2, posteriors3), \
            "Posteriors should not depend on temperature"

    def test_backward_full(self, device):
        """Test backward_full function."""
        B, L = 2, 10
        temperature = 1.0

        torch.manual_seed(42)
        costs = torch.rand(B, L, device=device)
        grad_output = torch.ones(B, device=device)

        grad_costs, grad_T = d2p.soft_hamming_backward_full(costs, grad_output, temperature, None)

        # Gradient should be grad_output[b] for each position
        expected_grad = torch.ones(B, L, device=device)  # grad_output is all ones
        assert allclose(grad_costs, expected_grad), \
            f"backward_full grad_costs mismatch: max diff = {max_diff(grad_costs, expected_grad)}"

        # Temperature gradient should be zero
        assert allclose(grad_T, torch.zeros(B, device=device)), \
            "backward_full grad_T should be zero"


if __name__ == '__main__':
    pytest.main([__file__, '-v'])
