import torch

def randomly_sample_triplets(b, f, h, w, num_triplets, neighborhood_size, sampling_weights):
    if sampling_weights is not None:
        assert sampling_weights.shape == (b,f,h,w)
    batch_idx = torch.randint(0, b, (num_triplets,))
    i = torch.randint(0, f-1, (num_triplets,))

    if sampling_weights is not None:
        selected_weights = sampling_weights[batch_idx, i, :, :]
        selected_weights_flat = selected_weights.view(num_triplets, -1)
        sampled_indices = torch.multinomial(selected_weights_flat, 1, replacement=False).squeeze()
        x_A = sampled_indices % h
        y_A = sampled_indices // h
    else:
        x_A = torch.randint(0, h, (num_triplets,))
        y_A = torch.randint(0, w, (num_triplets,))

    offsets_x = torch.randint(-neighborhood_size, neighborhood_size + 1, (num_triplets, 2)).to(x_A.device)
    offsets_y = torch.randint(-neighborhood_size, neighborhood_size + 1, (num_triplets, 2)).to(y_A.device)

    x_B = torch.clamp(x_A + offsets_x[:, 0], 0, h - 1)
    y_B = torch.clamp(y_A + offsets_y[:, 0], 0, w - 1)
    x_C = torch.clamp(x_A + offsets_x[:, 1], 0, h - 1)
    y_C = torch.clamp(y_A + offsets_y[:, 1], 0, w - 1)

    return batch_idx, i, x_A, y_A, x_B, y_B, x_C, y_C


def compute_distances(feature_map, batch_idx, i, x_A, y_A, x_B, y_B, x_C, y_C):
    A = feature_map[batch_idx, :, i, x_A, y_A]
    B = feature_map[batch_idx, :, i+1, x_B, y_B]
    C = feature_map[batch_idx, :, i+1, x_C, y_C]

    d_AB = torch.norm(A - B, dim=1)
    d_AC = torch.norm(A - C, dim=1)

    return d_AB, d_AC