import torch


def extract_image_tokens(features: torch.FloatTensor, mask: torch.Tensor):
    # assume that the number of image tokens are the same across each sample.
    # this means that the number of images as well the image sizes are the same.
    X, M = features, mask
    B, L, C = X.shape
    device = X.device
    mask = mask.to(device)

    # Compute number of valid elements per batch
    valid_counts = M.sum(dim=1)  # Shape: [B]
    R = valid_counts.max().item()  # Max number of selected elements across all batches

    if R == 0:  # Edge case: No valid elements
        return torch.zeros((B, 1, C), dtype=X.dtype, device=device), torch.zeros(
            (B, 1), dtype=torch.bool, device=device
        )

    # Create index tensors for selection
    sorted_indices = M.argsort(dim=1, descending=True)  # Move True values to front
    batch_indices = torch.arange(B, device=device).unsqueeze(1).expand(B, L)

    # Gather sorted X based on mask sorting
    X_sorted = X[batch_indices, sorted_indices]  # Shape: [B, L, C]
    X_selected = X_sorted[:, :R, :]  # Select the top valid elements per batch

    # Create new mask M2
    M2 = torch.arange(R, device=device).expand(B, R) < valid_counts.unsqueeze(
        1
    )  # Shape: [B, R]
    X_selected[~M2.unsqueeze(-1).expand_as(X_selected)] = 0

    return X_selected, M2


def _test_extract_image_tokens(features, mask):
    # Verify shape consistency
    X, M = features, mask

    X_padded, M2 = extract_image_tokens(X, M)
    assert X_padded.shape[0] == X.shape[0], "Batch size should remain the same"
    if mask.sum() > 0:
        assert (
            X_padded.shape[1] == M.sum(dim=1).max().item()
        ), "Output length should match max valid selections"
    assert X_padded.shape[2] == X.shape[2], "Feature dimension should remain the same"
    assert (
        M2.shape == X_padded.shape[:2]
    ), "Mask shape should match first two dimensions of output"

    # Verify correctness of selection and placement
    for i in range(X.shape[0]):  # Iterate over batches
        selected_X = X[i][M[i]]  # Manually select valid elements
        padded_X = X_padded[
            i, : len(selected_X)
        ]  # Extract corresponding valid part from output

        # Check if all selected values in X are copied correctly in X_padded
        for j in range(len(selected_X)):
            assert torch.allclose(
                selected_X[j], padded_X[j], atol=1e-6
            ), f"Mismatch in selected values for batch {i}, position {j}"

        # Ensure remaining padding values (if any) are correctly set
        if len(selected_X) < X_padded.shape[1]:
            assert torch.all(
                X_padded[i, len(selected_X) :] == 0
            ), f"Padding values incorrect for batch {i}"

        # Verify mask correctness
        expected_mask = torch.arange(X_padded.shape[1]) < len(selected_X)
        assert torch.equal(M2[i], expected_mask), f"Mismatch in mask for batch {i}"


