"""
Correctness tests for Soft Damerau (True Damerau-Levenshtein with Unrestricted Transpositions).
"""

import pytest
import torch

from reference import soft_damerau_forward_naive, soft_damerau_naive

try:
    import d2p
    D2P_AVAILABLE = True
except ImportError:
    D2P_AVAILABLE = False

CUDA_AVAILABLE = torch.cuda.is_available()


def allclose(a, b, rtol=1e-4, atol=1e-5):
    return torch.allclose(a.cpu(), b.cpu(), rtol=rtol, atol=atol)


def max_diff(a, b):
    return (a.cpu() - b.cpu()).abs().max().item()


def create_trans_src(B, L1, L2, density=0.3, device='cpu'):
    """Create a random transposition source index tensor.

    Unlike OSA's trans_mask, trans_src[b, i, j, :] = (k, l) specifies the
    source indices for transposition. k=-1 means invalid transposition.

    For true Damerau, transposition at (i+1, j+1) can come from any (k, l)
    where k < i+1 and l < j+1. Here we randomly select valid sources.
    """
    trans_src = torch.full((B, L1, L2, 2), -1, dtype=torch.int32, device=device)

    for b in range(B):
        for i in range(L1):
            for j in range(L2):
                # Transposition at DP position (i+1, j+1) requires source (k, l)
                # with k >= 0, l >= 0, k < i+1, l < j+1
                # For simplicity, we allow source from (i-1, j-1) like OSA
                # but with variable distances
                if i >= 1 and j >= 1 and torch.rand(1).item() < density:
                    # Random valid source: k in [0, i), l in [0, j)
                    # For testing, use adjacent (k=i-1, l=j-1) most often
                    # but occasionally use more distant sources
                    if torch.rand(1).item() < 0.7:
                        # Adjacent (like OSA)
                        trans_src[b, i, j, 0] = i - 1
                        trans_src[b, i, j, 1] = j - 1
                    else:
                        # Random valid source
                        k = torch.randint(0, i, (1,)).item()
                        l = torch.randint(0, j, (1,)).item()
                        trans_src[b, i, j, 0] = k
                        trans_src[b, i, j, 1] = l

    return trans_src


def create_adjacent_trans_src(B, L1, L2, density=0.3, device='cpu'):
    """Create trans_src with only adjacent transpositions (like OSA).

    This helps verify that Damerau reduces to OSA for adjacent transpositions.
    """
    trans_src = torch.full((B, L1, L2, 2), -1, dtype=torch.int32, device=device)

    for b in range(B):
        for i in range(1, L1):  # Need i >= 1 for adjacent
            for j in range(1, L2):  # Need j >= 1 for adjacent
                if torch.rand(1).item() < density:
                    trans_src[b, i, j, 0] = i - 1
                    trans_src[b, i, j, 1] = j - 1

    return trans_src


@pytest.fixture(params=[1, 4])
def batch_size(request):
    return request.param


@pytest.fixture(params=[(8, 10), (16, 16), (5, 20)])
def seq_lengths(request):
    return request.param


@pytest.fixture(params=[0.1, 1.0, 2.0])
def temperature(request):
    return request.param


@pytest.fixture(params=[(1.0, 1.0, 1.5), (0.5, 1.5, 1.0), (2.0, 0.5, 0.8)])
def cost_params(request):
    """Different (ins_cost, del_cost, trans_cost) combinations."""
    return request.param


