import torch

# 1. Simulate a batch of 3 samples with 4 possible classes
# Suppose these are the Softmax probabilities from your model
probs = torch.tensor([
    [0.145, 0.8, 0.05, 0.05], # Sample 0: Model thinks it's Class 1 (0.8)
    [0.2, 0.2, 0.5, 0.15],   # Sample 1: Model thinks it's Class 2 (0.5)
    [0.9, 0.0, 0.05, 0.05]  # Sample 2: Model thinks it's Class 0 (0.9)
])
print(f"Probabilities Shape: {probs.shape}") # [3, 4]
# 2. Simulate the Ground Truth labels for these 3 samples
# (This is the label provided by your DataLoader)
label = torch.tensor([1, 0, 0]) 

print(f"Original Label Shape: {label.shape}") # [3]

# 3. Step-by-Step 'Gather' operation
# Step A: Unsqueeze the label to match the 'probs' dimensions [Batch, 1]
label_expanded = label.unsqueeze(1) 
print(f"Expanded Label Shape: {label_expanded.shape}") # [3, 1]

# Step B: Gather along dimension 1 (the class dimension)
# It looks at row i and picks the value at the column index specified by label_expanded[i]
print("Label Expanded for Gathering:", label_expanded)
true_class_probs_raw = probs.gather(1, label_expanded)

# Step C: Squeeze back to a 1D array of scores
true_class_probs = true_class_probs_raw.squeeze()

print("\nResults:")
print(f"Probabilities Matrix:\n{probs}")
print(f"Ground Truth Labels: {label.tolist()}")
print(f"Selected Confidence Scores: {true_class_probs.tolist()}")

# Verification Logic
# Sample 0: Label 1 -> picks 0.8
# Sample 1: Label 3 -> picks 0.1
# Sample 2: Label 0 -> picks 0.9