"""
Correctness tests for Soft Eisner (Projective Dependency Parsing).
"""

import pytest
import torch

from reference import soft_eisner_forward_naive, soft_eisner_naive

try:
    import d2p
    D2P_AVAILABLE = True
except ImportError:
    D2P_AVAILABLE = False

CUDA_AVAILABLE = torch.cuda.is_available()


def allclose(a, b, rtol=1e-4, atol=1e-5):
    return torch.allclose(a.cpu(), b.cpu(), rtol=rtol, atol=atol)


def max_diff(a, b):
    return (a.cpu() - b.cpu()).abs().max().item()


@pytest.fixture(params=[1, 4])
def batch_size(request):
    return request.param


@pytest.fixture(params=[4, 8, 12])
def seq_length(request):
    return request.param


@pytest.fixture(params=[0.1, 1.0, 2.0])
def temperature(request):
    return request.param


@pytest.fixture
def device():
    return torch.device('cuda' if CUDA_AVAILABLE else 'cpu')


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestForward:

    def test_partition(self, batch_size, seq_length, temperature, device):
        """Test that partition functions match."""
        torch.manual_seed(42)
        arc_scores = torch.randn(batch_size, seq_length, seq_length, device=device)

        partition_ref, _, _ = soft_eisner_forward_naive(arc_scores, temperature)
        partition_d2p, _ = d2p.soft_eisner_with_grads(arc_scores, temperature, None)

        assert allclose(partition_ref, partition_d2p), \
            f"Partition mismatch: max diff = {max_diff(partition_ref, partition_d2p)}"

    def test_partition_positive_scores(self, batch_size, device):
        """Test with positive arc scores."""
        seq_length = 6
        temperature = 1.0

        torch.manual_seed(42)
        arc_scores = torch.randn(batch_size, seq_length, seq_length, device=device).abs()

        partition_ref, _, _ = soft_eisner_forward_naive(arc_scores, temperature)
        partition_d2p, _ = d2p.soft_eisner_with_grads(arc_scores, temperature, None)

        assert allclose(partition_ref, partition_d2p), \
            f"Partition mismatch with positive scores: max diff = {max_diff(partition_ref, partition_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestBackward:

    def test_arc_marginals(self, batch_size, seq_length, temperature, device):
        """Test that arc marginals match."""
        torch.manual_seed(42)
        arc_scores = torch.randn(batch_size, seq_length, seq_length, device=device)

        marginals_ref = soft_eisner_naive(arc_scores, temperature)
        _, marginals_d2p = d2p.soft_eisner_with_grads(arc_scores, temperature, None)

        assert allclose(marginals_ref, marginals_d2p, rtol=1e-3, atol=1e-4), \
            f"Arc marginals mismatch: max diff = {max_diff(marginals_ref, marginals_d2p)}"

    def test_gradients(self, batch_size, device):
        """Test gradients through the soft Eisner."""
        seq_length = 6
        temperature = 1.0

        torch.manual_seed(42)
        arc_scores = torch.randn(batch_size, seq_length, seq_length, device=device)
        arc_scores.requires_grad_(True)

        # Use soft_eisner_float which has autograd support
        partition = d2p.soft_eisner_float(arc_scores, temperature, None)
        loss = partition.sum()
        loss.backward()
        grad_d2p = arc_scores.grad.clone()

        arc_scores_ref = arc_scores.detach().clone().requires_grad_(True)
        partition_ref, _, _ = soft_eisner_forward_naive(arc_scores_ref, temperature)
        loss_ref = partition_ref.sum()
        loss_ref.backward()
        grad_ref = arc_scores_ref.grad

        assert allclose(grad_ref, grad_d2p, rtol=1e-2, atol=1e-3), \
            f"Gradient mismatch: max diff = {max_diff(grad_ref, grad_d2p)}"

    def test_marginal_sum(self, batch_size, seq_length, device):
        """Test that arc marginals for each dependent sum to ~1."""
        temperature = 1.0

        torch.manual_seed(42)
        arc_scores = torch.randn(batch_size, seq_length, seq_length, device=device)

        _, marginals = d2p.soft_eisner_with_grads(arc_scores, temperature, None)

        # For each word j (except root at 0), sum of incoming arcs should be ~1
        # marginals[:, :, j].sum(dim=1) ≈ 1 for j > 0
        for j in range(1, seq_length):
            incoming_sum = marginals[:, :, j].sum(dim=1)
            expected = torch.ones(batch_size, device=device)
            assert allclose(incoming_sum, expected, rtol=1e-2, atol=1e-2), \
                f"Incoming arc sum for word {j} should be ~1, got {incoming_sum}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestHVP:

    def test_hvp_finite_diff(self, device):
        """Test HVP against finite differences."""
        B, n = 2, 6
        temperature = 1.0
        eps = 1e-4

        torch.manual_seed(42)
        arc_scores = torch.randn(B, n, n, device=device)
        V = torch.randn(B, n, n, device=device)

        hvp_d2p = d2p.soft_eisner_hvp(arc_scores, V, temperature, None)

        marginals_plus = soft_eisner_naive(arc_scores + eps * V, temperature)
        marginals_minus = soft_eisner_naive(arc_scores - eps * V, temperature)
        hvp_fd = (marginals_plus - marginals_minus) / (2 * eps)

        assert allclose(hvp_fd, hvp_d2p, rtol=1e-2, atol=5e-3), \
            f"HVP mismatch: max diff = {max_diff(hvp_fd, hvp_d2p)}"

    def test_hvp_various_temps(self, device):
        """Test HVP with various temperatures."""
        B, n = 2, 6
        eps = 1e-4

        for temperature in [0.5, 1.0, 2.0]:
            torch.manual_seed(42)
            arc_scores = torch.randn(B, n, n, device=device)
            V = torch.randn(B, n, n, device=device)

            hvp_d2p = d2p.soft_eisner_hvp(arc_scores, V, temperature, None)

            marginals_plus = soft_eisner_naive(arc_scores + eps * V, temperature)
            marginals_minus = soft_eisner_naive(arc_scores - eps * V, temperature)
            hvp_fd = (marginals_plus - marginals_minus) / (2 * eps)

            assert allclose(hvp_fd, hvp_d2p, rtol=1e-2, atol=5e-3), \
                f"HVP mismatch for temp={temperature}: max diff = {max_diff(hvp_fd, hvp_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available")