@pytest.fixture
def device():
    return torch.device('cuda' if CUDA_AVAILABLE else 'cpu')


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestForward:

    def test_distance(self, batch_size, seq_lengths, temperature, cost_params, device):
        """Test that Damerau distances match."""
        L1, L2 = seq_lengths
        ins_cost, del_cost, trans_cost = cost_params

        torch.manual_seed(42)
        sub_costs = torch.rand(batch_size, L1, L2, device=device) * 2
        trans_src = create_trans_src(batch_size, L1, L2, device=device)

        distance_ref, _ = soft_damerau_forward_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)
        distance_d2p = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)[0]

        assert allclose(distance_ref, distance_d2p), \
            f"Distance mismatch: max diff = {max_diff(distance_ref, distance_d2p)}"

    def test_distance_no_transpositions(self, batch_size, temperature, device):
        """Test Damerau reduces to Levenshtein when trans_src has all -1."""
        L1, L2 = 10, 12
        ins_cost = del_cost = trans_cost = 1.0

        torch.manual_seed(42)
        sub_costs = torch.rand(batch_size, L1, L2, device=device)
        trans_src = torch.full((batch_size, L1, L2, 2), -1, dtype=torch.int32, device=device)

        distance_ref, _ = soft_damerau_forward_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)
        distance_d2p = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)[0]

        assert allclose(distance_ref, distance_d2p), \
            f"Distance mismatch (no trans): max diff = {max_diff(distance_ref, distance_d2p)}"

    def test_distance_adjacent_transpositions(self, batch_size, temperature, device):
        """Test Damerau with adjacent-only transpositions (should behave like OSA)."""
        L1, L2 = 8, 10
        ins_cost, del_cost, trans_cost = 1.0, 1.0, 0.8

        torch.manual_seed(42)
        sub_costs = torch.rand(batch_size, L1, L2, device=device)
        trans_src = create_adjacent_trans_src(batch_size, L1, L2, density=0.4, device=device)

        distance_ref, _ = soft_damerau_forward_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)
        distance_d2p = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)[0]

        assert allclose(distance_ref, distance_d2p), \
            f"Distance mismatch (adjacent): max diff = {max_diff(distance_ref, distance_d2p)}"

    def test_transposition_with_gaps(self, device):
        """Test transposition with intermediate characters (gaps)."""
        B = 2
        L = 6
        temperature = 1.0
        ins_cost = del_cost = 1.0
        trans_cost = 0.5

        torch.manual_seed(42)
        sub_costs = torch.ones(B, L, L, device=device)

        # Create trans_src with varying source distances
        trans_src = torch.full((B, L, L, 2), -1, dtype=torch.int32, device=device)
        # Position (4, 4) can transpose from (1, 1) - gap of 2 each
        trans_src[:, 3, 3, 0] = 1
        trans_src[:, 3, 3, 1] = 1
        # Position (5, 5) can transpose from (0, 0) - gap of 4 each
        trans_src[:, 4, 4, 0] = 0
        trans_src[:, 4, 4, 1] = 0

        distance_ref, _ = soft_damerau_forward_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)
        distance_d2p = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)[0]

        assert allclose(distance_ref, distance_d2p), \
            f"Distance mismatch (gaps): max diff = {max_diff(distance_ref, distance_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestBackward:

    def test_posteriors(self, batch_size, seq_lengths, temperature, cost_params, device):
        """Test that alignment posteriors match."""
        L1, L2 = seq_lengths
        ins_cost, del_cost, trans_cost = cost_params

        torch.manual_seed(42)
        sub_costs = torch.rand(batch_size, L1, L2, device=device) * 2
        trans_src = create_trans_src(batch_size, L1, L2, device=device)

        posteriors_ref = soft_damerau_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)
        posteriors_d2p = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)[1]

        assert allclose(posteriors_ref, posteriors_d2p, rtol=1e-3, atol=1e-4), \
            f"Posteriors mismatch: max diff = {max_diff(posteriors_ref, posteriors_d2p)}"

    def test_gradients(self, batch_size, temperature, device):
        """Test gradients through the soft alignment."""
        L1, L2 = 6, 8
        ins_cost, del_cost, trans_cost = 1.0, 1.0, 1.5

        torch.manual_seed(42)
        sub_costs = torch.rand(batch_size, L1, L2, device=device)
        trans_src = create_trans_src(batch_size, L1, L2, device=device)
        sub_costs.requires_grad_(True)

        posteriors = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)[1]
        loss = (posteriors ** 2).sum()
        loss.backward()
        grad_d2p = sub_costs.grad.clone()

        sub_costs_ref = sub_costs.detach().clone().requires_grad_(True)
        posteriors_ref = soft_damerau_naive(sub_costs_ref, trans_src, ins_cost, del_cost, trans_cost, temperature)
        loss_ref = (posteriors_ref ** 2).sum()
        loss_ref.backward()
        grad_ref = sub_costs_ref.grad

        assert allclose(grad_ref, grad_d2p, rtol=1e-2, atol=1e-3), \
            f"Gradient mismatch: max diff = {max_diff(grad_ref, grad_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestWithGrads:

    def test_with_grads_returns_param_grads(self, device):
        """Test that soft_damerau_with_grads returns parameter gradients."""
        B, L1, L2 = 2, 6, 8
        temperature = 1.0
        ins_cost, del_cost, trans_cost = 1.0, 0.8, 1.2

        torch.manual_seed(42)
        sub_costs = torch.rand(B, L1, L2, device=device)
        trans_src = create_trans_src(B, L1, L2, device=device)

        distance, posteriors, grad_T, grad_ins, grad_del, grad_trans = d2p.soft_damerau_with_grads(
            sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None
        )

        # Check shapes
        assert distance.shape == (B,)
        assert posteriors.shape == (B, L1, L2)
        assert grad_T.shape == (B,)
        assert grad_ins.shape == (B,)
        assert grad_del.shape == (B,)
        assert grad_trans.shape == (B,)


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestHVP:

    def test_hvp_finite_diff(self, device):
        """Test HVP against finite differences."""
        B, L1, L2 = 2, 5, 6
        temperature = 1.0
        ins_cost, del_cost, trans_cost = 1.0, 1.0, 1.5
        eps = 1e-4

        torch.manual_seed(42)
        sub_costs = torch.rand(B, L1, L2, device=device)
        trans_src = create_trans_src(B, L1, L2, device=device)
        V = torch.randn(B, L1, L2, device=device)

        hvp_d2p = d2p.soft_damerau_hvp(sub_costs, trans_src, V, ins_cost, del_cost, trans_cost, temperature, None)

        posteriors_plus = soft_damerau_naive(sub_costs + eps * V, trans_src, ins_cost, del_cost, trans_cost, temperature)
        posteriors_minus = soft_damerau_naive(sub_costs - eps * V, trans_src, ins_cost, del_cost, trans_cost, temperature)
        hvp_fd = (posteriors_plus - posteriors_minus) / (2 * eps)

        assert allclose(hvp_fd, hvp_d2p, rtol=2e-2, atol=5e-3), \
            f"HVP mismatch: max diff = {max_diff(hvp_fd, hvp_d2p)}"

    def test_hvp_no_transpositions(self, device):
        """Test HVP with no transpositions."""
        B, L1, L2 = 2, 6, 7
        temperature = 1.0
        ins_cost, del_cost, trans_cost = 1.0, 1.0, 1.0
        eps = 1e-4

        torch.manual_seed(42)
        sub_costs = torch.rand(B, L1, L2, device=device)
        trans_src = torch.full((B, L1, L2, 2), -1, dtype=torch.int32, device=device)
        V = torch.randn(B, L1, L2, device=device)

        hvp_d2p = d2p.soft_damerau_hvp(sub_costs, trans_src, V, ins_cost, del_cost, trans_cost, temperature, None)

        posteriors_plus = soft_damerau_naive(sub_costs + eps * V, trans_src, ins_cost, del_cost, trans_cost, temperature)
        posteriors_minus = soft_damerau_naive(sub_costs - eps * V, trans_src, ins_cost, del_cost, trans_cost, temperature)
        hvp_fd = (posteriors_plus - posteriors_minus) / (2 * eps)

        assert allclose(hvp_fd, hvp_d2p, rtol=2e-2, atol=5e-3), \
            f"HVP mismatch (no trans): max diff = {max_diff(hvp_fd, hvp_d2p)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available")
