import argparse
import random
import copy
import datetime
import time
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import warnings

# ======================
# 0. 日志工具函数
# ======================
def log_print(*args, **kwargs):
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    print(f"[{timestamp}]", *args, **kwargs, flush=True)

# 尝试导入 clip
try:
    import clip
except ImportError:
    clip = None

# 尝试导入 timm
try:
    import timm
except ImportError:
    timm = None

# 忽略部分 torchvision 的兼容性警告
warnings.filterwarnings("ignore")

# ======================
# 1. 工具函数 & 数据集管理
# ======================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_dataset(name, root='./data'):
    name = name.lower()
    
    # 默认 Transform
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)) 
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    if name == 'cifar10':
        num_classes = 10
        in_channels = 3
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC) 
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC) 
        ])
        train_dataset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)

    elif name == 'cifar100':
        num_classes = 100
        in_channels = 3
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC)
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC)
        ])
        train_dataset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test)

    elif name == 'mnist':
        num_classes = 10
        in_channels = 1
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_dataset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform_test)
        
    elif name == 'fashionmnist':
        num_classes = 10
        in_channels = 1
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_dataset = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform_test)
    
    else:
        raise ValueError(f"Dataset {name} not supported.")

    return train_dataset, test_dataset, num_classes, in_channels

def split_dataset_non_iid(train_dataset, num_clients, alpha=0.5, num_classes=10):
    """使用 Dirichlet 分布划分 Non-IID 数据"""
    if hasattr(train_dataset, 'targets'):
        labels = np.array(train_dataset.targets)
    elif hasattr(train_dataset, 'labels'):
        labels = np.array(train_dataset.labels)
    else:
        labels = np.array([y for _, y in train_dataset])

    idxs = np.arange(len(train_dataset))
    client_indices = [[] for _ in range(num_clients)]

    for c in range(num_classes):
        idxs_c = idxs[labels == c]
        np.random.shuffle(idxs_c)
        
        proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
        proportions = np.array([p * (len(idxs_c)) for p in proportions])
        proportions = proportions.astype(int)
        proportions[-1] = len(idxs_c) - np.sum(proportions[:-1])
        
        if proportions[-1] < 0: proportions[-1] = 0
        
        split_idxs = np.split(idxs_c, np.cumsum(proportions)[:-1])
        for client_id, idx in enumerate(split_idxs):
            client_indices[client_id].extend(idx.tolist())

    for cid in range(num_clients):
        np.random.shuffle(client_indices[cid])
        
    return client_indices

# ======================
# 2. 模型定义
# ======================

class CLIPModelWrapper(nn.Module):
    def __init__(self, backbone, input_dim, num_classes):
        super().__init__()
        self.backbone = backbone.float()
        self.head = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        features = self.backbone(x)
        return self.head(features)

class TimmModelWrapper(nn.Module):
    def __init__(self, model_name, num_classes, pretrained=True):
        super(TimmModelWrapper, self).__init__()
        try:
            self.backbone = timm.create_model(
                model_name, 
                pretrained=pretrained, 
                num_classes=0, 
                img_size=224
            )
        except TypeError:
            self.backbone = timm.create_model(
                model_name, 
                pretrained=pretrained, 
                num_classes=0
            )
        self.in_features = self.backbone.num_features
        self.head = nn.Linear(self.in_features, num_classes)

    def forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        features = self.backbone(x)
        out = self.head(features)
        return out

def build_model(model_name, num_classes, in_channels, force_random=False):
    use_pretrained = True 
    if "clip" in model_name.lower():
        try:
            import clip
            model, _ = clip.load(model_name, device="cpu")
            if hasattr(model, 'visual'):
                input_dim = model.visual.output_dim
                return CLIPModelWrapper(model.visual, input_dim, num_classes)
        except Exception as e:
            log_print(f"[Warning] Failed to load via CLIP library: {e}. Trying timm...")

    try:
        return TimmModelWrapper(
            model_name=model_name, 
            num_classes=num_classes, 
            pretrained=use_pretrained
        )
    except Exception as e:
        raise ValueError(f"Model '{model_name}' not found in torchvision or timm.\nError: {e}")


