import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import auraloss


class Criterion(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.mse_loss = torch.nn.MSELoss()
        self.mrft_loss = auraloss.freq.MultiResolutionSTFTLoss(w_lin_mag=1, w_phs=1)

    def forward(self, pred_ir, gt_ir):
        scalar_stats = {}
        mse_loss = self.mse_loss(pred_ir, gt_ir) + self.mrft_loss(pred_ir, gt_ir)
        scalar_stats['mse_loss'] = mse_loss
        return scalar_stats