def test_extract_image_tokens():
    B, L, C = 3, 6, 4
    X = torch.randn(B, L, C)
    M = torch.tensor(
        [[1, 0, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1], [1, 1, 0, 0, 1, 1]], dtype=torch.bool
    )
    test_cases = [
        {
            "name": "General Case",
            "X": torch.randn(3, 6, 4),
            "M": torch.tensor(
                [[1, 0, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1], [1, 1, 0, 0, 1, 1]],
                dtype=torch.bool,
            ),
        },
        {
            "name": "All True Mask",
            "X": torch.randn(3, 6, 4),
            "M": torch.ones((3, 6), dtype=torch.bool),
        },
        {
            "name": "All False Mask",
            "X": torch.randn(3, 6, 4),
            "M": torch.zeros((3, 6), dtype=torch.bool),
        },
        {
            "name": "Different Number of Selected Elements",
            "X": torch.randn(3, 6, 4),
            "M": torch.tensor(
                [[1, 1, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
                dtype=torch.bool,
            ),
        },
        {
            "name": "Single Batch Case",
            "X": torch.randn(1, 6, 4),
            "M": torch.tensor([[1, 0, 1, 0, 1, 0]], dtype=torch.bool),
        },
    ]
    for test in test_cases:
        name, X, M = test["name"], test["X"], test["M"]
        print(f"Running test: {name}")

        # Run function
        _test_extract_image_tokens(X, M)


def apply_vocab_mask(logits, mask, V_x):
    """
    Applies a vocabulary mask to the second part of logits without constructing a large BLV mask,
    ensuring proper gradient flow.

    Args:
        logits (torch.Tensor): Logit tensor of shape [B, L, (V_x + V_y_max)].
        mask (torch.Tensor): Binary mask of shape [B, V_y_max] (1 = valid, 0 = invalid).
        V_x (int): Size of fixed vocabulary.

    Returns:
        logits (torch.Tensor): Masked logits with invalid positions set to a large negative value.
    """
    B, L, V_total = logits.shape
    V_y_max = V_total - V_x  # Compute the dynamic vocab size

    # Split logits into fixed (X) and dynamic (Y) parts
    logits_x = logits[:, :, :V_x]  # Shape: [B, L, V_x]
    logits_y = logits[:, :, V_x:]  # Shape: [B, L, V_y_max]

    # Expand mask for broadcasting over L
    mask = mask.unsqueeze(1)  # Shape: [B, 1, V_y_max]

    # Use `torch.where` to replace invalid logits with a large negative value
    logits_y = torch.where(
        mask, logits_y, torch.tensor(-1e9, device=logits.device, dtype=logits.dtype)
    )

    # Concatenate back
    return torch.cat([logits_x, logits_y], dim=-1)  # Shape: [B, L, (V_x + V_y_max)]


def test_apply_vocab_mask():
    """
    Tests apply_vocab_mask() to verify correct masking of logits while maintaining gradient flow.
    """
    print(f"Running test: vocab_mask")
    # Define test parameters
    B, L, V_x, V_y_max = 2, 4, 3, 5  # Batch size, sequence length, vocab sizes
    V_total = V_x + V_y_max

    # Create random logits
    torch.manual_seed(42)
    logits = torch.randn(B, L, V_total, requires_grad=True)  # Ensure gradient tracking

    # Create a binary mask for the dynamic vocabulary part
    mask = torch.tensor(
        [[1, 0, 1, 1, 0], [1, 1, 0, 0, 1]], dtype=torch.bool
    )  # Shape: [B, V_y_max]

    # Save original logits for comparison
    logits_clone = logits.clone().detach()

    # Apply function
    masked_logits = apply_vocab_mask(logits, mask, V_x)

    # Test 1: Output shape matches input shape
    assert masked_logits.shape == (B, L, V_total), "Output shape mismatch"

    # Extract fixed (X) and dynamic (Y) vocab parts
    masked_logits_x = masked_logits[:, :, :V_x]
    masked_logits_y = masked_logits[:, :, V_x:]

    logits_x = logits_clone[:, :, :V_x]
    logits_y = logits_clone[:, :, V_x:]

    # Test 2: Ensure original logits_x remain unchanged
    assert torch.allclose(
        masked_logits_x, logits_x
    ), "Fixed vocabulary logits (logits_x) should remain unchanged"

    # Test 3: Ensure invalid logits in logits_y are replaced with -1e9
    for b in range(B):
        for v in range(V_y_max):
            if mask[b, v] == 0:  # Check masked entries
                assert torch.allclose(
                    masked_logits_y[b, :, v], torch.tensor(-1e9, dtype=logits.dtype)
                ), f"Invalid logit at B={b}, V_y={v} not properly masked"
            else:
                assert torch.allclose(
                    masked_logits_y[b, :, v], logits_y[b, :, v]
                ), f"Valid logit at B={b}, V_y={v} was incorrectly changed"

    # Test 4: Ensure gradients flow properly
    loss = masked_logits.sum()  # Dummy loss
    loss.backward()
    assert logits.grad is not None, "Gradients did not propagate"
    assert not torch.isnan(logits.grad).any(), "NaN gradient detected"

    # Test 5: Ensure softmax does not produce NaNs
    softmax_output = torch.nn.functional.softmax(masked_logits, dim=-1)
    assert not torch.isnan(softmax_output).any(), "Softmax produced NaN values"


test_extract_image_tokens()
test_apply_vocab_mask()
print("done")
