import torch
from torch.utils.data import Subset, DataLoader
import numpy as np
import random
import torchvision
import torchvision.transforms as transforms
import argparse

def getDirichletData(data, targets, psizes, alpha):
    """devide data by dirichlet with alpha"""
    n_nets = psizes
    K = len(torch.unique(targets))
    labelList = targets
    min_size = 0
    N = len(labelList)

    net_dataidx_map = {}
    while min_size < K:
        idx_batch = [[] for _ in range(n_nets)]
        # for each class in the dataset
        for k in range(K):
            idx_k = np.where(labelList == k)[0] #labelList[k][0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, n_nets)) #dirichlet distribution
            ## Balance
            proportions = np.array([p*(len(idx_j)<N/n_nets) for p,idx_j in zip(proportions,idx_batch)])
            proportions = proportions/proportions.sum()
            proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(n_nets):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]
        
    net_cls_counts = {}

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(labelList[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp
    print('Data statistics: %s' % str(net_cls_counts))

    local_sizes = []
    for i in range(n_nets):
        local_sizes.append(len(net_dataidx_map[i]))
    local_sizes = np.array(local_sizes)
    weights = local_sizes/np.sum(local_sizes)
    print(weights)
    partitions = [Subset(data, idx_batch[i]) for i in range(n_nets)]
    return partitions, weights

def dataload(size, isNonIID = False, alpha = 1):
    data_tf = transforms.Compose([
    transforms.ToTensor(),                
    transforms.Normalize([0.5], [0.5]),     
    transforms.Lambda(lambda x: x.view(-1))  
    ])
    trainset = torchvision.datasets.MNIST(root='./../data', train=True, download=False, transform=data_tf)
    testset = torchvision.datasets.MNIST(root='./../data', train=False, download=False, transform=data_tf)
    targets, classnum = trainset.targets, 10
    if isNonIID:
        partitions, ratio = getDirichletData(trainset, targets, size, alpha)
    else:
        partitions = []
        datalen =  len(trainset)
        indexes = [x for x in range(0,datalen)]
        random.shuffle(indexes)
        partlen = int(datalen / size)
        for index in range(size):
            partitions.append(Subset(trainset, indexes[0:partlen]))
            indexes = indexes[partlen:]
        ratio = [1 / size] * size
    return partitions, ratio, testset, classnum

def softmax(z):
    exp_z = torch.exp(z - torch.max(z, dim=1, keepdim=True).values)
    return exp_z / exp_z.sum(dim=1, keepdim=True)

def cross_entropy_loss(probs, labels):
    one_hot_labels = torch.zeros_like(probs).scatter_(1, labels.view(-1, 1), 1)
    return -torch.sum(one_hot_labels * torch.log(probs)) / labels.size(0)

def federated_aggregate(local_weights, weights):
    global_weight = sum(local_weights[i] * weights[i] for i in range(len(local_weights)))
    return global_weight

def federated_aggregate_grad_hess(local_grads, local_hessians, weights):
    global_grad = sum(local_grads[i] * weights[i] for i in range(len(local_grads)))
    global_hessian = sum(local_hessians[i] * weights[i] for i in range(len(local_hessians)))
    return global_grad, global_hessian

def sparse_compression(H, sparsity=0.1):
    H_compressed = H.clone()

    d = H.size(0)

    mask = torch.rand((d, d), device=H.device) < sparsity
    mask = torch.triu(mask, diagonal=1) 

    mask = mask + mask.T 

    H_compressed[~mask] = 0.0

    return H_compressed
def low_rank_compression(H, rank_k=10):

    U, S, Vh = torch.linalg.svd(H, full_matrices=False)
    d = H.size(0)


    H_compressed = torch.zeros((d, d), device=H.device) 
    delta = 0 

    for i in range(rank_k):

        H_compressed += S[i] * torch.ger(U[:, i], Vh[i, :]) 
        delta += S[i] ** 2  

    total_variance = torch.sum(S ** 2) + 1e-15 
    delta = delta / total_variance

    return H_compressed, delta
def unbiased_compressor(matrix, sparsity=0.1):
    """
     E[C(M)] = M
    """
    mask = (torch.rand_like(matrix) < sparsity).float()
    scale = 1.0 / sparsity
    return matrix * mask * scale
def contractive_compressor(matrix, delta=0.5):
    """
    ‖C(M) - M‖² ≤ (1 - δ)‖M‖²
    """
    noise = torch.randn_like(matrix) * torch.norm(matrix) * ((1 - delta) ** 0.5)
    compressed = matrix - noise
    return compressed


def hessian(loss, param, sampling_ratio=0.0003):
    g = torch.autograd.grad(loss, param, create_graph=True)[0]
    grad = g.view(-1) 
    D = g.numel()    
    d = int(sampling_ratio * D)
    d = max(1, min(d, D))  

    HdD = torch.empty((d, D), requires_grad=False, device=device)

    g_squared = grad ** 2
    leverage_scores = g_squared / g_squared.sum()
    sampling_probabilities = leverage_scores / leverage_scores.sum()

    N_idx = torch.multinomial(sampling_probabilities, d, replacement=False)

    for i, idx in enumerate(N_idx):
        HdD[i] = torch.autograd.grad(grad[idx], param, retain_graph=True, create_graph=False)[0].view(-1)

    H_term = HdD[:, N_idx].clone()
    for i, idx in enumerate(N_idx):
        HdD[:, idx] = HdD[:, i]
    HdD[:, :d] = H_term

    return HdD, g

def woodbury_update(grad, HdD, eps=1e-5, clp=1e-4):
    g = grad.view(-1)
    # H_mid
    Hdd = HdD[:, HdD.shape[0]]
    H_mid = Hdd + 1 / eps * torch.matmul(HdD, HdD.T)
    # 
    U, S, Vh = torch.linalg.svd(H_mid)
    S = torch.sqrt(S**2 + clp) 
    Hdd_inv = Vh.T @ torch.diag(1.0 / S) @ U.T  
    #  H^{-1}g
    p1 = torch.matmul(HdD, g)
    p2 = torch.matmul(Hdd_inv, p1)
    Hg = 1 / eps * torch.matmul(HdD.T, p2)
    Hg = torch.reshape(Hg, grad.shape)
    Hg = 1 / eps * (grad - Hg)
    return Hg

def sherman(v, g, eps):
    grad = g.view(-1)
    z = v / v[0]
    t1 = torch.matmul(v, z.t())
    t2 = torch.matmul(v, grad)
    Hg = grad/eps - (z.t() * t2 / eps) / (eps + t1)
    Hg = torch.reshape(Hg, g.shape)
    return Hg

def train_nl(rank, weight, train_loader, epochs=1, reg_lambda=1e-4, compression_operator=None, alpha=1, lr=1):
    grad_accumulator = torch.zeros_like(weight[:-1], device=device) 
    num_features, num_classes = weight[:-1].shape  # 784, 10
    num_params = num_features * num_classes  # 7840
    local_hessian = torch.zeros((num_params, num_params), device=device) 
    weight_copy = weight.clone()  
    for epoch in range(epochs):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            batch_size = images.size(0)

            logits = images @ weight_copy[:-1] + weight_copy[-1]
            probs = softmax(logits)

            one_hot_labels = torch.zeros_like(probs).scatter_(1, labels.view(-1, 1), 1)
            grad = (images.T @ (probs - one_hot_labels) / batch_size) 
            g = grad.view(-1)

            R_diag = probs * (1 - probs)  # (batch_size, num_classes)
            H = torch.zeros((num_params, num_params), device=device)  
            for k in range(num_classes):  
                R_k = R_diag[:, k]  
                R_weighted_images = images * R_k[:, None] 
                H_k = torch.einsum('bi,bj->ij', R_weighted_images, images) / batch_size  

                row_start, row_end = k * num_features, (k + 1) * num_features
                col_start, col_end = k * num_features, (k + 1) * num_features
                H[row_start:row_end, col_start:col_end] = H_k

            H_diff = H - local_hessian
            H_shift, delta = compression_operator(H_diff) 
            local_hessian += alpha * H_shift

            #  delta = [H + lI]^{-1} g
            lk = torch.norm(H_diff, p='fro')  # l_i^k := ‖H_i - ∇²f_i(x^k)‖_F
            modified_hessian = local_hessian + lk * torch.eye(num_params, device=device)

            delta_weight = torch.linalg.solve(modified_hessian, g).clamp(-1e4, 1e4)

            weight_copy[:-1] -= lr * delta_weight.view_as(weight_copy[:-1])
            weight_copy[-1] -= lr * g.mean()

            grad_accumulator += grad
    return grad_accumulator, local_hessian, lk

def train_avg(rank, weight, train_loader, epochs=1, lr=0.1):
    weight_copy = weight.clone()  # Copy the current weight to update
    for epoch in range(epochs):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            logits = images @ weight_copy[:-1] + weight_copy[-1]
            probs = softmax(logits)

            # Compute gradient
            one_hot_labels = torch.zeros_like(probs).scatter_(1, labels.view(-1, 1), 1)
            g_weight = images.T @ (probs - one_hot_labels) / images.size(0)  # (784, 10)
            g_bias = (probs - one_hot_labels).mean(dim=0)  # (10,)

            # Gradient descent update
            weight_copy[:-1] -= lr * g_weight
            weight_copy[-1] -= lr * g_bias
    return weight_copy

def train_now(rank, weight, train_loader, b, mHG=None, epochs=1, lr=0.1, eps=5e-4, clp=5e-4, sampling_ratio=0.01):
    weight_copy = weight.clone().detach().requires_grad_(False) 
    num_params = weight_copy[:-1].numel() 
    d = int(sampling_ratio * num_params) 
    d = min(1, d)
    num_features = weight_copy[:-1].shape[0]  
    num_classes = weight_copy.shape[1]  
    if mHG is None:
        mHG = torch.zeros_like(weight_copy[:-1], device=device) 

    for epoch in range(epochs):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            logits = images @ weight_copy[:-1] + weight_copy[-1]  # (batch_size, num_classes)
            probs = torch.softmax(logits, dim=1) 

            one_hot_labels = torch.nn.functional.one_hot(labels, num_classes=num_classes).float()

            g = (images.T @ (probs - one_hot_labels) / images.size(0)) 
            grad = g.view(-1)

            R_diag = probs * (1 - probs)  # (batch_size, num_classes)

            leverage_scores = grad.pow(2) / grad.pow(2).sum()
            sampling_probabilities = leverage_scores / leverage_scores.sum()
            sampled_indices = torch.multinomial(sampling_probabilities, num_samples=d, replacement=False)
            # sample
            HdD = torch.zeros((d, num_params), device=device)
            for i, idx in enumerate(sampled_indices):
                feature_idx = idx % num_features  
                class_idx = torch.div(idx, num_features, rounding_mode='floor')  
                # R_weighted
                R_weighted = R_diag[:, class_idx]  # (batch_size,)
                # (batch_size, 784) -> (784,)
                weighted_images = images * R_weighted[:, None]
                hessian_row = (weighted_images.T @ images[:, feature_idx]) / images.size(0)

            HdD[i, feature_idx::num_features] = hessian_row[feature_idx]
            H_term = HdD[:, sampled_indices].clone()
            for i, idx in enumerate(sampled_indices):
                HdD[:, idx] = HdD[:, i]
            HdD[:, :d] = H_term

            Hg = woodbury_update(g, HdD, eps=eps, clp=clp)
            mHg = b * mHG + (1 - b) * Hg
            with torch.no_grad():
                weight_copy[:-1] -= lr * mHg.view_as(weight_copy[:-1])
                weight_copy[-1] -= lr * grad.mean()
    return weight_copy, mHG

def train_fagh(rank, weight, train_loader, mG=None, mH=None, epochs=1, lr=0.1, eps=1e-5):
    weight_copy = weight.clone().detach().requires_grad_(False) 
    num_features = weight_copy[:-1].shape[0]  
    num_classes = weight_copy.shape[1]  
    b1, b2 = 0.5, 0.55

    grad_accumulator = torch.zeros_like(weight_copy[:-1], device=device)  # (784, 10)
    hessian_row_accumulator = torch.zeros_like(weight_copy[:-1].view(-1), device=device) 
    if mG is None:
        mG = torch.zeros_like(weight_copy[:-1], device=device) 
    if mH is None:
        mH = torch.zeros((1, num_features * num_classes), device=device)  
    for epoch in range(epochs):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            logits = images @ weight_copy[:-1] + weight_copy[-1]  # (batch_size, num_classes)
            probs = torch.softmax(logits, dim=1) 

            one_hot_labels = torch.nn.functional.one_hot(labels, num_classes=num_classes).float()
            g = (images.T @ (probs - one_hot_labels)) / images.size(0) 
            mG = b1 * mG + (1 - b1) * g

            R_diag = probs * (1 - probs) 
            v = torch.zeros((1, num_features * num_classes), device=device) 
            for class_idx in range(num_classes):  
                R_weighted_images = images * R_diag[:, class_idx].unsqueeze(1)  
                hessian_row = (R_weighted_images.T @ images[:, 0]) / images.size(0)  
                
                start_idx = class_idx * num_features 
                v[0, start_idx:start_idx + num_features] = hessian_row
            mH = b2 * mH + (1 - b2) * v

            grad_accumulator += mG
            hessian_row_accumulator += mH.view(-1)

            Hg = sherman(mH.view(-1), mG, eps)
            with torch.no_grad():
                weight_copy[:-1] -= lr * Hg.view_as(weight_copy[:-1])
                weight_copy[-1] -= lr * mG.mean()

    return grad_accumulator, hessian_row_accumulator, mG, mH

def train_sophia(rank, weight, train_loader, mG=None, mH=None, epochs=1, lr=0.02, b1=0.95, b2=0.99, eps=1e-3, clip_value=1e4):
    weight_copy = weight.clone().detach().requires_grad_(False)  
    num_features = weight_copy[:-1].shape[0]  
    num_classes = weight_copy.shape[1]  

    if mG is None:
        mG = torch.zeros_like(weight_copy[:-1], device=device) 
        mH = torch.zeros_like(weight_copy[:-1], device=device)  

    for epoch in range(epochs):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            logits = images @ weight_copy[:-1] + weight_copy[-1]
            probs = torch.softmax(logits, dim=1)

            one_hot_labels = torch.zeros_like(probs).scatter_(1, labels.view(-1, 1), 1)
            g = (images.T @ (probs - one_hot_labels)) / images.size(0)  

            mG = b1 * mG + (1 - b1) * g

            pseudo_labels = torch.multinomial(probs, num_samples=1).squeeze(-1) 
            pseudo_one_hot = torch.zeros_like(probs).scatter_(1, pseudo_labels.view(-1, 1), 1)

            pseudo_g = images.T @ (probs - pseudo_one_hot) / images.size(0)
            H = pseudo_g.pow(2) 
            H = g.pow(2)

            mH = b2 * mH + (1 - b2) * H

            Hg = mG / torch.maximum(mH, torch.tensor(eps, device=device))
            Hg = Hg.clamp(-clip_value, clip_value) 
            with torch.no_grad():
                weight_copy[:-1] -= lr * Hg
                weight_copy[-1] -= lr * mG.mean(dim=0)
    return weight_copy, mG, mH

def train_done(rank, weight, train_loader, epochs=1, lr=0.1):
    weight_copy = weight.clone().detach().requires_grad_(False)  
    num_features = weight_copy[:-1].shape[0] 
    num_classes = weight_copy.shape[1]  
    

    grad_accumulator = torch.zeros_like(weight_copy[:-1], device=device)  # (784, 10)
    
    for epoch in range(epochs):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            logits = images @ weight_copy[:-1] + weight_copy[-1]  # (batch_size, num_classes)
            probs = torch.softmax(logits, dim=1) 
 
            one_hot_labels = torch.nn.functional.one_hot(labels, num_classes=num_classes).float()
            g = (images.T @ (probs - one_hot_labels)) / images.size(0) 


    return g

def train_scaffold(rank, weight, train_loader, c_local, c_global, epochs=1, lr=0.1):
    weight_copy = weight.clone() 
    c_local_copy = c_local.clone() 
    k = 0
    for epoch in range(epochs):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            logits = images @ weight_copy[:-1] + weight_copy[-1]
            probs = softmax(logits)

            one_hot_labels = torch.zeros_like(probs).scatter_(1, labels.view(-1, 1), 1)
            g_weight = images.T @ (probs - one_hot_labels) / images.size(0) 
            g_bias = (probs - one_hot_labels).mean(dim=0) 

            # yi <- yi - ηl (gi(yi) - ci + c_global)
            weight_copy[:-1] -= lr * (g_weight - c_local_copy[:-1] + c_global[:-1])
            weight_copy[-1] -= lr * (g_bias - c_local_copy[-1] + c_global[-1])
            # weight_copy[:-1] -= lr * (g_weight)
            # weight_copy[-1] -= lr * (g_bias)
            k += 1
    #  c_local，ci <- ci + 1 / Kηl (x - yi)
    c_local_copy[:-1] += (weight[:-1] - weight_copy[:-1]) / (k * lr) - c_global[:-1]
    c_local_copy[-1] += (weight[-1] - weight_copy[-1]) / (k * lr) - c_global[-1]
    return weight_copy, c_local_copy

def test(weight, test_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            logits = images @ weight[:-1] + weight[-1]
            predicted = logits.argmax(dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return accuracy

def update_global_model(global_weight, global_grad, global_hessian, lr=0.01):

    grad_flat = global_grad.view(-1) 
    delta_flat = torch.linalg.solve(global_hessian, grad_flat).clamp(-1e5, 1e5) 
    delta_weight = delta_flat.view_as(global_weight[:-1])

    global_weight[:-1] -= lr * delta_weight
    global_weight[-1] -= lr * global_grad.mean()
    return global_weight

def update_global_done(global_weight, weight, local_grads, partitions, num_clients, R=5, alpha=0.01, eta=0.1):
    num_features, num_classes = global_weight[:-1].shape  # 784, 10
    num_params = num_features * num_classes  # 7840    

    global_grad = torch.zeros_like(global_weight[:-1], device=device)
    local_H = torch.zeros((num_params, num_params), device=device)

    for i, grad in enumerate(local_grads):
        global_grad += grad * weight[i]
    # d^R_t = 0
    d_R_global = - torch.clone(global_grad)

    for i in range(num_clients):
        local_weight = global_weight.clone().detach()
        local_dataset = partitions[i]
        bs = len(local_dataset)
        local_loader = DataLoader(partitions[i], batch_size=bs, shuffle=True)
）
        for images, labels in local_loader:
            images, labels = images.to(device), labels.to(device)
            logits = images @ local_weight[:-1] + local_weight[-1]
            probs = torch.softmax(logits, dim=1)
            R_diag = probs * (1 - probs)  # (batch_size, num_classes)
            H = torch.zeros((num_params, num_params), device=device)  
            for k in range(num_classes): 
                R_k = R_diag[:, k]  
                R_weighted_images = images * R_k[:, None] 
                H_k = torch.einsum('bi,bj->ij', R_weighted_images, images) / bs 

                row_start, row_end = k * num_features, (k + 1) * num_features
                col_start, col_end = k * num_features, (k + 1) * num_features
                H[row_start:row_end, col_start:col_end] = H_k

        d_R = torch.zeros_like(global_weight[:-1], device=device).view(-1)

        I = torch.eye(H.size(0), device=device)  
        for r in range(R):
            # d_R = torch.linalg.solve(H_scaled, -alpha * global_grad.flatten()).view_as(d_R) 
            d_R = (torch.eye(local_H.size(0), device=device) - alpha * local_H) @ d_R - alpha * global_grad.view(-1)

        d_R_global += d_R.view_as(global_weight[:-1]) * weight[i]

    global_weight[:-1] += eta * d_R_global
    global_weight[-1] += eta * global_grad.mean(dim=0)  

    return global_weight

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.cuda.empty_cache()
    # random.seed(42)
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0, help='Value of learning rate')
    parser.add_argument('--lr', type=float, default=0.0015, help='Value of learning rate')
    parser.add_argument('--bs', type=int, default=512, help='batch size')
    parser.add_argument('--eps', type=float, default=0.005, help='param of hessian')
    parser.add_argument('--donea', type=float, default=0.008, help='alpha of done')
    parser.add_argument('--epoch', type=int, default=1, help='epoch num')
    parser.add_argument('--nonIID', action='store_true', help='--nonIID means True nothing means False')
    parser.add_argument('--alpha', type=float, default=0.1, help='nonIID')
    parser.add_argument('--size', type=int, default=30, help='size of client')
    parser.add_argument('--round', type=int, default=60, help='rounds num')
    parser.add_argument('--opt', type=str, default='avg', help='optimizer')
    parser.add_argument('--sr', type=float, default=0.01, help='sample rate')
    # parser.add_argument('--dataset', type=str, default="pendigits", help='dataset')
    
    args = parser.parse_args()
    lr = args.lr   
    eps = args.eps
    bs = args.bs
    local_epochs = args.epoch
    NonIID = args.nonIID
    alpha = args.alpha
    size = args.size  
    rounds = args.round
    optimizer = args.opt
    sr = args.sr

    donea = args.donea

    partitions, weights, testset, classnum = dataload(size, isNonIID=NonIID, alpha=alpha)
    test_loader = DataLoader(testset, batch_size=512, shuffle=False)


    global_weight = torch.zeros(785, 10, device=device)
    local_weights = [torch.zeros(785, 10, device=device) for _ in range(size)]

    global_hessian = None
    mG, mH, mHG = [None for _ in range(size)], [None for _ in range(size)], [None for _ in range(size)]
    global_c = torch.zeros_like(global_weight, device=device) 
    local_cs = [torch.zeros_like(global_weight, device=device) for _ in range(size)]
    b = 0.55
    # b=0
    for r in range(rounds):
        if r % 10 == 0 and r < 60:
            b -= 0.05
        # print(f"Round {r + 1}/{rounds}")
        local_grads, local_hessians, l = [], [], []

        for i in range(size):
            # print(f"Training on client {i + 1}/{size}")
            train_loader = DataLoader(partitions[i], batch_size=bs, shuffle=True)
            if optimizer == "avg":
                local_weights[i] = train_avg(i, local_weights[i], train_loader, epochs=local_epochs, lr=lr)
            if optimizer == "nl":
                grad, hessian_shift, lk = train_nl(i, local_weights[i], train_loader, epochs=local_epochs, compression_operator=low_rank_compression, lr=lr)  # 返回梯度和Hessian shift
                local_grads.append(grad)         
                local_hessians.append(hessian_shift)  
                l.append(lk)
            if optimizer == "now":
                local_weights[i], mHG[i] = train_now(i, local_weights[i], train_loader, b, mHG=mHG[i], epochs=local_epochs, lr=lr, eps=eps, clp=1e-4, sampling_ratio=sr)
            if optimizer == "fagh":
                g, v, mG[i], mH[i] = train_fagh(i, local_weights[i], train_loader, mG=mG[i], mH=mH[i], epochs=local_epochs, lr=lr, eps=eps)
                local_grads.append(g)
                local_hessians.append(v)
            if optimizer == "sophia":
                local_weights[i], mG[i], mH[i] = train_sophia(i, global_weight, train_loader, mG=mG[i], mH=mH[i], epochs=1, lr=lr)
            if optimizer == "done":
                grad = train_done(i, local_weights[i], train_loader, epochs=1, lr=lr)
                local_grads.append(grad)
            if optimizer == "scaffold":
                local_weights[i], local_cs[i] = train_scaffold(rank=i, weight=global_weight, train_loader=train_loader, c_local=local_cs[i], c_global=global_c, epochs=local_epochs, lr=lr)


        if optimizer == "avg" or optimizer == "now" or optimizer == "sophia":
            global_weight = federated_aggregate(local_weights, weights) 
        if optimizer == "nl" or optimizer == "fagh":
            global_grad = sum(local_grads[i] * weights[i] for i in range(size))  
            global_s = sum(local_hessians[i] * weights[i] for i in range(size)) 
            l_avg = sum(l[i] * weights[i] for i in range(size))  
            if global_hessian == None:
                global_hessian = torch.clone(global_s)
            global_hessian += 1 * global_s + l_avg * torch.eye(global_hessian.size(0), device=device)
        if optimizer == "done":
            global_weight = update_global_done(global_weight, weights, local_grads, partitions, size, R=40, alpha=donea, eta=lr)
        if optimizer == "scaffold":
            global_weight = federated_aggregate(local_weights, weights) 
            global_c = sum(local_cs[i] - global_c for i in range(size)) * weights[i]


        if optimizer =="nl":
            global_weight = update_global_model(global_weight, global_grad, global_hessian, lr=lr)
        if optimizer == "fagh":
            Hg = sherman(v, g, eps)
            global_weight[:-1] -= lr * Hg
            global_weight[-1] -= lr * global_grad.mean(dim=0)


        for i in range(size):
            local_weights[i] = global_weight.clone()


        accuracy = test(global_weight, test_loader)
        print(f"Global Model Accuracy after round {r + 1}: {accuracy * 100:.2f}%")
