import torch
import torch.nn as nn

def gather_kl_tensor(tensor, index):
    bsz, L, V = tensor.size()
    
    # Create a mask to identify non-zero indices
    mask = (index != 0)
    
    # Create a new tensor to store the modified values
    # masked_tensor = tensor[mask]
    modified_tensor = torch.zeros(bsz, L, V - 1, dtype=tensor.dtype, device=tensor.device)
    
    # Iterate over each batch and sequence position
    for b in range(bsz):
        for l in range(L):
            if mask[b, l]:
                # Get the non-zero index value
                idx = index[b, l].item()
                
                # Select the corresponding values in the tensor without the index
                selected_values = torch.cat((tensor[b, l, :idx], tensor[b, l, idx+1:]))
                
                # Assign the selected values to the modified tensor
                modified_tensor[b, l] = selected_values
    
    return modified_tensor[mask] # [N, D-1]

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(5, 10)  # Linear layer mapping from 5 dimensions to 10 dimensions
    
    def forward(self, tensor, index):
        modified_tensor = gather_kl_tensor(tensor, index)
        output = self.linear(modified_tensor)
        return output

def check_gather_kl_grad():
    # Create an instance of the model
    model = MyModel()

    # Create example input tensor and index tensor
    tensor = torch.randn(2, 3, 6, requires_grad=True)  # Input tensor of size [2, 3, 6] with requires_grad=True
    index = torch.tensor([[0, 2, 5], [1, 0, 3]])  # Index tensor of size [2, 3]

    # Forward pass
    output = model(tensor, index)

    # Compute loss
    loss = output.sum()

    # Backward pass
    loss.backward()

    # Access the gradients
    print("Gradients of the input tensor:")
    print(tensor.grad)


if __name__ == '__main__':
    # # Example usage
    # tensor = torch.randn(2, 3, 6)  # Example tensor of size [2, 3, 6]
    # index = torch.tensor([[0, 2, 5], [1, 0, 3]])  # Example index tensor of size [2, 3]

    # modified_tensor = gather_kl_tensor(tensor, index)
    # print(modified_tensor.size())  # Output: torch.Size([2, 3, 5])

    check_gather_kl_grad()