import torch.nn as nn
import torch


class MaskedCELoss(nn.Module):
    def __init__(self):
        super(MaskedCELoss, self).__init__()
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        self.ce = nn.CrossEntropyLoss(reduction='none')
        self.lhTargVar = 0

    def forward(self, cfgs, model, y_hat, final_rnn_outputs, target, mask):
        target = torch.reshape(target, (target.size(0), target.size(-1))).argmax(dim=-1)
        loss = self.ce(y_hat, target)
        loss = loss * mask.float()
        loss = loss.sum() / mask.sum()
        self.hNorm = torch.mean(torch.square(final_rnn_outputs))
        loss_wh = torch.abs(self.hNorm - self.lhTargVar)
        wi = model.state_dict()['dense.weight'][:, :cfgs['num_input']]
        loss_wi = torch.mean(torch.square(wi))
        wr = model.state_dict()['dense.weight'][:, cfgs['num_input']:cfgs['num_input'] + cfgs['num_rnn']]
        loss_wr = torch.mean(torch.square(wr))
        wo = model.state_dict()['out.weight']
        loss_wo = torch.mean(torch.square(wo))
        loss += loss_wr * cfgs['l2_wR']
        loss += loss_wh * cfgs['l2_h']
        loss += loss_wi * cfgs['l2_wI']
        loss += loss_wo * cfgs['l2_wO']
        return loss
