import torch

def GroupMSE(X, y, network_output, group_ID):

    X_for_group = X[y==group_ID,:]
    network_output_for_group = network_output[y==group_ID,:]

    MSE_for_group = torch.mean(torch.mean((X_for_group - network_output_for_group)**2, dim=1))

    return MSE_for_group