# ======================
# 3. 客户端逻辑
# ======================
class Client:
    def __init__(self, client_id, dataset, indices, device, args, num_classes, in_channels):
        self.client_id = client_id
        self.device = device
        self.args = args
        self.num_classes = num_classes
        self.in_channels = in_channels
        
        should_drop = len(indices) > args.batch_size
        self.loader = DataLoader(
            Subset(dataset, indices), 
            batch_size=args.batch_size, 
            shuffle=True,
            drop_last=should_drop
        )
        self.model = build_model(args.model_name, num_classes, in_channels, force_random=True).to(self.device)

    def _set_freeze_status(self):
        if self.args.training_mode == 'head':
            for param in self.model.parameters():
                param.requires_grad = False
            
            if hasattr(self.model, 'head'):
                for param in self.model.head.parameters():
                    param.requires_grad = True
            elif hasattr(self.model, 'fc'):
                for param in self.model.fc.parameters():
                    param.requires_grad = True
            elif hasattr(self.model, 'classifier'):
                for param in self.model.classifier.parameters():
                    param.requires_grad = True
            else:
                log_print(f"[Error] Client {self.client_id}: Could not find head layer to unfreeze!")

            self.model.eval() 
            if hasattr(self.model, 'head'): self.model.head.train()
        else:
            self.model.train()
            for param in self.model.parameters():
                param.requires_grad = True

    def local_train(self, global_state_dict):
        """FedAvg / FedProx / FedSophia Step 1 (Gradient Accumulation)"""
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status() 
        
        trainable_params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
        if len(trainable_params) == 0:
            return self.model.state_dict(), 0, 0.0
        
        if self.args.optimizer == 'sgd':
            optimizer = optim.SGD(trainable_params, lr=self.args.lr, momentum=0.9)
        else:
            optimizer = optim.Adam(trainable_params, lr=self.args.lr)
            
        criterion = nn.CrossEntropyLoss()
        global_params_dict = {}
        if self.args.method == "FedProx":
             global_params_dict = {n: p.detach().clone() for n, p in self.model.named_parameters() if p.requires_grad}
        
        num_samples = 0
        loss_sum = 0.0

        for _ in range(self.args.local_epochs):
            for x, y in self.loader:
                x, y = x.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                logits = self.model(x)
                loss = criterion(logits, y)

                if self.args.method == "FedProx" and self.args.mu > 0:
                    prox_term = 0.0
                    for n, p in self.model.named_parameters():
                        if p.requires_grad:
                            prox_term += (p - global_params_dict[n]).norm(2)**2
                    loss += (self.args.mu / 2) * prox_term

                loss.backward()
                optimizer.step()

                num_samples += x.size(0)
                loss_sum += loss.item() * x.size(0)

        return self.model.state_dict(), num_samples, loss_sum / max(num_samples, 1)

    def compute_diag_hessian(self, global_state_dict):
        """
        FedSophia / FedNew / FedNewton-Approx: 计算对角 Hessian 估计。
        """
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status()
        self.model.train()
        
        criterion = nn.CrossEntropyLoss(reduction='sum') 
        
        try:
            x, y = next(iter(self.loader))
        except StopIteration:
            return {}, 0

        x, y = x.to(self.device), y.to(self.device)
        bs = x.size(0)
        
        logits = self.model(x)
        probs = torch.softmax(logits, dim=1)
        y_sampled = torch.multinomial(probs, num_samples=1).squeeze(1)
        
        loss = criterion(logits, y_sampled)
        
        grads = torch.autograd.grad(loss, filter(lambda p: p.requires_grad, self.model.parameters()), create_graph=False)
        
        h_diag_local = {}
        idx = 0
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                g = grads[idx]
                h_diag_local[name] = (g * g).clone().detach() / bs
                idx += 1
        
        return h_diag_local, bs

    def compute_gradient(self, global_state_dict):
        """FedNewton / FedNew (Gradient) / FedDANE (Gradient)"""
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status()
        
        trainable_params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
        if len(trainable_params) == 0:
            return {}, 0, 0.0

        criterion = nn.CrossEntropyLoss()
        # Zero grad manually
        for p in trainable_params:
            if p.grad is not None: p.grad.zero_()
        
        total_loss = 0.0
        total_samples = 0
        
        for x, y in self.loader:
            x, y = x.to(self.device), y.to(self.device)
            logits = self.model(x)
            loss = criterion(logits, y)
            (loss * x.size(0)).backward() 
            total_loss += loss.item() * x.size(0)
            total_samples += x.size(0)
            
        local_grads = {}
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                local_grads[name] = param.grad.clone().detach() / max(total_samples, 1)
        
        return local_grads, total_samples, total_loss / max(total_samples, 1)

    def compute_newton_step(self, global_state_dict, global_grads):
        """FedNewton 第二步：计算牛顿更新方向 (Full Hessian)"""
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status() 
        self.model.eval() 

        head_keys, head_params, body_keys = [], [], []
        
        for name, param in self.model.named_parameters():
            if not param.requires_grad: continue
            if 'head' in name or 'fc' in name or 'classifier' in name:
                head_keys.append(name)
                head_params.append(param)
            else:
                body_keys.append(name)

        head_module = None
        if hasattr(self.model, 'head'): head_module = self.model.head
        elif hasattr(self.model, 'fc'): head_module = self.model.fc
        elif hasattr(self.model, 'classifier'): head_module = self.model.classifier
        
        if head_module is None:
            return {}, 0

        feats_container = []
        def hook_fn(module, input, output): feats_container.append(input[0].detach()) 
        handle = head_module.register_forward_hook(hook_fn)

        features_list, targets_list, total_samples = [], [], 0
        for i, (x, y) in enumerate(self.loader):
            if i >= self.args.hessian_batches: break
            x, y = x.to(self.device), y.to(self.device)
            total_samples += x.size(0)
            feats_container = [] 
            with torch.no_grad(): self.model(x) 
            if len(feats_container) > 0:
                features_list.append(feats_container[0])
                targets_list.append(y)
        handle.remove()
        
        if total_samples == 0 or not features_list: return {}, 0
        
        all_features = torch.cat(features_list, dim=0)
        all_targets = torch.cat(targets_list, dim=0)

        curr_w = head_module.weight.detach()
        curr_b = head_module.bias.detach() if head_module.bias is not None else None
        
        def head_loss_func(w, b):
            out = torch.nn.functional.linear(all_features, w, b)
            return torch.nn.functional.cross_entropy(out, all_targets, reduction='sum')
        def head_loss_func_no_bias(w):
            out = torch.nn.functional.linear(all_features, w, None)
            return torch.nn.functional.cross_entropy(out, all_targets, reduction='sum')

        if curr_b is not None:
            inputs = (curr_w, curr_b)
            try:
                H_tuple = torch.autograd.functional.hessian(head_loss_func, inputs)
            except Exception as e:
                return {}, total_samples
        else:
            inputs = (curr_w,)
            H_tuple = torch.autograd.functional.hessian(head_loss_func_no_bias, inputs)
            H_tuple = ((H_tuple,),)

        g_parts = []
        for k in head_keys:
            if k in global_grads: 
                g_parts.append(global_grads[k].view(-1).to(self.device))
            else: 
                param_shape = self.model.state_dict()[k].shape
                g_parts.append(torch.zeros(param_shape).view(-1).to(self.device))
                
        if not g_parts: return {}, total_samples
        flat_g = torch.cat(g_parts)

        H_blocks_rows = []
        for i in range(len(inputs)):
            row_blocks = []
            for j in range(len(inputs)):
                block_flat = H_tuple[i][j].reshape(inputs[i].numel(), inputs[j].numel())
                row_blocks.append(block_flat)
            H_blocks_rows.append(torch.cat(row_blocks, dim=1)) 
        H_mat = torch.cat(H_blocks_rows, dim=0)
        H_avg = H_mat / max(total_samples, 1)

        I = torch.eye(H_avg.size(0), device=self.device)
        current_damping = self.args.damping
        delta_flat = None
        success = False
        
        for attempt in range(5):
            try:
                reg_H = H_avg + current_damping * I
                delta_flat = torch.linalg.solve(reg_H, flat_g)
                success = True
                break
            except RuntimeError:
                current_damping *= 10
        
        if not success:
            delta_flat = flat_g 

        delta_flat = delta_flat * self.args.newton_lr

        local_update_s, ptr = {}, 0
        for name, param in zip(head_keys, head_params):
            numel = param.numel()
            local_update_s[name] = delta_flat[ptr : ptr + numel].view(param.shape).cpu()
            ptr += numel

        if self.args.training_mode == 'full':
            for name in body_keys:
                if name in global_grads:
                    local_update_s[name] = (self.args.lr * global_grads[name].to(self.device)).cpu()

        return local_update_s, total_samples

    def compute_newton_approx_step(self, global_state_dict, global_grads):
        """
        FedNewton-Approx: 
        使用对角 Hessian 近似代替完整 Hessian 逆计算。
        Delta = g / (h_diag + damping)
        """
        # 1. 计算本地对角 Hessian (Monte Carlo 估计)
        h_diag, n_samples = self.compute_diag_hessian(global_state_dict)
        
        if n_samples == 0:
            return {}, 0

        local_updates = {}
        damping = self.args.damping
        
        # 2. 计算近似牛顿方向
        for name, param in self.model.named_parameters():
            if not param.requires_grad: continue
            if name not in global_grads: continue
            
            g_global_p = global_grads[name].to(self.device)
            
            if name in h_diag:
                h_vals = h_diag[name].to(self.device)
                # 近似 Newton step: delta = g_global / (h_diag + damping)
                preconditioner = h_vals + damping
                
                # 避免除零
                preconditioner = torch.clamp(preconditioner, min=1e-6)
                
                delta = g_global_p / preconditioner
                local_updates[name] = delta.cpu()
            else:
                # 如果没有 Hessian 信息，回退到 SGD
                local_updates[name] = g_global_p.cpu()

        return local_updates, n_samples

    def compute_dane_step(self, global_state_dict, global_grads):
        """FedDANE: Approximate Newton Step using local Hessian"""
        h_diag, n_samples = self.compute_diag_hessian(global_state_dict)
        
        if n_samples == 0:
            return {}, 0

        local_updates = {}
        mu = self.args.mu 
        
        for name, param in self.model.named_parameters():
            if not param.requires_grad: continue
            if name not in global_grads: continue
            
            g_global_p = global_grads[name].to(self.device)
            
            if name in h_diag:
                h_vals = h_diag[name].to(self.device)
                preconditioner = h_vals + mu
                delta = g_global_p / preconditioner
                local_updates[name] = delta.cpu()
            else:
                local_updates[name] = g_global_p.cpu()

        return local_updates, n_samples