class TestCPUCUDA:

    def test_consistency(self):
        """Test CPU vs CUDA produce identical results."""
        B, L1, L2 = 2, 8, 10
        temperature = 1.0
        ins_cost, del_cost, trans_cost = 1.0, 1.0, 1.5

        torch.manual_seed(42)
        sub_costs_cpu = torch.rand(B, L1, L2)
        trans_src_cpu = create_trans_src(B, L1, L2)
        sub_costs_cuda = sub_costs_cpu.cuda()
        trans_src_cuda = trans_src_cpu.cuda()

        posteriors_cpu = d2p.soft_damerau_float(sub_costs_cpu, trans_src_cpu, ins_cost, del_cost, trans_cost, temperature, None)[1]
        posteriors_cuda = d2p.soft_damerau_float(sub_costs_cuda, trans_src_cuda, ins_cost, del_cost, trans_cost, temperature, None)[1]

        assert allclose(posteriors_cpu, posteriors_cuda), \
            f"CPU/CUDA mismatch: max diff = {max_diff(posteriors_cpu, posteriors_cuda)}"

    def test_consistency_with_gaps(self):
        """Test CPU vs CUDA with transpositions that have gaps."""
        B, L1, L2 = 2, 10, 12
        temperature = 1.0
        ins_cost, del_cost, trans_cost = 0.5, 1.5, 0.8

        torch.manual_seed(42)
        sub_costs_cpu = torch.rand(B, L1, L2)
        trans_src_cpu = create_trans_src(B, L1, L2, density=0.5)
        sub_costs_cuda = sub_costs_cpu.cuda()
        trans_src_cuda = trans_src_cpu.cuda()

        posteriors_cpu = d2p.soft_damerau_float(sub_costs_cpu, trans_src_cpu, ins_cost, del_cost, trans_cost, temperature, None)[1]
        posteriors_cuda = d2p.soft_damerau_float(sub_costs_cuda, trans_src_cuda, ins_cost, del_cost, trans_cost, temperature, None)[1]

        assert allclose(posteriors_cpu, posteriors_cuda), \
            f"CPU/CUDA mismatch (gaps): max diff = {max_diff(posteriors_cpu, posteriors_cuda)}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestVariableLength:

    def test_variable_lengths(self, device):
        """Test with variable sequence lengths in batch."""
        B = 4
        max_L1, max_L2 = 10, 12
        temperature = 1.0
        ins_cost, del_cost, trans_cost = 1.0, 1.0, 1.5

        torch.manual_seed(42)
        sub_costs = torch.rand(B, max_L1, max_L2, device=device)
        trans_src = create_trans_src(B, max_L1, max_L2, device=device)

        # Variable lengths
        lengths = torch.tensor([
            [8, 10],
            [10, 12],
            [6, 8],
            [9, 11]
        ], device=device, dtype=torch.int32)

        distance, posteriors = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, lengths)

        # Check each batch element individually
        for b in range(B):
            l1, l2 = lengths[b].tolist()
            sub_costs_b = sub_costs[b:b+1, :l1, :l2]
            trans_src_b = trans_src[b:b+1, :l1, :l2, :]

            distance_ref, _ = soft_damerau_forward_naive(sub_costs_b, trans_src_b, ins_cost, del_cost, trans_cost, temperature)
            posteriors_ref = soft_damerau_naive(sub_costs_b, trans_src_b, ins_cost, del_cost, trans_cost, temperature)

            # Distance should match for this sequence
            assert allclose(distance_ref, distance[b:b+1]), \
                f"Distance mismatch for batch {b}: {distance_ref.item()} vs {distance[b].item()}"

            # Posteriors for valid region should match
            assert allclose(posteriors_ref, posteriors[b:b+1, :l1, :l2], rtol=1e-3, atol=1e-4), \
                f"Posteriors mismatch for batch {b}"


