import torch
import torch.nn.functional as F

torch.manual_seed(0)

N = 3
M = N
logits_per_image = torch.randn(N, M)
logits_per_text = logits_per_image.T
# gt = torch.tensor([
#     [1, 1, 0, 0, 0, 0],
#     [0, 0, 1, 1, 0, 0],
#     [0, 0, 0, 0, 1, 1]
# ])
gt = torch.eye(N)


# CE
loss_img = torch.nn.CrossEntropyLoss()
loss_txt = torch.nn.CrossEntropyLoss()

_loss_img = loss_img(logits_per_image, gt)
_loss_txt = loss_txt(logits_per_text, gt)
loss = (_loss_img + _loss_txt) / 2


print(logits_per_image)
print(_loss_img, _loss_txt)
print(loss)


# CE
gt_per_img = gt
gt_per_txt = gt_per_img.T

gt_per_img = gt_per_img / gt_per_img.sum(dim=1, keepdim=True)
gt_per_txt = gt_per_txt / gt_per_txt.sum(dim=1, keepdim=True)

_loss_img = loss_img(logits_per_image, gt_per_img)
_loss_txt = loss_txt(logits_per_text, gt_per_txt)
loss = (_loss_img + _loss_txt) / 2

# print(probs_per_img)
print(_loss_img, _loss_txt)
print(loss)