import torch

# a = torch.arange(0, 20).view(4, 5)
# print(a[None, :, :].repeat(4,1,1))

loss_func1 = torch.nn.MarginRankingLoss(0.1, reduction='sum')


# a = torch.randn(3, 10)
# b = torch.randn(3, 10)
# ones = torch.ones(a.size())
# print(torch.max(a, dim=1).values)
# batch_size = 2
# a = torch.tensor([[5, -1], [2, -2]]).long()
# b = torch.arange(0, batch_size).reshape(batch_size, 1).repeat(1, batch_size).long()
# c = torch.where(a >= 0, a, b)
# print(c)
# a = torch.randn(2, 4)
# print(a)
# a = torch.randn(2, 4, 5)
# b = torch.tensor([2, 3])
# b = b[:, None, None].repeat(1, 1, 5)
# print(a[0][2])
# print(a[1][3])
# print(torch.gather(a, 1, b))

# print(loss_func1(a, b, ones).sum())
# a = a.contiguous().view(-1)
# b = b.contiguous().view(-1)
# ones = torch.ones(a.size())
# print(loss_func1(a, b, ones).sum())
# model_dict = "skip_warmup_ckpts/iwslt14-de-en/model-warmed-up.pt"
# state_dict = torch.load(model_dict, map_location='cpu')
# # optimizer = FairseqAdam()
# print(state_dict.keys())

# max_len = torch.max(cand_len)
# print(max_len.item())
def strip_pad(tensor, pad):
    return tensor[tensor.ne(pad)]


# a = torch.ones(2,4)
# a[0][0] =2
# print(strip_pad(a, 1))
import numpy as np
print(-np.log(50)+ 3.0)