from torch import nn
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CustomLoss(nn.Module):
    def __init__(self, beta, batch_size, weight_word = False, word_weight_index = 0, weight_value = 0):
        super().__init__()
        self.margin = beta
        self.batch_size = batch_size
        self.weight_word = weight_word
        self.word_weight_index = word_weight_index
        self.weight_value = weight_value

    def forward(self, sim, index, word):
        '''
        outputs: 予測結果(ネットワークの出力)
　　　　 targets: 正解
        '''

        if self.weight_word:
            weight = torch.ones(len(index))
            weight_idxs = list()
            unweight_idxs = list()
            for i in range(len(word)):
                if word[i] == self.word_weight_index:
                    weight[i] = self.weight_value
        # print(weight)
        sorted_index = torch.argsort(index)
        mask = torch.eye(sim.size(1),device="cuda")[index] > .5
        # mask = mask.to(device)
        same = sim[mask==True].unsqueeze(0).T
        same = same.repeat(1,sim.size(1))
        tmp2 = torch.mul(weight.unsqueeze(1).repeat(1,5000).to(device), sim)
        loss = (self.margin + sim - same).clamp(min=0.0)
        #loss = #torch.mul(weight.unsqueeze(1).repeat(1,5000).to(device),tmp2)
        loss.masked_fill_(mask, 0.0)
        # print(loss)
        # print(loss.size())
        return loss.mean()


        # diagonal = sim.diag().view(img.size(0), 1)
        # same = diagonal.expand_as(sim)
        # mask = torch.eye(sim.size(0)) > .5
        # I = torch.autograd.Variable(mask)
        # if torch.cuda.is_available():
        #     I = I.cuda()

        # # 損失の計算
        # # loss = ...
        # return loss.mean()