"""
Correctness tests for Soft DTW (Dynamic Time Warping).
"""

import pytest
import torch

from reference import soft_dtw_forward_naive, soft_dtw_naive

try:
    import d2p
    D2P_AVAILABLE = True
except ImportError:
    D2P_AVAILABLE = False

CUDA_AVAILABLE = torch.cuda.is_available()


def allclose(a, b, rtol=1e-4, atol=1e-5):
    return torch.allclose(a.cpu(), b.cpu(), rtol=rtol, atol=atol)


def max_diff(a, b):
    return (a.cpu() - b.cpu()).abs().max().item()


@pytest.fixture(params=[1, 4])
def batch_size(request):
    return request.param


@pytest.fixture(params=[(8, 10), (16, 16), (5, 20)])
def seq_lengths(request):
    return request.param


@pytest.fixture(params=[0.1, 1.0, 2.0])
def temperature(request):
    return request.param


@pytest.fixture
def device():
    return torch.device('cuda' if CUDA_AVAILABLE else 'cpu')


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestForward:

    def test_score(self, batch_size, seq_lengths, temperature, device):
        """Test that DTW scores match."""
        L1, L2 = seq_lengths

        torch.manual_seed(42)
        costs = torch.randn(batch_size, L1, L2, device=device).abs()  # Use positive costs

        score_ref, _ = soft_dtw_forward_naive(costs, temperature)
        score_d2p = d2p.soft_dtw_float(costs, temperature, None, -1)[0]

        assert allclose(score_ref, score_d2p), \
            f"Score mismatch: max diff = {max_diff(score_ref, score_d2p)}"

    def test_score_with_bandwidth(self, batch_size, temperature, device):
        """Test DTW scores with Sakoe-Chiba band."""
        L1, L2 = 12, 15
        bandwidth = 3

        torch.manual_seed(42)
        costs = torch.randn(batch_size, L1, L2, device=device).abs()

        score_ref, _ = soft_dtw_forward_naive(costs, temperature, bandwidth=bandwidth)
        score_d2p = d2p.soft_dtw_float(costs, temperature, None, bandwidth)[0]

        assert allclose(score_ref, score_d2p), \
            f"Score mismatch with bandwidth: max diff = {max_diff(score_ref, score_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestBackward:

    def test_posteriors(self, batch_size, seq_lengths, temperature, device):
        """Test that alignment posteriors match."""
        L1, L2 = seq_lengths

        torch.manual_seed(42)
        costs = torch.randn(batch_size, L1, L2, device=device).abs()

        posteriors_ref = soft_dtw_naive(costs, temperature)
        posteriors_d2p = d2p.soft_dtw_float(costs, temperature, None, -1)[1]

        assert allclose(posteriors_ref, posteriors_d2p, rtol=1e-3, atol=1e-4), \
            f"Posteriors mismatch: max diff = {max_diff(posteriors_ref, posteriors_d2p)}"

    def test_posteriors_with_bandwidth(self, batch_size, temperature, device):
        """Test posteriors with Sakoe-Chiba band."""
        L1, L2 = 12, 15
        bandwidth = 3

        torch.manual_seed(42)
        costs = torch.randn(batch_size, L1, L2, device=device).abs()

        posteriors_ref = soft_dtw_naive(costs, temperature, bandwidth=bandwidth)
        posteriors_d2p = d2p.soft_dtw_float(costs, temperature, None, bandwidth)[1]

        assert allclose(posteriors_ref, posteriors_d2p, rtol=1e-3, atol=1e-4), \
            f"Posteriors mismatch with bandwidth: max diff = {max_diff(posteriors_ref, posteriors_d2p)}"

    def test_gradients(self, batch_size, temperature, device):
        """Test gradients through the soft alignment."""
        L1, L2 = 6, 8

        torch.manual_seed(42)
        # Create tensor, apply abs, then set requires_grad to get a leaf tensor
        costs = torch.randn(batch_size, L1, L2, device=device).abs()
        costs.requires_grad_(True)

        posteriors = d2p.soft_dtw_float(costs, temperature, None, -1)[1]
        loss = (posteriors ** 2).sum()
        loss.backward()
        grad_d2p = costs.grad.clone()

        costs_ref = costs.detach().clone().requires_grad_(True)
        posteriors_ref = soft_dtw_naive(costs_ref, temperature)
        loss_ref = (posteriors_ref ** 2).sum()
        loss_ref.backward()
        grad_ref = costs_ref.grad

        assert allclose(grad_ref, grad_d2p, rtol=1e-2, atol=1e-3), \
            f"Gradient mismatch: max diff = {max_diff(grad_ref, grad_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestHVP:

    def test_hvp_finite_diff(self, device):
        """Test HVP against finite differences."""
        B, L1, L2 = 2, 5, 6
        temperature = 1.0
        eps = 1e-4

        torch.manual_seed(42)
        costs = torch.randn(B, L1, L2, device=device).abs()
        V = torch.randn(B, L1, L2, device=device)

        hvp_d2p = d2p.soft_dtw_hvp(costs, V, temperature, None, -1)

        posteriors_plus = soft_dtw_naive(costs + eps * V, temperature)
        posteriors_minus = soft_dtw_naive(costs - eps * V, temperature)
        hvp_fd = (posteriors_plus - posteriors_minus) / (2 * eps)

        # Finite differences have O(eps^2) error, so allow slightly larger tolerance
        assert allclose(hvp_fd, hvp_d2p, rtol=1e-2, atol=2e-3), \
            f"HVP mismatch: max diff = {max_diff(hvp_fd, hvp_d2p)}"

    def test_hvp_with_bandwidth(self, device):
        """Test HVP with Sakoe-Chiba band."""
        B, L1, L2 = 2, 8, 10
        temperature = 1.0
        bandwidth = 3
        eps = 1e-4

        torch.manual_seed(42)
        costs = torch.randn(B, L1, L2, device=device).abs()
        V = torch.randn(B, L1, L2, device=device)

        hvp_d2p = d2p.soft_dtw_hvp(costs, V, temperature, None, bandwidth)

        posteriors_plus = soft_dtw_naive(costs + eps * V, temperature, bandwidth=bandwidth)
        posteriors_minus = soft_dtw_naive(costs - eps * V, temperature, bandwidth=bandwidth)
        hvp_fd = (posteriors_plus - posteriors_minus) / (2 * eps)

        # Finite differences have O(eps^2) error, so allow slightly larger tolerance
        assert allclose(hvp_fd, hvp_d2p, rtol=1e-2, atol=2e-3), \
            f"HVP mismatch with bandwidth: max diff = {max_diff(hvp_fd, hvp_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available")
