"""
Correctness tests for Soft Smith-Waterman (affine gap).
"""

import pytest
import torch

from reference import soft_sw_affine_forward_naive, soft_sw_affine_naive

try:
    import d2p
    D2P_AVAILABLE = True
except ImportError:
    D2P_AVAILABLE = False

CUDA_AVAILABLE = torch.cuda.is_available()


def allclose(a, b, rtol=1e-4, atol=1e-5):
    return torch.allclose(a.cpu(), b.cpu(), rtol=rtol, atol=atol)


def max_diff(a, b):
    return (a.cpu() - b.cpu()).abs().max().item()


@pytest.fixture(params=[1, 4])
def batch_size(request):
    return request.param


@pytest.fixture(params=[(8, 10), (16, 16), (5, 20)])
def seq_lengths(request):
    return request.param


@pytest.fixture(params=[0.1, 1.0, 2.0])
def temperature(request):
    return request.param


@pytest.fixture
def device():
    return torch.device('cuda' if CUDA_AVAILABLE else 'cpu')


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestForward:

    def test_partition(self, batch_size, seq_lengths, temperature, device):
        """Test that partition functions match."""
        L1, L2 = seq_lengths
        gap_open = -2.0
        gap_ext = -0.5

        torch.manual_seed(42)
        scores = torch.randn(batch_size, L1, L2, device=device)

        partition_ref, _, _, _ = soft_sw_affine_forward_naive(
            scores, gap_open, gap_ext, temperature
        )
        partition_d2p = d2p.soft_sw_affine_float(
            scores, gap_open, gap_ext, temperature, None
        )[0]

        assert allclose(partition_ref, partition_d2p), \
            f"Partition mismatch: max diff = {max_diff(partition_ref, partition_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestBackward:

    def test_posteriors(self, batch_size, seq_lengths, temperature, device):
        """Test that alignment posteriors match."""
        L1, L2 = seq_lengths
        gap_open = -2.0
        gap_ext = -0.5

        torch.manual_seed(42)
        scores = torch.randn(batch_size, L1, L2, device=device)

        posteriors_ref = soft_sw_affine_naive(
            scores, gap_open, gap_ext, temperature
        )
        posteriors_d2p = d2p.soft_sw_affine_float(
            scores, gap_open, gap_ext, temperature, None
        )[1]

        assert allclose(posteriors_ref, posteriors_d2p, rtol=1e-3, atol=1e-4), \
            f"Posteriors mismatch: max diff = {max_diff(posteriors_ref, posteriors_d2p)}"

    def test_gradients(self, batch_size, temperature, device):
        """Test gradients through the soft alignment."""
        L1, L2 = 6, 8
        gap_open = -2.0
        gap_ext = -0.5

        torch.manual_seed(42)
        scores = torch.randn(batch_size, L1, L2, device=device, requires_grad=True)

        posteriors = d2p.soft_sw_affine_float(scores, gap_open, gap_ext, temperature, None)[1]
        loss = (posteriors ** 2).sum()
        loss.backward()
        grad_d2p = scores.grad.clone()

        scores_ref = scores.detach().clone().requires_grad_(True)
        posteriors_ref = soft_sw_affine_naive(scores_ref, gap_open, gap_ext, temperature)
        loss_ref = (posteriors_ref ** 2).sum()
        loss_ref.backward()
        grad_ref = scores_ref.grad

        assert allclose(grad_ref, grad_d2p, rtol=1e-2, atol=1e-3), \
            f"Gradient mismatch: max diff = {max_diff(grad_ref, grad_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestHVP:

    def test_hvp_finite_diff(self, device):
        """Test HVP against finite differences."""
        B, L1, L2 = 2, 5, 6
        gap_open = -2.0
        gap_ext = -0.5
        temperature = 1.0
        eps = 1e-4

        torch.manual_seed(42)
        scores = torch.randn(B, L1, L2, device=device)
        V = torch.randn(B, L1, L2, device=device)

        hvp_d2p = d2p.soft_sw_affine_hvp(scores, V, gap_open, gap_ext, temperature, None)

        posteriors_plus = soft_sw_affine_naive(
            scores + eps * V, gap_open, gap_ext, temperature
        )
        posteriors_minus = soft_sw_affine_naive(
            scores - eps * V, gap_open, gap_ext, temperature
        )
        hvp_fd = (posteriors_plus - posteriors_minus) / (2 * eps)

        assert allclose(hvp_fd, hvp_d2p, rtol=1e-2, atol=1e-3), \
            f"HVP mismatch: max diff = {max_diff(hvp_fd, hvp_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available")
class TestCPUCUDA:

    def test_consistency(self):
        """Test CPU vs CUDA produce identical results."""
        B, L1, L2 = 2, 8, 10
        gap_open = -2.0
        gap_ext = -0.5
        temperature = 1.0

        torch.manual_seed(42)
        scores_cpu = torch.randn(B, L1, L2)
        scores_cuda = scores_cpu.cuda()

        posteriors_cpu = d2p.soft_sw_affine_float(scores_cpu, gap_open, gap_ext, temperature, None)[1]
        posteriors_cuda = d2p.soft_sw_affine_float(scores_cuda, gap_open, gap_ext, temperature, None)[1]

        assert allclose(posteriors_cpu, posteriors_cuda), \
            f"CPU/CUDA mismatch: max diff = {max_diff(posteriors_cpu, posteriors_cuda)}"


if __name__ == '__main__':
    pytest.main([__file__, '-v'])
