import torch.nn as nn
import torch


class MaskedCELoss(nn.Module):
    def __init__(self):
        super(MaskedCELoss, self).__init__()
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        self.mse = nn.MSELoss(reduction='none')
        self.lhTargVar = 0

    def forward(self, cfgs, model, y_hat, final_rnn_outputs, target, mask):
        target = torch.reshape(target, (target.size(0), target.size(-1)))
        loss = self.mse(y_hat, target)
        loss = loss.sum(-1) * mask.float()
        loss = loss.sum() / mask.sum()
        self.hNorm = torch.mean(torch.square(final_rnn_outputs))
        loss_wh = torch.abs(self.hNorm - self.lhTargVar)
        loss += loss_wh * cfgs['l2_h']
        return loss