@pytest.mark.skipif(not D2P_AVAILABLE, reason="d2p not built")
class TestEdgeCases:

    def test_single_element(self, device):
        """Test 1x1 cost matrix."""
        sub_costs = torch.tensor([[[0.5]]], device=device)
        trans_src = torch.full((1, 1, 1, 2), -1, dtype=torch.int32, device=device)
        temperature = 1.0
        ins_cost = del_cost = trans_cost = 1.0

        distance, posteriors = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)
        distance_ref, _ = soft_damerau_forward_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)
        posteriors_ref = soft_damerau_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)

        assert allclose(distance, distance_ref), f"Single element distance wrong: {distance.item()} vs {distance_ref.item()}"
        assert allclose(posteriors, posteriors_ref), "Single element posterior wrong"

    def test_2x2_with_adjacent_transposition(self, device):
        """Test 2x2 cost matrix with adjacent transposition."""
        sub_costs = torch.tensor([[[1.0, 0.0], [0.0, 1.0]]], device=device)
        trans_src = torch.full((1, 2, 2, 2), -1, dtype=torch.int32, device=device)
        trans_src[0, 1, 1, 0] = 0  # k = 0
        trans_src[0, 1, 1, 1] = 0  # l = 0
        temperature = 1.0
        ins_cost = del_cost = 1.0
        trans_cost = 0.5

        distance, posteriors = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)
        distance_ref, _ = soft_damerau_forward_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)

        assert allclose(distance, distance_ref), f"2x2 trans distance mismatch: {max_diff(distance, distance_ref)}"

    def test_row_vector(self, device):
        """Test 1xN cost matrix."""
        sub_costs = torch.tensor([[[0.1, 0.2, 0.3]]], device=device)
        trans_src = torch.full((1, 1, 3, 2), -1, dtype=torch.int32, device=device)
        temperature = 1.0
        ins_cost = del_cost = trans_cost = 1.0

        distance, posteriors = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)
        distance_ref, _ = soft_damerau_forward_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)

        assert allclose(distance, distance_ref), "Row vector distance mismatch"

    def test_col_vector(self, device):
        """Test Nx1 cost matrix."""
        sub_costs = torch.tensor([[[0.1], [0.2], [0.3]]], device=device)
        trans_src = torch.full((1, 3, 1, 2), -1, dtype=torch.int32, device=device)
        temperature = 1.0
        ins_cost = del_cost = trans_cost = 1.0

        distance, posteriors = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)
        distance_ref, _ = soft_damerau_forward_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)

        assert allclose(distance, distance_ref), "Column vector distance mismatch"

    def test_low_temperature(self, device):
        """Test with low temperature (approaches hard Damerau)."""
        B, L1, L2 = 2, 6, 8
        temperature = 0.01
        ins_cost = del_cost = trans_cost = 1.0

        torch.manual_seed(42)
        sub_costs = torch.rand(B, L1, L2, device=device)
        trans_src = create_trans_src(B, L1, L2, device=device)

        posteriors = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)[1]

        # With low temperature, posteriors should be close to 0 or 1
        assert posteriors.min() >= -0.1, "Low temp posteriors should be >= 0"
        assert posteriors.max() <= 1.1, "Low temp posteriors should be <= 1"

    def test_high_temperature(self, device):
        """Test with high temperature (more uniform distribution)."""
        B, L1, L2 = 2, 6, 8
        temperature = 10.0
        ins_cost = del_cost = trans_cost = 1.0

        torch.manual_seed(42)
        sub_costs = torch.rand(B, L1, L2, device=device)
        trans_src = create_trans_src(B, L1, L2, device=device)

        posteriors = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)[1]
        posteriors_ref = soft_damerau_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)

        assert allclose(posteriors_ref, posteriors, rtol=1e-3, atol=1e-4), \
            "High temperature posteriors mismatch"

    def test_zero_temperature_clamp(self, device):
        """Test that very low temperature doesn't cause NaN."""
        B, L1, L2 = 2, 5, 6
        temperature = 1e-6
        ins_cost = del_cost = trans_cost = 1.0

        torch.manual_seed(42)
        sub_costs = torch.rand(B, L1, L2, device=device)
        trans_src = create_trans_src(B, L1, L2, device=device)

        distance, posteriors = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)

        assert not torch.isnan(distance).any(), "Distance contains NaN"
        assert not torch.isnan(posteriors).any(), "Posteriors contain NaN"
        assert not torch.isinf(distance).any(), "Distance contains Inf"

    def test_distant_transposition(self, device):
        """Test transposition from distant source (gap of 3)."""
        B, L = 2, 8
        temperature = 1.0
        ins_cost = del_cost = 1.0
        trans_cost = 0.5

        torch.manual_seed(42)
        sub_costs = torch.rand(B, L, L, device=device)
        trans_src = torch.full((B, L, L, 2), -1, dtype=torch.int32, device=device)
        # Position (6, 6) can transpose from (2, 2) - gap of 3 each
        trans_src[:, 5, 5, 0] = 2
        trans_src[:, 5, 5, 1] = 2

        distance, posteriors = d2p.soft_damerau_float(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, None)
        distance_ref, _ = soft_damerau_forward_naive(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature)

        assert allclose(distance, distance_ref), \
            f"Distant trans distance mismatch: {max_diff(distance, distance_ref)}"


if __name__ == '__main__':
    pytest.main([__file__, '-v'])
