"""
Correctness tests for Soft CKY parsing.
"""

import pytest
import torch

from reference import soft_cky_forward_naive, soft_cky_naive

try:
    import d2p
    D2P_AVAILABLE = True
except ImportError:
    D2P_AVAILABLE = False

CUDA_AVAILABLE = torch.cuda.is_available()


def allclose(a, b, rtol=1e-4, atol=1e-5):
    return torch.allclose(a.cpu(), b.cpu(), rtol=rtol, atol=atol)


def max_diff(a, b):
    return (a.cpu() - b.cpu()).abs().max().item()


@pytest.fixture(params=[1, 4])
def batch_size(request):
    return request.param


@pytest.fixture(params=[4, 8, 12])
def seq_length(request):
    return request.param


@pytest.fixture(params=[0.1, 1.0, 2.0])
def temperature(request):
    return request.param


@pytest.fixture
def device():
    return torch.device('cuda' if CUDA_AVAILABLE else 'cpu')


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestForward:

    def test_partition(self, batch_size, seq_length, temperature, device):
        """Test that partition functions match."""
        torch.manual_seed(42)
        merge_scores = torch.randn(batch_size, seq_length, seq_length, seq_length, device=device)
        leaf_scores = torch.randn(batch_size, seq_length, device=device)

        partition_ref, _ = soft_cky_forward_naive(merge_scores, leaf_scores, temperature)
        partition_d2p = d2p.soft_cky_float(merge_scores, leaf_scores, temperature)[0]

        assert allclose(partition_ref, partition_d2p), \
            f"Partition mismatch: max diff = {max_diff(partition_ref, partition_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestBackward:

    def test_marginals(self, batch_size, seq_length, temperature, device):
        """Test that span marginals match."""
        torch.manual_seed(42)
        merge_scores = torch.randn(batch_size, seq_length, seq_length, seq_length, device=device)
        leaf_scores = torch.randn(batch_size, seq_length, device=device)

        marginals_ref = soft_cky_naive(merge_scores, leaf_scores, temperature)
        marginals_d2p = d2p.soft_cky_float(merge_scores, leaf_scores, temperature)[1]

        assert allclose(marginals_ref, marginals_d2p, rtol=1e-3, atol=1e-4), \
            f"Marginals mismatch: max diff = {max_diff(marginals_ref, marginals_d2p)}"

    def test_gradients(self, batch_size, temperature, device):
        """Test gradients through the soft CKY."""
        seq_length = 6

        torch.manual_seed(42)
        merge_scores = torch.randn(batch_size, seq_length, seq_length, seq_length, device=device, requires_grad=True)
        leaf_scores = torch.randn(batch_size, seq_length, device=device, requires_grad=True)

        # Create tensor temperature for gradient tracking
        temp_tensor = torch.tensor([temperature], device=device)
        marginals = d2p.soft_cky(merge_scores, leaf_scores, temp_tensor)[1]
        loss = (marginals ** 2).sum()
        loss.backward()
        grad_d2p_merge = merge_scores.grad.clone()

        merge_scores_ref = merge_scores.detach().clone().requires_grad_(True)
        leaf_scores_ref = leaf_scores.detach().clone().requires_grad_(True)
        marginals_ref = soft_cky_naive(merge_scores_ref, leaf_scores_ref, temperature)
        loss_ref = (marginals_ref ** 2).sum()
        loss_ref.backward()
        grad_ref_merge = merge_scores_ref.grad

        assert allclose(grad_ref_merge, grad_d2p_merge, rtol=1e-2, atol=1e-3), \
            f"Gradient mismatch: max diff = {max_diff(grad_ref_merge, grad_d2p_merge)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestHVP:

    def test_hvp_finite_diff(self, device):
        """Test HVP against finite differences."""
        B, N = 2, 6
        temperature = 1.0
        eps = 1e-4

        torch.manual_seed(42)
        merge_scores = torch.randn(B, N, N, N, device=device)
        leaf_scores = torch.randn(B, N, device=device)
        V_merge = torch.randn(B, N, N, N, device=device)
        V_leaf = torch.randn(B, N, device=device)

        hvp_d2p = d2p.soft_cky_hvp(merge_scores, leaf_scores, V_merge, V_leaf, temperature)

        marginals_plus = soft_cky_naive(merge_scores + eps * V_merge, leaf_scores + eps * V_leaf, temperature)
        marginals_minus = soft_cky_naive(merge_scores - eps * V_merge, leaf_scores - eps * V_leaf, temperature)
        hvp_fd = (marginals_plus - marginals_minus) / (2 * eps)

        assert allclose(hvp_fd, hvp_d2p, rtol=1e-2, atol=2e-3), \
            f"HVP mismatch: max diff = {max_diff(hvp_fd, hvp_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available")
class TestCPUCUDA:

    def test_consistency(self):
        """Test CPU vs CUDA produce identical results."""
        B, N = 2, 8
        temperature = 1.0

        torch.manual_seed(42)
        merge_scores_cpu = torch.randn(B, N, N, N)
        leaf_scores_cpu = torch.randn(B, N)
        merge_scores_cuda = merge_scores_cpu.cuda()
        leaf_scores_cuda = leaf_scores_cpu.cuda()

        marginals_cpu = d2p.soft_cky_float(merge_scores_cpu, leaf_scores_cpu, temperature)[1]
        marginals_cuda = d2p.soft_cky_float(merge_scores_cuda, leaf_scores_cuda, temperature)[1]

        assert allclose(marginals_cpu, marginals_cuda), \
            f"CPU/CUDA mismatch: max diff = {max_diff(marginals_cpu, marginals_cuda)}"


if __name__ == '__main__':
    pytest.main([__file__, '-v'])
