import torch

class SCAFFOLD:
    def __init__(self, eps, lr=0.01, max_norm=15, device='cuda'):
        self.lr = lr
        self.max_norm = max_norm
        self.device = device
        self.c_global = {}
        self.c_local = {}  # key: rank -> dict[name -> tensor]
        self.eps = eps

    def init_global(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                if name not in self.c_global:
                    self.c_global[name] = torch.zeros_like(param)

    def train(self, rank, model, dataloader, criterion, epochs):
        self._init_local_if_needed(rank, model)

        model.train()
        state = model.state_dict()
        x_del = {}
        k = 0

        for _ in range(epochs):
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)
                model.zero_grad()
                output = model(x)
                loss = criterion(output, y)
                l2_reg = sum((param ** 2).sum() for param in model.parameters() if param.requires_grad)
                loss = loss + (self.eps / 2) * l2_reg

                gradients = {}
                for name, param in model.named_parameters():
                    if not param.requires_grad:
                        continue
                    g = torch.autograd.grad(loss, param, retain_graph=True)[0]
                    norm = g.norm()
                    if norm > self.max_norm:
                        g = g * (self.max_norm / norm)
                    gradients[name] = g

                for name, param in model.named_parameters():
                    if not param.requires_grad:
                        continue
                    g = gradients[name]
                    update = -self.lr * (g - self.c_local[rank][name] + self.c_global[name])
                    param.data += update
                    with torch.no_grad():
                        if name not in x_del:
                            x_del[name] = torch.zeros_like(param)
                        x_del[name] += update
                k += 1


        c_del = {}
        for name in self.c_local[rank]:
            with torch.no_grad():
                c_tmp = self.c_local[rank][name].clone()
                self.c_local[rank][name] -= self.c_global[name] + x_del[name] / (k * self.lr)
                c_del[name] = self.c_local[rank][name] - c_tmp

        model.load_state_dict(state)
        return x_del, c_del

    def _init_local_if_needed(self, rank, model):
        if rank not in self.c_local:
            self.c_local[rank] = {}
            for name, param in model.named_parameters():
                if param.requires_grad:
                    self.c_local[rank][name] = self.c_global[name].clone()