"""
Correctness tests for Soft MAS (Monotonic Alignment Search).
"""

import pytest
import torch

from reference import soft_mas_forward_naive, soft_mas_naive

try:
    import d2p
    D2P_AVAILABLE = True
except ImportError:
    D2P_AVAILABLE = False

CUDA_AVAILABLE = torch.cuda.is_available()


def allclose(a, b, rtol=1e-4, atol=1e-5):
    return torch.allclose(a.cpu(), b.cpu(), rtol=rtol, atol=atol)


def max_diff(a, b):
    return (a.cpu() - b.cpu()).abs().max().item()


@pytest.fixture(params=[1, 4])
def batch_size(request):
    return request.param


@pytest.fixture(params=[(10, 8), (16, 16), (20, 5)])
def seq_lengths(request):
    """(T, S) where T=frames, S=text length. Must have T >= S."""
    return request.param


@pytest.fixture(params=[0.1, 1.0, 2.0])
def temperature(request):
    return request.param


@pytest.fixture
def device():
    return torch.device('cuda' if CUDA_AVAILABLE else 'cpu')


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestForward:

    def test_score(self, batch_size, seq_lengths, temperature, device):
        """Test that MAS scores match."""
        T, S = seq_lengths

        torch.manual_seed(42)
        scores = torch.randn(batch_size, T, S, device=device)

        score_ref, _ = soft_mas_forward_naive(scores, temperature)
        score_d2p, _ = d2p.soft_mas_with_grads(scores, temperature, None)

        assert allclose(score_ref, score_d2p), \
            f"Score mismatch: max diff = {max_diff(score_ref, score_d2p)}"

    def test_score_positive_scores(self, batch_size, device):
        """Test MAS with positive scores."""
        T, S = 10, 6
        temperature = 1.0

        torch.manual_seed(42)
        scores = torch.randn(batch_size, T, S, device=device).abs()

        score_ref, _ = soft_mas_forward_naive(scores, temperature)
        score_d2p, _ = d2p.soft_mas_with_grads(scores, temperature, None)

        assert allclose(score_ref, score_d2p), \
            f"Score mismatch with positive scores: max diff = {max_diff(score_ref, score_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestBackward:

    def test_posteriors(self, batch_size, seq_lengths, temperature, device):
        """Test that alignment posteriors match."""
        T, S = seq_lengths

        torch.manual_seed(42)
        scores = torch.randn(batch_size, T, S, device=device)

        posteriors_ref = soft_mas_naive(scores, temperature)
        _, posteriors_d2p = d2p.soft_mas_with_grads(scores, temperature, None)

        assert allclose(posteriors_ref, posteriors_d2p, rtol=1e-3, atol=1e-4), \
            f"Posteriors mismatch: max diff = {max_diff(posteriors_ref, posteriors_d2p)}"

    def test_gradients(self, batch_size, device):
        """Test gradients through the soft alignment."""
        T, S = 8, 5
        temperature = 1.0

        torch.manual_seed(42)
        scores = torch.randn(batch_size, T, S, device=device)
        scores.requires_grad_(True)

        # Use soft_mas_float which has autograd support
        score, _ = d2p.soft_mas_float(scores, temperature, None)
        loss = score.sum()
        loss.backward()
        grad_d2p = scores.grad.clone()

        scores_ref = scores.detach().clone().requires_grad_(True)
        score_ref, _ = soft_mas_forward_naive(scores_ref, temperature)
        loss_ref = score_ref.sum()
        loss_ref.backward()
        grad_ref = scores_ref.grad

        assert allclose(grad_ref, grad_d2p, rtol=1e-2, atol=1e-3), \
            f"Gradient mismatch: max diff = {max_diff(grad_ref, grad_d2p)}"

    def test_score_gradients(self, batch_size, device):
        """Test gradients through the score output."""
        T, S = 8, 5
        temperature = 1.0

        torch.manual_seed(42)
        scores = torch.randn(batch_size, T, S, device=device)
        scores.requires_grad_(True)

        score, _ = d2p.soft_mas_float(scores, temperature, None)
        loss = score.sum()
        loss.backward()
        grad_d2p = scores.grad.clone()

        scores_ref = scores.detach().clone().requires_grad_(True)
        score_ref, _ = soft_mas_forward_naive(scores_ref, temperature)
        loss_ref = score_ref.sum()
        loss_ref.backward()
        grad_ref = scores_ref.grad

        assert allclose(grad_ref, grad_d2p, rtol=1e-3, atol=1e-4), \
            f"Score gradient mismatch: max diff = {max_diff(grad_ref, grad_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestHVP:

    def test_hvp_finite_diff(self, device):
        """Test HVP against finite differences."""
        B, T, S = 2, 8, 5
        temperature = 1.0
        eps = 1e-4

        torch.manual_seed(42)
        scores = torch.randn(B, T, S, device=device)
        V = torch.randn(B, T, S, device=device)

        hvp_d2p = d2p.soft_mas_hvp(scores, V, temperature, None)

        posteriors_plus = soft_mas_naive(scores + eps * V, temperature)
        posteriors_minus = soft_mas_naive(scores - eps * V, temperature)
        hvp_fd = (posteriors_plus - posteriors_minus) / (2 * eps)

        assert allclose(hvp_fd, hvp_d2p, rtol=1e-2, atol=5e-3), \
            f"HVP mismatch: max diff = {max_diff(hvp_fd, hvp_d2p)}"

    def test_hvp_various_temps(self, device):
        """Test HVP with various temperatures."""
        B, T, S = 2, 10, 6
        eps = 1e-4

        for temperature in [0.5, 1.0, 2.0]:
            torch.manual_seed(42)
            scores = torch.randn(B, T, S, device=device)
            V = torch.randn(B, T, S, device=device)

            hvp_d2p = d2p.soft_mas_hvp(scores, V, temperature, None)

            posteriors_plus = soft_mas_naive(scores + eps * V, temperature)
            posteriors_minus = soft_mas_naive(scores - eps * V, temperature)
            hvp_fd = (posteriors_plus - posteriors_minus) / (2 * eps)

            assert allclose(hvp_fd, hvp_d2p, rtol=1e-2, atol=5e-3), \
                f"HVP mismatch for temp={temperature}: max diff = {max_diff(hvp_fd, hvp_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available")
