import torch

class FedSophia:
    def __init__(self, lr=0.01, eps=0.1, beta1=0.9, beta2=0.95, device='cuda'):
        self.lr = lr
        self.eps = eps
        self.beta1 = beta1
        self.beta2 = beta2
        self.device = device
        self.mG = {}  
        self.mH = {}  

    def train(self, model, dataloader, criterion, epochs, rnd, batch_size):
        HG_dict = {}
        model.train()
        state = model.state_dict()

        for _ in range(epochs):
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)
                model.zero_grad()
                out = model(x)
                loss = criterion(out, y)
                l2_reg = sum((param ** 2).sum() for param in model.parameters() if param.requires_grad)
                loss = loss + (self.eps / 2) * l2_reg

                for name, param in model.named_parameters():
                    if not param.requires_grad:
                        continue
                    G = torch.autograd.grad(loss, param, create_graph=True)[0]
                    with torch.no_grad():
                        if name not in self.mG:
                            self.mG[name] = G.clone()
                        mg = self.mG[name]
                        mg = self.beta1 * mg + (1 - self.beta1) * G

                        if rnd % 2 == 0:
                            # probs = torch.softmax(outputs, dim=1)
                            # pseudo_labels = torch.multinomial(probs, num_samples=1).squeeze(-1)
                            # pseudo_loss = criterion(outputs, pseudo_labels)
                            # pseudo_g = torch.autograd.grad(pseudo_loss/bs, param, create_graph=True)[0]
                            H = batch_size * G.pow(2)
                            if name not in self.mH:
                                self.mH[name] = H.clone()
                            mh = self.mH[name]
                            mh = self.beta2 * mh + (1 - self.beta2) * H
                        else:
                            mh = self.mH.get(name, torch.ones_like(G))

                        HG = mg / torch.maximum(mh, torch.tensor(1e-3, device=self.device))
                        HG = HG.clamp(-1e4, 1e4)

                    param.data -= self.lr * HG

                    with torch.no_grad():
                        self.mG[name] = mg.clone()
                        self.mH[name] = mh.clone()
                        if name not in HG_dict:
                            HG_dict[name] = torch.zeros_like(HG)
                        HG_dict[name] += HG

        model.load_state_dict(state)
        return HG_dict