import torch.nn as nn
import torch


class MaskedCELoss(nn.Module):
    def __init__(self):
        super(MaskedCELoss, self).__init__()
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        self.mse = nn.MSELoss(reduction='none')
        self.lhTargVar = 0

    def forward(self, cfgs, model, y_hat, final_rnn_outputs, target, mask):
        target = torch.reshape(target, (target.size(0), target.size(-1)))
        loss = self.mse(y_hat, target)
        loss = loss.sum(-1) * mask.float()
        loss = loss.sum() / mask.sum()

        if cfgs['model_type'] == 'custom-rnn':
            self.hNorm = torch.mean(torch.square(final_rnn_outputs))
            loss_wh = torch.abs(self.hNorm - self.lhTargVar)
            wi = model.state_dict()['in2rnn']
            loss_wi = torch.mean(torch.square(wi))
            wr = model.state_dict()['kernel']
            loss_wr = torch.mean(torch.square(wr))
            wo = model.state_dict()['w_rnn_out']
            loss_wo = torch.mean(torch.square(wo))
            loss += loss_wr * cfgs['l2_wR']
            loss += loss_wh * cfgs['l2_h']
            loss += loss_wi * cfgs['l2_wI']
            loss += loss_wo * cfgs['l2_wO']

        elif cfgs['model_type'] == 'custom-snn-optim':
            self.hNorm = torch.mean(torch.square(final_rnn_outputs))
            loss_wh = torch.abs(self.hNorm - self.lhTargVar)
            wi = model.state_dict()['dense.weight'][:, :cfgs['num_input']]
            loss_wi = torch.mean(torch.square(wi))
            wr = model.state_dict()['dense.weight'][:, cfgs['num_input']:cfgs['num_input'] + cfgs['num_rnn']]
            loss_wr = torch.mean(torch.square(wr))
            wo = model.state_dict()['out.weight']
            loss_wo = torch.mean(torch.square(wo))
            loss += loss_wr * cfgs['l2_wR']
            loss += loss_wh * cfgs['l2_h']
            loss += loss_wi * cfgs['l2_wI']
            loss += loss_wo * cfgs['l2_wO']

        else:
            self.hNorm = torch.mean(torch.square(final_rnn_outputs))
            loss_wh = torch.abs(self.hNorm - self.lhTargVar)
            wi = model.state_dict()['in2rnn.weight']
            loss_wi = torch.mean(torch.square(wi))
            wr = model.state_dict()['rec.weight']
            loss_wr = torch.mean(torch.square(wr))
            wo = model.state_dict()['out.weight']
            loss_wo = torch.mean(torch.square(wo))
            loss += loss_wr * cfgs['l2_wR']
            loss += loss_wh * cfgs['l2_h']
            loss += loss_wi * cfgs['l2_wI']
            loss += loss_wo * cfgs['l2_wO']
        return loss
