import torch

class FedAVG:
    def __init__(self, eps, lr=0.01, beta=0.0, device='cuda'):
        self.lr = lr
        self.beta = beta
        self.device = device
        self.eps = eps

    def train(self, model, dataloader, criterion, epochs):
        g_dict = {}
        model.train()
        state = model.state_dict()

        for _ in range(epochs):
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)
                model.zero_grad()
                output = model(x)
                loss = criterion(output, y)
                l2_reg = sum((param ** 2).sum() for param in model.parameters() if param.requires_grad)
                loss = loss + (self.eps / 2) * l2_reg
                for name, param in model.named_parameters():
                    if not param.requires_grad:
                        continue
                    g = torch.autograd.grad(loss, param, retain_graph=True)[0]
                    param.data -= self.lr * g
                    with torch.no_grad():
                        if name not in g_dict:
                            g_dict[name] = torch.zeros_like(g)
                        g_dict[name] += g
        model.load_state_dict(state)
        return g_dict