# ======================
# 4. 服务端逻辑
# ======================
def average_weights(w_list, s_list):
    w_avg = copy.deepcopy(w_list[0])
    total_samples = sum(s_list)
    for k in w_avg.keys():
        if 'num_batches_tracked' in k: continue
        weighted_sum = sum(w[k] * s for w, s in zip(w_list, s_list))
        w_avg[k] = weighted_sum / total_samples
    return w_avg

def average_gradients(g_list, s_list):
    if not g_list: return {}
    all_keys = set()
    for g in g_list: all_keys.update(g.keys())
    
    g_avg = {}
    total_samples = sum(s_list)
    
    for k in all_keys:
        weighted_sum = 0.0
        for g, s in zip(g_list, s_list):
            if k in g:
                weighted_sum += g[k] * s
        g_avg[k] = weighted_sum / total_samples
    return g_avg

def test_global(model, global_weights, test_loader, device):
    model.load_state_dict(global_weights)
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            loss_sum += loss.item() * x.size(0)
            _, pred = torch.max(out, 1)
            correct += (pred == y).sum().item()
            total += x.size(0)
    return loss_sum / max(total, 1), correct / max(total, 1)

# ======================
# 5. 主流程
# ======================
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', type=str, default='FedAvg', 
                        choices=['FedAvg', 'FedProx', 'FedNewton', 'FedNewton-Approx', 'FedSophia', 'FedNew', 'FedDANE'])
    parser.add_argument('--training_mode', type=str, default='full', choices=['full', 'head'], help='full: update all; head: update head only')
    parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, cifar100, mnist, femnist')
    parser.add_argument('--model_name', type=str, default='resnet18')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--rounds', type=int, default=20)
    parser.add_argument('--num_clients', type=int, default=10)
    parser.add_argument('--fraction', type=float, default=0.5)
    parser.add_argument('--local_epochs', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=32)
    # LR settings
    parser.add_argument('--lr', type=float, default=0.01, help="Local SGD LR or Global Learning Rate for Newton methods")
    parser.add_argument('--newton_lr', type=float, default=0.1, help="For FedNewton / FedDANE / FedNewton-Approx")
    # FedSophia / FedNew Params
    parser.add_argument('--sophia_lr', type=float, default=0.05, help="Server side LR for FedSophia/FedNew")
    parser.add_argument('--rho', type=float, default=0.04, help="Hessian smoothing parameter")
    parser.add_argument('--betas', type=str, default='0.9,0.99', help="Betas for Sophia/New (Momentum, Hessian average)")
    parser.add_argument('--mu', type=float, default=0.01, help="FedProx mu or FedDANE regularization")
    parser.add_argument('--log_dir', type=str, default='./logs')
    parser.add_argument('--optimizer', type=str, default='sgd')
    parser.add_argument('--hessian_batches', type=int, default=32)
    parser.add_argument('--damping', type=float, default=0.01, help="Damping for FedNewton and FedNewton-Approx")
    parser.add_argument('--alpha', type=float, default=0.5, help="Paramter to control data heterogeneity")

    args = parser.parse_args()

    set_seed(args.seed)
    log_print(args)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.makedirs(args.log_dir, exist_ok=True)
    
    log_print(f"Start: {args.method} | Mode: {args.training_mode} | Dataset: {args.dataset} | Device: {device}")

    train_ds, test_ds, num_classes, in_channels = get_dataset(args.dataset)
    
    test_loader = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=2)
    client_indices = split_dataset_non_iid(train_ds, args.num_clients, alpha=args.alpha, num_classes=num_classes)
    
    global_model = build_model(args.model_name, num_classes, in_channels).to(device)
    global_weights = global_model.state_dict()

    # 初始化 FedSophia / FedNew 需要的动量 m 和 Hessian 累积 h
    sophia_m = {}
    sophia_h = {}
    beta1, beta2 = [float(x) for x in args.betas.split(',')]
    
    for k, v in global_weights.items():
        if v.dtype == torch.float32 or v.dtype == torch.float64:
            sophia_m[k] = torch.zeros_like(v)
            sophia_h[k] = torch.zeros_like(v)

    clients = [Client(i, train_ds, client_indices[i], device, args, num_classes, in_channels) for i in range(args.num_clients)]

    logs = {"rounds": [], "test_acc": [], "test_loss": [], "train_loss": [], "wall_time": [], "max_gpu_mem": []}
    start_time = time.time()

    for rnd in range(1, args.rounds + 1):
        start_rnd_time = time.time() 

        # 重置显存统计
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats(device)

        m = max(int(args.fraction * args.num_clients), 1)
        selected_users = np.random.choice(range(args.num_clients), m, replace=False)
        
        local_losses = []
        local_samples_list = []

        # ==========================================
        # Method Implementation
        # ==========================================
        if args.method in ['FedSophia', 'FedNew']:
            local_deltas = []
            local_hessians = []
            
            for idx in selected_users:
                w_before = copy.deepcopy(global_weights)
                w_after, n, loss = clients[idx].local_train(w_before)
                
                diff = {}
                for k in w_before.keys():
                    if k in w_after:
                        diff[k] = (w_after[k] - w_before[k]).cpu() 
                
                local_deltas.append(diff)
                local_losses.append(loss)
                local_samples_list.append(n)
                
                h_local, _ = clients[idx].compute_diag_hessian(w_before)
                local_hessians.append(h_local)

            avg_diff = average_weights(local_deltas, local_samples_list)
            avg_h = average_gradients(local_hessians, local_samples_list)

            for k in global_weights.keys():
                if k not in avg_diff or k not in sophia_m: continue
                g = -avg_diff[k].to(device)
                sophia_m[k] = beta1 * sophia_m[k] + (1 - beta1) * g
                
                if k in avg_h:
                    h = avg_h[k].to(device)
                    sophia_h[k] = beta2 * sophia_h[k] + (1 - beta2) * h
                
                if args.method == 'FedSophia':
                    preconditioner = torch.maximum(sophia_h[k], torch.tensor(args.rho, device=device))
                    step = args.sophia_lr * (sophia_m[k] / preconditioner)
                else: # FedNew
                    preconditioner = sophia_h[k] + args.rho
                    step = args.sophia_lr * (sophia_m[k] / preconditioner)

                global_weights[k] -= step

            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Strategy: FedAvg / FedProx ---
        elif args.method in ['FedAvg', 'FedProx']:
            local_weights = []
            for idx in selected_users:
                w, n, loss = clients[idx].local_train(copy.deepcopy(global_weights))
                local_weights.append(w)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_weights = average_weights(local_weights, local_samples_list)
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Strategy: FedNewton (Original Full Hessian for Head) ---
        elif args.method == 'FedNewton':
            local_grads_list = []
            for idx in selected_users:
                g, n, loss = clients[idx].compute_gradient(copy.deepcopy(global_weights))
                local_grads_list.append(g)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_grads = average_gradients(local_grads_list, local_samples_list)
            
            local_steps_list = []
            for idx in selected_users:
                s_j, _ = clients[idx].compute_newton_step(copy.deepcopy(global_weights), global_grads)
                local_steps_list.append(s_j)

            global_step = average_gradients(local_steps_list, local_samples_list)

            for name, param in global_weights.items():
                if name in global_step:
                    if param.dtype == torch.long: continue 
                    global_weights[name] = param - global_step[name].to(device)
            
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Strategy: FedNewton-Approx (Diagonal Hessian Approximation) ---
        elif args.method == 'FedNewton-Approx':
            # 1. 计算梯度
            local_grads_list = []
            for idx in selected_users:
                g, n, loss = clients[idx].compute_gradient(copy.deepcopy(global_weights))
                local_grads_list.append(g)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_grads = average_gradients(local_grads_list, local_samples_list)
            
            # 2. 使用近似对角 Hessian 计算 Step
            local_steps_list = []
            for idx in selected_users:
                s_j, _ = clients[idx].compute_newton_approx_step(copy.deepcopy(global_weights), global_grads)
                local_steps_list.append(s_j)

            global_step = average_gradients(local_steps_list, local_samples_list)

            # 3. 更新
            for name, param in global_weights.items():
                if name in global_step:
                    if param.dtype == torch.long: continue 
                    global_weights[name] = param - args.newton_lr * global_step[name].to(device)
            
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Strategy: FedDANE (Diag Approximation) ---
        elif args.method == 'FedDANE':
            local_grads_list = []
            for idx in selected_users:
                g, n, loss = clients[idx].compute_gradient(copy.deepcopy(global_weights))
                local_grads_list.append(g)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_grads = average_gradients(local_grads_list, local_samples_list)
            
            local_steps_list = []
            for idx in selected_users:
                s_j, _ = clients[idx].compute_dane_step(copy.deepcopy(global_weights), global_grads)
                local_steps_list.append(s_j)

            global_step = average_gradients(local_steps_list, local_samples_list)

            for name, param in global_weights.items():
                if name in global_step:
                    if param.dtype == torch.long: continue 
                    global_weights[name] = param - args.newton_lr * global_step[name].to(device)
            
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)


        test_loss, test_acc = test_global(global_model, global_weights, test_loader, device)
        
        # 时间计算
        total_elapsed = time.time() - start_time      
        round_elapsed = time.time() - start_rnd_time 
        
        # 记录显存
        max_mem = 0.0
        if torch.cuda.is_available():
            max_mem = torch.cuda.max_memory_allocated(device) / (1024 * 1024)

        logs["rounds"].append(rnd)
        logs["test_acc"].append(test_acc)
        logs["test_loss"].append(test_loss)
        logs["train_loss"].append(train_loss_avg)
        logs["wall_time"].append(total_elapsed)
        logs["max_gpu_mem"].append(max_mem)
        
        log_print(f"[Round {rnd:03d}] {args.method} | Acc: {test_acc*100:.2f}% | Loss: {test_loss:.4f} | Time: {round_elapsed:.2f}s | Max GPU Mem: {max_mem:.2f} MB")

        if test_loss < 1e-6:
            log_print("Early stopping as test loss is very low.")
            break
        elif test_loss > 10.0:
            log_print("Early stopping as test loss is very high.")
            break

    save_filename = f"{args.method}_{args.training_mode}_{args.dataset}_{args.model_name}_s{args.seed}.npz"
    save_path = os.path.join(args.log_dir, save_filename)
    np.savez(save_path, method=args.method, training_mode=args.training_mode, **logs)
    log_print(f"Saved to {save_path}")
