import torch
import torch.nn as nn
import torch.nn.functional as F
from tools.EarlyStopping import Determine_Inf_Nan

def loss_function(net_E, net_H, real_E, real_H):
    Determine_Inf_Nan(real_E, 'real_E')
    Determine_Inf_Nan(net_E, 'net_E')

    mse_loss_E = F.mse_loss(net_E, real_E)
    mse_loss_H = F.mse_loss(net_H, real_H)
    mse_loss = mse_loss_E + mse_loss_H
    return mse_loss

def evaluate_model(model, loader):
    """
    Accumulate MSE over a data list or loader.
    """
    model.eval()
    with torch.no_grad():
        loss = 0
        for data in loader:
            data.cuda()
            net_E, net_H = model(data)
            loss += loss_function(net_E, net_H, data.y_E, data.y_H)
        return loss / len(loader)