class TestCPUCUDA:

    def test_consistency(self):
        """Test CPU vs CUDA produce identical results."""
        B, n = 2, 8
        temperature = 1.0

        torch.manual_seed(42)
        arc_scores_cpu = torch.randn(B, n, n)
        arc_scores_cuda = arc_scores_cpu.cuda()

        _, marginals_cpu = d2p.soft_eisner_with_grads(arc_scores_cpu, temperature, None)
        _, marginals_cuda = d2p.soft_eisner_with_grads(arc_scores_cuda, temperature, None)

        assert allclose(marginals_cpu, marginals_cuda), \
            f"CPU/CUDA mismatch: max diff = {max_diff(marginals_cpu, marginals_cuda)}"

    def test_consistency_partition(self):
        """Test CPU vs CUDA partition functions match."""
        B, n = 2, 10
        temperature = 1.0

        torch.manual_seed(42)
        arc_scores_cpu = torch.randn(B, n, n)
        arc_scores_cuda = arc_scores_cpu.cuda()

        partition_cpu, _ = d2p.soft_eisner_with_grads(arc_scores_cpu, temperature, None)
        partition_cuda, _ = d2p.soft_eisner_with_grads(arc_scores_cuda, temperature, None)

        assert allclose(partition_cpu, partition_cuda), \
            f"CPU/CUDA partition mismatch: max diff = {max_diff(partition_cpu, partition_cuda)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestVariableLength:

    def test_variable_lengths(self, device):
        """Test with variable sentence lengths in batch."""
        B = 4
        max_n = 10
        temperature = 1.0

        torch.manual_seed(42)
        arc_scores = torch.randn(B, max_n, max_n, device=device)

        # Variable lengths
        lengths = torch.tensor([8, 10, 6, 9], device=device, dtype=torch.int32)

        partition, marginals = d2p.soft_eisner_with_grads(arc_scores, temperature, lengths)

        # Check each batch element individually
        for b in range(B):
            n = lengths[b].item()
            arc_scores_b = arc_scores[b:b+1, :n, :n]

            partition_ref, _, _ = soft_eisner_forward_naive(arc_scores_b, temperature)
            marginals_ref = soft_eisner_naive(arc_scores_b, temperature)

            assert allclose(partition_ref, partition[b:b+1]), \
                f"Partition mismatch for batch {b}: {partition_ref.item()} vs {partition[b].item()}"

            assert allclose(marginals_ref, marginals[b:b+1, :n, :n], rtol=1e-3, atol=1e-4), \
                f"Marginals mismatch for batch {b}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestEdgeCases:

    def test_two_words(self, device):
        """Test with 2-word sentence (simplest non-trivial case)."""
        arc_scores = torch.tensor([[[0.0, 0.5], [-0.3, 0.0]]], device=device)
        temperature = 1.0

        partition, marginals = d2p.soft_eisner_with_grads(arc_scores, temperature, None)
        partition_ref, _, _ = soft_eisner_forward_naive(arc_scores, temperature)
        marginals_ref = soft_eisner_naive(arc_scores, temperature)

        assert allclose(partition, partition_ref), \
            f"Two words partition wrong: {partition.item()} vs {partition_ref.item()}"
        assert allclose(marginals, marginals_ref), \
            "Two words marginals mismatch"

    def test_three_words(self, device):
        """Test with 3-word sentence."""
        torch.manual_seed(42)
        arc_scores = torch.randn(1, 3, 3, device=device)
        temperature = 1.0

        partition, marginals = d2p.soft_eisner_with_grads(arc_scores, temperature, None)
        partition_ref, _, _ = soft_eisner_forward_naive(arc_scores, temperature)
        marginals_ref = soft_eisner_naive(arc_scores, temperature)

        assert allclose(partition, partition_ref), \
            "Three words partition mismatch"
        assert allclose(marginals, marginals_ref, rtol=1e-3), \
            "Three words marginals mismatch"

    def test_low_temperature(self, device):
        """Test with low temperature (approaches hard parsing)."""
        B, n = 2, 6
        temperature = 0.01

        torch.manual_seed(42)
        arc_scores = torch.randn(B, n, n, device=device)

        _, marginals = d2p.soft_eisner_with_grads(arc_scores, temperature, None)

        # With low temperature, marginals should be close to 0 or 1
        assert marginals.min() >= -0.1, "Low temp marginals should be >= 0"
        assert marginals.max() <= 1.1, "Low temp marginals should be <= 1"

    def test_high_temperature(self, device):
        """Test with high temperature (more uniform distribution)."""
        B, n = 2, 6
        temperature = 10.0

        torch.manual_seed(42)
        arc_scores = torch.randn(B, n, n, device=device)

        _, marginals = d2p.soft_eisner_with_grads(arc_scores, temperature, None)
        marginals_ref = soft_eisner_naive(arc_scores, temperature)

        assert allclose(marginals_ref, marginals, rtol=1e-3, atol=1e-4), \
            "High temperature marginals mismatch"

    def test_diagonal_zeros(self, device):
        """Test that self-loop marginals are zero."""
        B, n = 2, 6
        temperature = 1.0

        torch.manual_seed(42)
        arc_scores = torch.randn(B, n, n, device=device)

        _, marginals = d2p.soft_eisner_with_grads(arc_scores, temperature, None)

        # Self-loops (diagonal) should have zero probability
        diag_marginals = torch.diagonal(marginals, dim1=1, dim2=2)
        assert allclose(diag_marginals, torch.zeros_like(diag_marginals), atol=1e-5), \
            f"Self-loop marginals should be 0, got max {diag_marginals.abs().max().item()}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestWithGrads:

    def test_with_grads_output(self, device):
        """Test soft_eisner_with_grads returns correct values."""
        B, n = 2, 8
        temperature = 1.0

        torch.manual_seed(42)
        arc_scores = torch.randn(B, n, n, device=device)

        partition, marginals = d2p.soft_eisner_with_grads(arc_scores, temperature, None)

        # Compare with reference
        partition_ref, _, _ = soft_eisner_forward_naive(arc_scores, temperature)
        marginals_ref = soft_eisner_naive(arc_scores, temperature)

        assert allclose(partition, partition_ref), "with_grads partition mismatch"
        assert allclose(marginals, marginals_ref, rtol=1e-3), "with_grads marginals mismatch"

    def test_backward_full_output(self, device):
        """Test soft_eisner_backward_full returns correct values."""
        B, n = 2, 8
        temperature = 1.0

        torch.manual_seed(42)
        arc_scores = torch.randn(B, n, n, device=device)

        partition, marginals, grad_T = d2p.soft_eisner_backward_full(arc_scores, temperature, None)

        # Compare with soft_eisner_with_grads
        partition_ref, marginals_ref = d2p.soft_eisner_with_grads(arc_scores, temperature, None)

        assert allclose(partition, partition_ref), "backward_full partition mismatch"
        assert allclose(marginals, marginals_ref), "backward_full marginals mismatch"
        assert grad_T.shape == (B,), f"grad_T shape wrong: {grad_T.shape}"


if __name__ == '__main__':
    pytest.main([__file__, '-v'])