class TestCPUCUDA:

    def test_consistency(self):
        """Test CPU vs CUDA produce identical results."""
        B, L1, L2 = 2, 8, 10
        temperature = 1.0

        torch.manual_seed(42)
        costs_cpu = torch.randn(B, L1, L2).abs()
        costs_cuda = costs_cpu.cuda()

        posteriors_cpu = d2p.soft_dtw_float(costs_cpu, temperature, None, -1)[1]
        posteriors_cuda = d2p.soft_dtw_float(costs_cuda, temperature, None, -1)[1]

        assert allclose(posteriors_cpu, posteriors_cuda), \
            f"CPU/CUDA mismatch: max diff = {max_diff(posteriors_cpu, posteriors_cuda)}"

    def test_consistency_with_bandwidth(self):
        """Test CPU vs CUDA with Sakoe-Chiba band."""
        B, L1, L2 = 2, 10, 12
        temperature = 1.0
        bandwidth = 4

        torch.manual_seed(42)
        costs_cpu = torch.randn(B, L1, L2).abs()
        costs_cuda = costs_cpu.cuda()

        posteriors_cpu = d2p.soft_dtw_float(costs_cpu, temperature, None, bandwidth)[1]
        posteriors_cuda = d2p.soft_dtw_float(costs_cuda, temperature, None, bandwidth)[1]

        assert allclose(posteriors_cpu, posteriors_cuda), \
            f"CPU/CUDA mismatch with bandwidth: max diff = {max_diff(posteriors_cpu, posteriors_cuda)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestVariableLength:

    def test_variable_lengths(self, device):
        """Test with variable sequence lengths in batch."""
        B = 4
        max_L1, max_L2 = 10, 12
        temperature = 1.0

        torch.manual_seed(42)
        costs = torch.randn(B, max_L1, max_L2, device=device).abs()

        # Variable lengths
        lengths = torch.tensor([
            [8, 10],
            [10, 12],
            [6, 8],
            [9, 11]
        ], device=device, dtype=torch.int32)

        score, posteriors = d2p.soft_dtw_float(costs, temperature, lengths, -1)

        # Check each batch element individually
        for b in range(B):
            l1, l2 = lengths[b].tolist()
            costs_b = costs[b:b+1, :l1, :l2]

            score_ref, _ = soft_dtw_forward_naive(costs_b, temperature)
            posteriors_ref = soft_dtw_naive(costs_b, temperature)

            # Score should match for this sequence
            assert allclose(score_ref, score[b:b+1]), \
                f"Score mismatch for batch {b}: {score_ref.item()} vs {score[b].item()}"

            # Posteriors for valid region should match
            assert allclose(posteriors_ref, posteriors[b:b+1, :l1, :l2], rtol=1e-3, atol=1e-4), \
                f"Posteriors mismatch for batch {b}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestEdgeCases:

    def test_single_element(self, device):
        """Test 1x1 cost matrix."""
        costs = torch.tensor([[[0.5]]], device=device)
        temperature = 1.0

        score, posteriors = d2p.soft_dtw_float(costs, temperature, None, -1)

        # Single element: score = cost, posterior = 1.0
        assert allclose(score, torch.tensor([0.5], device=device)), "Single element score wrong"
        assert allclose(posteriors, torch.ones_like(costs)), "Single element posterior wrong"

    def test_row_vector(self, device):
        """Test 1xN cost matrix."""
        costs = torch.tensor([[[0.1, 0.2, 0.3]]], device=device)
        temperature = 1.0

        score, posteriors = d2p.soft_dtw_float(costs, temperature, None, -1)
        score_ref, _ = soft_dtw_forward_naive(costs, temperature)

        assert allclose(score, score_ref), "Row vector score mismatch"
        # All posteriors should be 1 (only one path)
        assert allclose(posteriors, torch.ones_like(costs)), "Row vector posteriors wrong"

    def test_col_vector(self, device):
        """Test Nx1 cost matrix."""
        costs = torch.tensor([[[0.1], [0.2], [0.3]]], device=device)
        temperature = 1.0

        score, posteriors = d2p.soft_dtw_float(costs, temperature, None, -1)
        score_ref, _ = soft_dtw_forward_naive(costs, temperature)

        assert allclose(score, score_ref), "Column vector score mismatch"
        # All posteriors should be 1 (only one path)
        assert allclose(posteriors, torch.ones_like(costs)), "Column vector posteriors wrong"

    def test_low_temperature(self, device):
        """Test with low temperature (approaches hard DTW)."""
        B, L1, L2 = 2, 6, 8
        temperature = 0.01

        torch.manual_seed(42)
        costs = torch.randn(B, L1, L2, device=device).abs()

        posteriors = d2p.soft_dtw_float(costs, temperature, None, -1)[1]

        # With low temperature, posteriors should be close to 0 or 1
        assert posteriors.min() >= -0.1, "Low temp posteriors should be >= 0"
        assert posteriors.max() <= 1.1, "Low temp posteriors should be <= 1"

    def test_high_temperature(self, device):
        """Test with high temperature (more uniform distribution)."""
        B, L1, L2 = 2, 6, 8
        temperature = 10.0

        torch.manual_seed(42)
        costs = torch.randn(B, L1, L2, device=device).abs()

        posteriors = d2p.soft_dtw_float(costs, temperature, None, -1)[1]
        posteriors_ref = soft_dtw_naive(costs, temperature)

        assert allclose(posteriors_ref, posteriors, rtol=1e-3, atol=1e-4), \
            "High temperature posteriors mismatch"


if __name__ == '__main__':
    pytest.main([__file__, '-v'])