class TestCPUCUDA:

    def test_consistency(self):
        """Test CPU vs CUDA produce identical results."""
        B, T, S = 2, 12, 8
        temperature = 1.0

        torch.manual_seed(42)
        scores_cpu = torch.randn(B, T, S)
        scores_cuda = scores_cpu.cuda()

        _, posteriors_cpu = d2p.soft_mas_with_grads(scores_cpu, temperature, None)
        _, posteriors_cuda = d2p.soft_mas_with_grads(scores_cuda, temperature, None)

        assert allclose(posteriors_cpu, posteriors_cuda), \
            f"CPU/CUDA mismatch: max diff = {max_diff(posteriors_cpu, posteriors_cuda)}"

    def test_consistency_score(self):
        """Test CPU vs CUDA scores match."""
        B, T, S = 2, 15, 10
        temperature = 1.0

        torch.manual_seed(42)
        scores_cpu = torch.randn(B, T, S)
        scores_cuda = scores_cpu.cuda()

        score_cpu, _ = d2p.soft_mas_with_grads(scores_cpu, temperature, None)
        score_cuda, _ = d2p.soft_mas_with_grads(scores_cuda, temperature, None)

        assert allclose(score_cpu, score_cuda), \
            f"CPU/CUDA score mismatch: max diff = {max_diff(score_cpu, score_cuda)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestVariableLength:

    def test_variable_lengths(self, device):
        """Test with variable sequence lengths in batch."""
        B = 4
        max_T, max_S = 15, 10
        temperature = 1.0

        torch.manual_seed(42)
        scores = torch.randn(B, max_T, max_S, device=device)

        # Variable lengths [T, S] pairs
        lengths = torch.tensor([
            [12, 8],
            [15, 10],
            [10, 6],
            [14, 9]
        ], device=device, dtype=torch.int32)

        score, posteriors = d2p.soft_mas_with_grads(scores, temperature, lengths)

        # Check each batch element individually
        for b in range(B):
            t_len, s_len = lengths[b].tolist()
            scores_b = scores[b:b+1, :t_len, :s_len]

            score_ref, _ = soft_mas_forward_naive(scores_b, temperature)
            posteriors_ref = soft_mas_naive(scores_b, temperature)

            assert allclose(score_ref, score[b:b+1]), \
                f"Score mismatch for batch {b}: {score_ref.item()} vs {score[b].item()}"

            assert allclose(posteriors_ref, posteriors[b:b+1, :t_len, :s_len], rtol=1e-3, atol=1e-4), \
                f"Posteriors mismatch for batch {b}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestEdgeCases:

    def test_single_element(self, device):
        """Test 1x1 score matrix."""
        scores = torch.tensor([[[0.5]]], device=device)
        temperature = 1.0

        score, posteriors = d2p.soft_mas_with_grads(scores, temperature, None)
        score_ref, _ = soft_mas_forward_naive(scores, temperature)
        posteriors_ref = soft_mas_naive(scores, temperature)

        # For 1x1, posterior should be 1.0 (only one path)
        assert allclose(score, score_ref), f"Single element score wrong: {score.item()} vs {score_ref.item()}"
        assert allclose(posteriors, posteriors_ref), f"Single element posterior wrong: {posteriors.item()} vs {posteriors_ref.item()}"

    def test_single_text_token(self, device):
        """Test Tx1 score matrix (single text token)."""
        scores = torch.tensor([[[0.1], [0.2], [0.3], [0.4]]], device=device)
        temperature = 1.0

        score, posteriors = d2p.soft_mas_with_grads(scores, temperature, None)
        score_ref, _ = soft_mas_forward_naive(scores, temperature)
        posteriors_ref = soft_mas_naive(scores, temperature)

        # All frames must align to the single text token
        assert allclose(score, score_ref), "Single text token score mismatch"
        assert allclose(posteriors, posteriors_ref, rtol=1e-3), "Single text token posteriors mismatch"

    def test_equal_length(self, device):
        """Test TxT score matrix (one-to-one alignment)."""
        T = 5
        scores = torch.randn(1, T, T, device=device)
        temperature = 1.0

        score, posteriors = d2p.soft_mas_with_grads(scores, temperature, None)
        score_ref, _ = soft_mas_forward_naive(scores, temperature)
        posteriors_ref = soft_mas_naive(scores, temperature)

        assert allclose(score, score_ref), "Equal length score mismatch"
        assert allclose(posteriors, posteriors_ref, rtol=1e-3), "Equal length posteriors mismatch"

    def test_low_temperature(self, device):
        """Test with low temperature (approaches hard alignment)."""
        B, T, S = 2, 10, 6
        temperature = 0.01

        torch.manual_seed(42)
        scores = torch.randn(B, T, S, device=device)

        _, posteriors = d2p.soft_mas_with_grads(scores, temperature, None)

        # With low temperature, posteriors should be close to 0 or 1
        assert posteriors.min() >= -0.1, "Low temp posteriors should be >= 0"
        assert posteriors.max() <= 1.1, "Low temp posteriors should be <= 1"

    def test_high_temperature(self, device):
        """Test with high temperature (more uniform distribution)."""
        B, T, S = 2, 10, 6
        temperature = 10.0

        torch.manual_seed(42)
        scores = torch.randn(B, T, S, device=device)

        _, posteriors = d2p.soft_mas_with_grads(scores, temperature, None)
        posteriors_ref = soft_mas_naive(scores, temperature)

        assert allclose(posteriors_ref, posteriors, rtol=1e-3, atol=1e-4), \
            "High temperature posteriors mismatch"

    def test_monotonicity_constraint(self, device):
        """Test that alignment respects monotonicity."""
        B, T, S = 1, 8, 4
        temperature = 0.1  # Low temp for clear alignment

        # Create scores that favor diagonal alignment
        scores = torch.zeros(B, T, S, device=device)
        for t in range(T):
            s = min(t // 2, S - 1)  # Expected alignment: each text gets ~2 frames
            scores[0, t, s] = 5.0

        _, posteriors = d2p.soft_mas_with_grads(scores, temperature, None)

        # Posteriors should concentrate on the diagonal region
        # (monotonicity is enforced by the DP structure)
        assert posteriors.sum(dim=-1).min() > 0.9, "Each frame should align somewhere"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestWithGrads:

    def test_with_grads_output(self, device):
        """Test soft_mas_with_grads returns correct values."""
        B, T, S = 2, 10, 6
        temperature = 1.0

        torch.manual_seed(42)
        scores = torch.randn(B, T, S, device=device)

        score, posteriors = d2p.soft_mas_with_grads(scores, temperature, None)

        # Compare with reference
        score_ref, _ = soft_mas_forward_naive(scores, temperature)
        posteriors_ref = soft_mas_naive(scores, temperature)

        assert allclose(score, score_ref), "with_grads score mismatch"
        assert allclose(posteriors, posteriors_ref, rtol=1e-3), "with_grads posteriors mismatch"

    def test_backward_full_output(self, device):
        """Test soft_mas_backward_full returns correct values."""
        B, T, S = 2, 10, 6
        temperature = 1.0

        torch.manual_seed(42)
        scores = torch.randn(B, T, S, device=device)

        score, posteriors, grad_T = d2p.soft_mas_backward_full(scores, temperature, None)

        # Compare with soft_mas_with_grads
        score_ref, posteriors_ref = d2p.soft_mas_with_grads(scores, temperature, None)

        assert allclose(score, score_ref), "backward_full score mismatch"
        assert allclose(posteriors, posteriors_ref), "backward_full posteriors mismatch"
        assert grad_T.shape == (B,), f"grad_T shape wrong: {grad_T.shape}"


if __name__ == '__main__':
    pytest.main([__file__, '-v'])
