import torch
import torch.nn as nn
import torch.nn.functional as F


class GEM(nn.Module):
    def __init__(self, a, b):
        super().__init__()
        self.a = a
        self.b = b

    def forward(self, z, y):
        ce = F.cross_entropy(z, y, reduction='none')
        ce_mean = torch.mean(ce)
        second_moment = torch.mean(ce**2)
        squared_mean = ce_mean**2
        loss = ce_mean + self.a * second_moment + self.b * squared_mean
        return loss, ce_mean