import argparse
import random
import copy
import datetime
import time
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import warnings

# ======================
# 0. 日志工具函数
# ======================
def log_print(*args, **kwargs):
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    print(f"[{timestamp}]", *args, **kwargs, flush=True)

# 尝试导入 clip
try:
    import clip
except ImportError:
    clip = None

# 尝试导入 timm
try:
    import timm
except ImportError:
    timm = None

# 忽略部分 torchvision 的兼容性警告
warnings.filterwarnings("ignore")

# ======================
# 1. 工具函数 & 数据集管理
# ======================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_dataset(name, root='./data'):
    name = name.lower()
    
    # 默认 Transform
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)) 
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    if name == 'cifar10':
        num_classes = 10
        in_channels = 3
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC) 
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC) 
        ])
        train_dataset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)

    elif name == 'cifar100':
        num_classes = 100
        in_channels = 3
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC)
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC)
        ])
        train_dataset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test)

    elif name == 'mnist':
        num_classes = 10
        in_channels = 1
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_dataset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform_test)
        
    elif name == 'fashionmnist':
        num_classes = 10
        in_channels = 1
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_dataset = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform_test)
    
    else:
        raise ValueError(f"Dataset {name} not supported.")

    return train_dataset, test_dataset, num_classes, in_channels

def split_dataset_non_iid(train_dataset, num_clients, alpha=0.5, num_classes=10):
    """使用 Dirichlet 分布划分 Non-IID 数据"""
    if hasattr(train_dataset, 'targets'):
        labels = np.array(train_dataset.targets)
    elif hasattr(train_dataset, 'labels'):
        labels = np.array(train_dataset.labels)
    else:
        labels = np.array([y for _, y in train_dataset])

    idxs = np.arange(len(train_dataset))
    client_indices = [[] for _ in range(num_clients)]

    for c in range(num_classes):
        idxs_c = idxs[labels == c]
        np.random.shuffle(idxs_c)
        
        proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
        proportions = np.array([p * (len(idxs_c)) for p in proportions])
        proportions = proportions.astype(int)
        proportions[-1] = len(idxs_c) - np.sum(proportions[:-1])
        
        if proportions[-1] < 0: proportions[-1] = 0
        
        split_idxs = np.split(idxs_c, np.cumsum(proportions)[:-1])
        for client_id, idx in enumerate(split_idxs):
            client_indices[client_id].extend(idx.tolist())

    for cid in range(num_clients):
        np.random.shuffle(client_indices[cid])
        
    return client_indices

# ======================
# 1.5 优化器工具 (CG, Neumann, L-BFGS 等)
# ======================

def conjugate_gradient(mvp_fn, b, max_iter=10, tol=1e-4, device='cpu'):
    """CG Solver for Ax=b"""
    x = torch.zeros_like(b).to(device)
    r = b.clone()
    p = r.clone()
    rsold = torch.dot(r, r)

    for i in range(max_iter):
        Ap = mvp_fn(p)
        alpha = rsold / (torch.dot(p, Ap) + 1e-8)
        x = x + alpha * p
        r = r - alpha * Ap
        rsnew = torch.dot(r, r)
        if torch.sqrt(rsnew) < tol:
            break
        p = r + (rsnew / rsold) * p
        rsold = rsnew
    return x

def neumann_series_solver(mvp_fn, b, k=5, scale=1.0, device='cpu'):
    """
    Solve Ax = b using Neumann Series Approximation.
    Approximate A^{-1} ≈ sum (I - scale*A)^k * scale
    """
    x = b.clone().to(device) * scale
    term = b.clone().to(device)
    
    for _ in range(k):
        # term_{i+1} = (I - scale*A) * term_i
        A_term = mvp_fn(term)
        term = term - scale * A_term
        x = x + term * scale
    return x

def lbfgs_solver(g_flat, s_list, y_list, rho_list, gamma=1.0):
    """
    Two-loop recursion for L-BFGS direction.
    Computes H_k * g_flat where H_k is L-BFGS inverse Hessian approx.
    """
    q = g_flat.clone()
    alpha_list = []
    m = len(s_list)
    
    # First loop
    for i in range(m - 1, -1, -1):
        s = s_list[i]
        y = y_list[i]
        rho = rho_list[i]
        alpha = rho * torch.dot(s, q)
        alpha_list.append(alpha)
        q = q - alpha * y
        
    # Initial approximation H_0
    r = gamma * q
    
    # Second loop
    for i in range(m):
        idx = m - 1 - i 
        s = s_list[i]
        y = y_list[i]
        rho = rho_list[i]
        alpha = alpha_list[idx]
        beta = rho * torch.dot(y, r)
        r = r + s * (alpha - beta)
        
    return r

# ======================
# 2. 模型定义
# ======================

class CLIPModelWrapper(nn.Module):
    def __init__(self, backbone, input_dim, num_classes):
        super().__init__()
        self.backbone = backbone.float()
        self.head = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        features = self.backbone(x)
        return self.head(features)

class TimmModelWrapper(nn.Module):
    def __init__(self, model_name, num_classes, pretrained=True):
        super(TimmModelWrapper, self).__init__()
        try:
            self.backbone = timm.create_model(
                model_name, 
                pretrained=pretrained, 
                num_classes=0, 
                img_size=224
            )
        except TypeError:
            self.backbone = timm.create_model(
                model_name, 
                pretrained=pretrained, 
                num_classes=0
            )
        self.in_features = self.backbone.num_features
        self.head = nn.Linear(self.in_features, num_classes)

    def forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        features = self.backbone(x)
        out = self.head(features)
        return out

def build_model(model_name, num_classes, in_channels, force_random=False):
    use_pretrained = True 
    if "clip" in model_name.lower():
        try:
            import clip
            model, _ = clip.load(model_name, device="cpu")
            if hasattr(model, 'visual'):
                input_dim = model.visual.output_dim
                return CLIPModelWrapper(model.visual, input_dim, num_classes)
        except Exception as e:
            log_print(f"[Warning] Failed to load via CLIP library: {e}. Trying timm...")

    try:
        return TimmModelWrapper(
            model_name=model_name, 
            num_classes=num_classes, 
            pretrained=use_pretrained
        )
    except Exception as e:
        raise ValueError(f"Model '{model_name}' not found in torchvision or timm.\nError: {e}")


# ======================
# 3. 客户端逻辑
# ======================
class Client:
    def __init__(self, client_id, dataset, indices, device, args, num_classes, in_channels):
        self.client_id = client_id
        self.device = device
        self.args = args
        self.num_classes = num_classes
        self.in_channels = in_channels
        
        should_drop = len(indices) > args.batch_size
        self.loader = DataLoader(
            Subset(dataset, indices), 
            batch_size=args.batch_size, 
            shuffle=True,
            drop_last=should_drop
        )
        self.model = build_model(args.model_name, num_classes, in_channels, force_random=True).to(self.device)

    def _set_freeze_status(self):
        if self.args.training_mode == 'head':
            for param in self.model.parameters():
                param.requires_grad = False
            
            if hasattr(self.model, 'head'):
                for param in self.model.head.parameters():
                    param.requires_grad = True
            elif hasattr(self.model, 'fc'):
                for param in self.model.fc.parameters():
                    param.requires_grad = True
            elif hasattr(self.model, 'classifier'):
                for param in self.model.classifier.parameters():
                    param.requires_grad = True
            else:
                log_print(f"[Error] Client {self.client_id}: Could not find head layer to unfreeze!")

            self.model.eval() 
            if hasattr(self.model, 'head'): self.model.head.train()
        else:
            self.model.train()
            for param in self.model.parameters():
                param.requires_grad = True

    def local_train(self, global_state_dict):
        """FedAvg / FedProx / FedSophia Step 1 (Gradient Accumulation)"""
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status() 
        
        trainable_params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
        if len(trainable_params) == 0:
            return self.model.state_dict(), 0, 0.0
        
        if self.args.optimizer == 'sgd':
            optimizer = optim.SGD(trainable_params, lr=self.args.lr, momentum=0.9)
        else:
            optimizer = optim.Adam(trainable_params, lr=self.args.lr)
            
        criterion = nn.CrossEntropyLoss()
        global_params_dict = {}
        if self.args.method == "FedProx":
             global_params_dict = {n: p.detach().clone() for n, p in self.model.named_parameters() if p.requires_grad}
        
        num_samples = 0
        loss_sum = 0.0

        for _ in range(self.args.local_epochs):
            for x, y in self.loader:
                x, y = x.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                logits = self.model(x)
                loss = criterion(logits, y)

                if self.args.method == "FedProx" and self.args.mu > 0:
                    prox_term = 0.0
                    for n, p in self.model.named_parameters():
                        if p.requires_grad:
                            prox_term += (p - global_params_dict[n]).norm(2)**2
                    loss += (self.args.mu / 2) * prox_term

                loss.backward()
                optimizer.step()

                num_samples += x.size(0)
                loss_sum += loss.item() * x.size(0)

        return self.model.state_dict(), num_samples, loss_sum / max(num_samples, 1)

    def compute_diag_hessian(self, global_state_dict):
        """
        FedSophia / FedNew / FedNewton-Approx: 计算对角 Hessian 估计。
        """
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status()
        self.model.train()
        
        criterion = nn.CrossEntropyLoss(reduction='sum') 
        
        try:
            x, y = next(iter(self.loader))
        except StopIteration:
            return {}, 0

        x, y = x.to(self.device), y.to(self.device)
        bs = x.size(0)
        
        logits = self.model(x)
        probs = torch.softmax(logits, dim=1)
        y_sampled = torch.multinomial(probs, num_samples=1).squeeze(1)
        
        loss = criterion(logits, y_sampled)
        
        grads = torch.autograd.grad(loss, filter(lambda p: p.requires_grad, self.model.parameters()), create_graph=False)
        
        h_diag_local = {}
        idx = 0
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                g = grads[idx]
                h_diag_local[name] = (g * g).clone().detach() / bs
                idx += 1
        
        return h_diag_local, bs

    def compute_gradient(self, global_state_dict):
        """FedNewton / FedNew (Gradient) / FedDANE (Gradient)"""
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status()
        
        trainable_params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
        if len(trainable_params) == 0:
            return {}, 0, 0.0

        criterion = nn.CrossEntropyLoss()
        # Zero grad manually
        for p in trainable_params:
            if p.grad is not None: p.grad.zero_()
        
        total_loss = 0.0
        total_samples = 0
        
        for x, y in self.loader:
            x, y = x.to(self.device), y.to(self.device)
            logits = self.model(x)
            loss = criterion(logits, y)
            (loss * x.size(0)).backward() 
            total_loss += loss.item() * x.size(0)
            total_samples += x.size(0)
            
        local_grads = {}
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                local_grads[name] = param.grad.clone().detach() / max(total_samples, 1)
        
        return local_grads, total_samples, total_loss / max(total_samples, 1)

    def compute_newton_step(self, global_state_dict, global_grads, lbfgs_state=None):
        r"""
        FedNewton Solver Hub:
        Supports: exact, cg, diag, neumann, lbfgs, lowrank
        """
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status() 
        self.model.eval() 

        # --- L-BFGS Special Path ---
        if self.args.newton_solver == 'lbfgs':
            if lbfgs_state is None or len(lbfgs_state['s_list']) == 0:
                # 如果还没有 L-BFGS 历史，退化为梯度下降 (返回梯度)
                local_update_s = {}
                for name, param in self.model.named_parameters():
                    if not param.requires_grad: continue
                    is_head = ('head' in name) or ('fc' in name) or ('classifier' in name)
                    if self.args.training_mode == 'head' and not is_head: continue

                    if name in global_grads:
                        local_update_s[name] = global_grads[name].cpu()
                return local_update_s, 0
            
            # Flatten global gradient (MUST MATCH SERVER FILTERING LOGIC)
            g_parts = []
            target_keys = []
            
            for name, param in self.model.named_parameters():
                if not param.requires_grad: continue
                
                # [Fix] 必须和 Server 端的过滤逻辑保持一致
                is_head = ('head' in name) or ('fc' in name) or ('classifier' in name)
                if self.args.training_mode == 'head' and not is_head: continue
                
                target_keys.append(name)
                
                if name in global_grads:
                     g_parts.append(global_grads[name].view(-1).to(self.device))
                else:
                     g_parts.append(torch.zeros_like(param).view(-1).to(self.device))
            
            if not g_parts: return {}, 0 # Nothing to update

            flat_g = torch.cat(g_parts)

            # 检查维度匹配
            if flat_g.shape[0] != lbfgs_state['s_list'][0].shape[0]:
                 # 如果维度对不上，回退到梯度
                 local_update_s = {}
                 for k in target_keys:
                     if k in global_grads: local_update_s[k] = global_grads[k].cpu()
                 return local_update_s, 0

            delta_flat = lbfgs_solver(
                flat_g, 
                lbfgs_state['s_list'], 
                lbfgs_state['y_list'], 
                lbfgs_state['rho_list'],
                gamma=1.0
            )

            delta_flat = delta_flat * self.args.newton_lr
            
            local_update_s, ptr = {}, 0
            for name in target_keys:
                # 注意：必须重新获取 param 以确定 shape
                param = dict(self.model.named_parameters())[name]
                numel = param.numel()
                local_update_s[name] = delta_flat[ptr:ptr+numel].view(param.shape).cpu()
                ptr += numel
                
            return local_update_s, 1 

        # --- Preparation for Second Order Methods ---
        head_keys, head_params, body_keys = [], [], []
        for name, param in self.model.named_parameters():
            if not param.requires_grad: continue
            if 'head' in name or 'fc' in name or 'classifier' in name:
                head_keys.append(name)
                head_params.append(param)
            else:
                body_keys.append(name)

        head_module = None
        if hasattr(self.model, 'head'): head_module = self.model.head
        elif hasattr(self.model, 'fc'): head_module = self.model.fc
        elif hasattr(self.model, 'classifier'): head_module = self.model.classifier
        
        if head_module is None: return {}, 0

        # Data Gathering
        feats_container = []
        def hook_fn(module, input, output): feats_container.append(input[0].detach()) 
        handle = head_module.register_forward_hook(hook_fn)

        features_list, targets_list, total_samples = [], [], 0
        for i, (x, y) in enumerate(self.loader):
            if i >= self.args.hessian_batches: break
            x, y = x.to(self.device), y.to(self.device)
            total_samples += x.size(0)
            feats_container = [] 
            with torch.no_grad(): self.model(x) 
            if len(feats_container) > 0:
                features_list.append(feats_container[0])
                targets_list.append(y)
        handle.remove()
        
        if total_samples == 0 or not features_list: return {}, 0
        all_features = torch.cat(features_list, dim=0)
        all_targets = torch.cat(targets_list, dim=0)

        curr_w = head_module.weight.detach()
        curr_b = head_module.bias.detach() if head_module.bias is not None else None
        inputs = (curr_w, curr_b) if curr_b is not None else (curr_w,)

        # Prepare Flat Gradient
        g_parts = []
        for k in head_keys:
            if k in global_grads: g_parts.append(global_grads[k].view(-1).to(self.device))
            else: g_parts.append(torch.zeros_like(self.model.state_dict()[k]).view(-1).to(self.device))
        flat_g = torch.cat(g_parts)
        delta_flat = None

        # Helper Functions
        def head_loss_func(w, b):
            out = torch.nn.functional.linear(all_features, w, b)
            return torch.nn.functional.cross_entropy(out, all_targets, reduction='sum')
        def head_loss_func_no_bias(w):
            out = torch.nn.functional.linear(all_features, w, None)
            return torch.nn.functional.cross_entropy(out, all_targets, reduction='sum')

        # ==========================================
        # Solver Branching
        # ==========================================

        # 1. Exact
        if self.args.newton_solver == 'exact':
            if curr_b is not None: H_tuple = torch.autograd.functional.hessian(head_loss_func, inputs)
            else: H_tuple = ((torch.autograd.functional.hessian(head_loss_func_no_bias, inputs),),)
            
            H_blocks_rows = []
            for i in range(len(inputs)):
                row_blocks = []
                for j in range(len(inputs)):
                    block = H_tuple[i][j].reshape(inputs[i].numel(), inputs[j].numel())
                    row_blocks.append(block)
                H_blocks_rows.append(torch.cat(row_blocks, dim=1))
            H_mat = torch.cat(H_blocks_rows, dim=0) / max(total_samples, 1)
            
            I = torch.eye(H_mat.size(0), device=self.device)
            reg_H = H_mat + self.args.damping * I
            try:
                delta_flat = torch.linalg.solve(reg_H, flat_g)
            except:
                delta_flat = flat_g # Fallback

        # 2. Diagonal Approximation (Calculated on the fly from exact Hessian diagonal for accuracy)
        elif self.args.newton_solver == 'diag':
             # Using hvp with standard basis vectors is slow, so we use autograd loop or simple approximation
             # Here we fallback to a quick diagonal estimation using squared gradients similar to FedSophia
             # but on current batch data
             
            # Calculate per-sample gradients for diagonal estimation
            # This is a simplified version. A true diag(H) needs second derivatives.
            # We use the Gauss-Newton approximation diag(G^T G) or simply return diag-inverse.
            
            # Better approach: Compute diagonal of Hessian using randomization (Hutchinson) or just use FedSophia logic
            # For simplicity here, we use the `compute_diag_hessian` logic but applied to the head only data
            h_diag, _ = self.compute_diag_hessian(global_state_dict) # Re-uses existing function
            
            h_parts = []
            for k in head_keys:
                if k in h_diag: h_parts.append(h_diag[k].view(-1).to(self.device))
                else: h_parts.append(torch.ones_like(flat_g[:1]).to(self.device)) # Dummy
            
            flat_h = torch.cat(h_parts)
            delta_flat = flat_g / (flat_h + self.args.damping)

        # 3. CG & Neumann & LowRank (Matrix-Free Methods)
        elif self.args.newton_solver in ['cg', 'neumann', 'lowrank']:
            
            def mvp_fn(v_flat):
                v_list, ptr = [], 0
                for inp in inputs:
                    numel = inp.numel()
                    v_list.append(v_flat[ptr : ptr + numel].view(inp.shape))
                    ptr += numel
                v_tuple = tuple(v_list)
                
                if curr_b is not None: _, hvp_tuple = torch.autograd.functional.hvp(head_loss_func, inputs, v_tuple)
                else: _, hvp_tuple = torch.autograd.functional.hvp(head_loss_func_no_bias, inputs, v_tuple)
                
                hvp_flat = torch.cat([t.reshape(-1) for t in hvp_tuple])
                return (hvp_flat / max(total_samples, 1)) + self.args.damping * v_flat

            if self.args.newton_solver == 'cg':
                delta_flat = conjugate_gradient(mvp_fn, flat_g, max_iter=self.args.newton_cg_max_iter, device=self.device)
            
            elif self.args.newton_solver == 'neumann':
                # Scale needs to be small enough so that ||I - scale*H|| < 1
                # A crude heuristic is 1/max_curvature. Often 0.1 or 0.01 works for normalized data.
                delta_flat = neumann_series_solver(mvp_fn, flat_g, k=5, scale=0.1, device=self.device)
                
            elif self.args.newton_solver == 'lowrank':
                 # Lanczos algorithm to find top-k eigenvectors is complex.
                 # Simplified: use CG but terminate very early (k=1 or 2) which acts like low-rank approx
                 delta_flat = conjugate_gradient(mvp_fn, flat_g, max_iter=2, device=self.device)

        # Post-Processing
        if delta_flat is None: delta_flat = flat_g
        
        if self.args.max_norm > 0:
            total_norm = torch.norm(delta_flat)
            if total_norm > self.args.max_norm:
                delta_flat = delta_flat * (self.args.max_norm / (total_norm + 1e-6))

        delta_flat = delta_flat * self.args.newton_lr
        
        local_update_s, ptr = {}, 0
        for name, param in zip(head_keys, head_params):
            numel = param.numel()
            local_update_s[name] = delta_flat[ptr : ptr + numel].view(param.shape).cpu()
            ptr += numel

        if self.args.training_mode == 'full':
             for name in body_keys:
                 if name in global_grads:
                     local_update_s[name] = (self.args.lr * global_grads[name].to(self.device)).cpu()

        return local_update_s, total_samples

# ======================
# 4. 服务端逻辑
# ======================
def average_weights(w_list, s_list):
    w_avg = copy.deepcopy(w_list[0])
    total_samples = sum(s_list)
    for k in w_avg.keys():
        if 'num_batches_tracked' in k: continue
        weighted_sum = sum(w[k] * s for w, s in zip(w_list, s_list))
        w_avg[k] = weighted_sum / total_samples
    return w_avg

def average_gradients(g_list, s_list):
    if not g_list: return {}
    all_keys = set()
    for g in g_list: all_keys.update(g.keys())
    
    g_avg = {}
    total_samples = sum(s_list)
    
    for k in all_keys:
        weighted_sum = 0.0
        for g, s in zip(g_list, s_list):
            if k in g:
                weighted_sum += g[k] * s
        g_avg[k] = weighted_sum / total_samples
    return g_avg

def test_global(model, global_weights, test_loader, device):
    model.load_state_dict(global_weights)
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            loss_sum += loss.item() * x.size(0)
            _, pred = torch.max(out, 1)
            correct += (pred == y).sum().item()
            total += x.size(0)
    return loss_sum / max(total, 1), correct / max(total, 1)

# ======================
# 5. 主流程
# ======================
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', type=str, default='FedAvg', 
                        choices=['FedAvg', 'FedProx', 'FedNewton', 'FedNewton-Approx', 'FedSophia', 'FedNew', 'FedDANE'])
    parser.add_argument('--training_mode', type=str, default='full', choices=['full', 'head'], help='full: update all; head: update head only')
    parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, cifar100, mnist, femnist')
    parser.add_argument('--model_name', type=str, default='resnet18')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--rounds', type=int, default=20)
    parser.add_argument('--num_clients', type=int, default=10)
    parser.add_argument('--fraction', type=float, default=0.5)
    parser.add_argument('--local_epochs', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=32)
    # LR settings
    parser.add_argument('--lr', type=float, default=0.01, help="Local SGD LR or Global Learning Rate for Newton methods")
    parser.add_argument('--newton_lr', type=float, default=0.1, help="For FedNewton / FedDANE / FedNewton-Approx")
    # FedSophia / FedNew Params
    parser.add_argument('--sophia_lr', type=float, default=0.05, help="Server side LR for FedSophia/FedNew")
    parser.add_argument('--rho', type=float, default=0.04, help="Hessian smoothing parameter")
    parser.add_argument('--betas', type=str, default='0.9,0.99', help="Betas for Sophia/New (Momentum, Hessian average)")
    parser.add_argument('--mu', type=float, default=0.01, help="FedProx mu or FedDANE regularization")
    parser.add_argument('--log_dir', type=str, default='./logs')
    parser.add_argument('--optimizer', type=str, default='sgd')
    parser.add_argument('--hessian_batches', type=int, default=32)
    parser.add_argument('--damping', type=float, default=0.01, help="Damping for FedNewton and FedNewton-Approx")
    parser.add_argument('--alpha', type=float, default=0.5, help="Paramter to control data heterogeneity")
    parser.add_argument('--max_norm', type=float, default=0, help='Maximum L2 norm for the Newton update step. Set to 0 to disable.')
    
    # [New] Solver Options
    parser.add_argument('--newton_solver', type=str, default='exact', 
                        choices=['exact', 'cg', 'diag', 'neumann', 'lbfgs', 'lowrank'], 
                        help='Solver for linear system Hx=g')
    parser.add_argument('--newton_cg_max_iter', type=int, default=10, help='Max iterations for CG solver')
    parser.add_argument('--lbfgs_m', type=int, default=5, help='Memory size for L-BFGS')

    args = parser.parse_args()

    set_seed(args.seed)
    log_print(args)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.makedirs(args.log_dir, exist_ok=True)
    
    log_print(f"Start: {args.method} | Mode: {args.training_mode} | Dataset: {args.dataset} | Solver: {args.newton_solver}")

    train_ds, test_ds, num_classes, in_channels = get_dataset(args.dataset)
    
    test_loader = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=2)
    client_indices = split_dataset_non_iid(train_ds, args.num_clients, alpha=args.alpha, num_classes=num_classes)
    
    global_model = build_model(args.model_name, num_classes, in_channels).to(device)
    global_weights = global_model.state_dict()

    # 初始化 FedSophia / FedNew 需要的动量 m 和 Hessian 累积 h
    sophia_m = {}
    sophia_h = {}
    beta1, beta2 = [float(x) for x in args.betas.split(',')]
    
    for k, v in global_weights.items():
        if v.dtype == torch.float32 or v.dtype == torch.float64:
            sophia_m[k] = torch.zeros_like(v)
            sophia_h[k] = torch.zeros_like(v)

    # L-BFGS 状态
    lbfgs_state = {
        's_list': [], # parameter updates
        'y_list': [], # gradient differences
        'rho_list': [],
        'prev_w': None,
        'prev_g': None
    }

    clients = [Client(i, train_ds, client_indices[i], device, args, num_classes, in_channels) for i in range(args.num_clients)]

    logs = {"rounds": [], "test_acc": [], "test_loss": [], "train_loss": [], "wall_time": [], "max_gpu_mem": []}
    start_time = time.time()

    for rnd in range(1, args.rounds + 1):
        start_rnd_time = time.time() 

        # 重置显存统计
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats(device)

        m = max(int(args.fraction * args.num_clients), 1)
        selected_users = np.random.choice(range(args.num_clients), m, replace=False)
        
        local_losses = []
        local_samples_list = []

        # ==========================================
        # Method Implementation
        # ==========================================
        if args.method in ['FedSophia', 'FedNew']:
            local_deltas = []
            local_hessians = []
            
            for idx in selected_users:
                w_before = copy.deepcopy(global_weights)
                w_after, n, loss = clients[idx].local_train(w_before)
                
                diff = {}
                for k in w_before.keys():
                    if k in w_after:
                        diff[k] = (w_after[k] - w_before[k]).cpu() 
                
                local_deltas.append(diff)
                local_losses.append(loss)
                local_samples_list.append(n)
                
                h_local, _ = clients[idx].compute_diag_hessian(w_before)
                local_hessians.append(h_local)

            avg_diff = average_weights(local_deltas, local_samples_list)
            avg_h = average_gradients(local_hessians, local_samples_list)

            for k in global_weights.keys():
                if k not in avg_diff or k not in sophia_m: continue
                g = -avg_diff[k].to(device)
                sophia_m[k] = beta1 * sophia_m[k] + (1 - beta1) * g
                
                if k in avg_h:
                    h = avg_h[k].to(device)
                    sophia_h[k] = beta2 * sophia_h[k] + (1 - beta2) * h
                
                if args.method == 'FedSophia':
                    preconditioner = torch.maximum(sophia_h[k], torch.tensor(args.rho, device=device))
                    step = args.sophia_lr * (sophia_m[k] / preconditioner)
                else: # FedNew
                    preconditioner = sophia_h[k] + args.rho
                    step = args.sophia_lr * (sophia_m[k] / preconditioner)

                global_weights[k] -= step

            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Strategy: FedAvg / FedProx ---
        elif args.method in ['FedAvg', 'FedProx']:
            local_weights = []
            for idx in selected_users:
                w, n, loss = clients[idx].local_train(copy.deepcopy(global_weights))
                local_weights.append(w)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_weights = average_weights(local_weights, local_samples_list)
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Strategy: FedNewton (The Master Solver) ---
        elif args.method == 'FedNewton':
            local_grads_list = []
            for idx in selected_users:
                g, n, loss = clients[idx].compute_gradient(copy.deepcopy(global_weights))
                local_grads_list.append(g)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_grads = average_gradients(local_grads_list, local_samples_list)
            
            # --- L-BFGS Update (Global State) ---
            if args.newton_solver == 'lbfgs':
                current_w_flat = []
                current_g_flat = []
                
                # 1. 定义需要追踪的层
                target_keys = []
                for k, p in global_model.named_parameters():
                    if not p.requires_grad: continue
                    
                    # [Fix] 根据 training_mode 过滤参数
                    is_head = ('head' in k) or ('fc' in k) or ('classifier' in k)
                    if args.training_mode == 'head' and not is_head:
                        continue
                    
                    target_keys.append(k)

                # 2. 扁平化这些特定参数
                for k in target_keys:
                    p = global_weights[k]
                    current_w_flat.append(p.view(-1).to(device))
                    
                    if k in global_grads:
                        current_g_flat.append(global_grads[k].view(-1).to(device))
                    else:
                        current_g_flat.append(torch.zeros_like(p).view(-1).to(device))
                
                if len(current_w_flat) > 0:
                    curr_w_vec = torch.cat(current_w_flat)
                    curr_g_vec = torch.cat(current_g_flat)
                    
                    if lbfgs_state['prev_w'] is not None:
                        # 确保维度一致，如果不一致（比如第一轮或模式切换），重置
                        if curr_w_vec.shape != lbfgs_state['prev_w'].shape:
                            print(f"[Warning] L-BFGS dimension mismatch ({lbfgs_state['prev_w'].shape} vs {curr_w_vec.shape}). Resetting state.")
                            lbfgs_state['s_list'] = []
                            lbfgs_state['y_list'] = []
                            lbfgs_state['rho_list'] = []
                            lbfgs_state['prev_w'] = None
                            lbfgs_state['prev_g'] = None
                        else:
                            s = curr_w_vec - lbfgs_state['prev_w']
                            y = curr_g_vec - lbfgs_state['prev_g']
                            
                            denom = torch.dot(y, s)
                            if denom > 1e-10:
                                rho = 1.0 / denom
                                lbfgs_state['s_list'].append(s)
                                lbfgs_state['y_list'].append(y)
                                lbfgs_state['rho_list'].append(rho)
                                if len(lbfgs_state['s_list']) > args.lbfgs_m:
                                    lbfgs_state['s_list'].pop(0)
                                    lbfgs_state['y_list'].pop(0)
                                    lbfgs_state['rho_list'].pop(0)
                    
                    lbfgs_state['prev_w'] = curr_w_vec.clone()
                    lbfgs_state['prev_g'] = curr_g_vec.clone()

            # Calculate Step
            local_steps_list = []
            for idx in selected_users:
                s_j, _ = clients[idx].compute_newton_step(copy.deepcopy(global_weights), global_grads, lbfgs_state)
                local_steps_list.append(s_j)

            global_step = average_gradients(local_steps_list, local_samples_list)

            for name, param in global_weights.items():
                if name in global_step:
                    if param.dtype == torch.long: continue 
                    global_weights[name] = param - global_step[name].to(device)
            
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Legacy: FedNewton-Approx (Can be replaced by FedNewton + solver='diag') ---
        elif args.method == 'FedNewton-Approx':
            local_grads_list = []
            for idx in selected_users:
                g, n, loss = clients[idx].compute_gradient(copy.deepcopy(global_weights))
                local_grads_list.append(g)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_grads = average_gradients(local_grads_list, local_samples_list)
            
            local_steps_list = []
            for idx in selected_users:
                s_j, _ = clients[idx].compute_newton_approx_step(copy.deepcopy(global_weights), global_grads)
                local_steps_list.append(s_j)

            global_step = average_gradients(local_steps_list, local_samples_list)
            for name, param in global_weights.items():
                if name in global_step:
                    if param.dtype == torch.long: continue 
                    global_weights[name] = param - args.newton_lr * global_step[name].to(device)
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Legacy: FedDANE ---
        elif args.method == 'FedDANE':
            local_grads_list = []
            for idx in selected_users:
                g, n, loss = clients[idx].compute_gradient(copy.deepcopy(global_weights))
                local_grads_list.append(g)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_grads = average_gradients(local_grads_list, local_samples_list)
            
            local_steps_list = []
            for idx in selected_users:
                s_j, _ = clients[idx].compute_dane_step(copy.deepcopy(global_weights), global_grads)
                local_steps_list.append(s_j)

            global_step = average_gradients(local_steps_list, local_samples_list)
            for name, param in global_weights.items():
                if name in global_step:
                    if param.dtype == torch.long: continue 
                    global_weights[name] = param - args.newton_lr * global_step[name].to(device)
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)


        test_loss, test_acc = test_global(global_model, global_weights, test_loader, device)
        
        # 时间计算
        total_elapsed = time.time() - start_time      
        round_elapsed = time.time() - start_rnd_time 
        
        # 记录显存
        max_mem = 0.0
        if torch.cuda.is_available():
            max_mem = torch.cuda.max_memory_allocated(device) / (1024 * 1024)

        logs["rounds"].append(rnd)
        logs["test_acc"].append(test_acc)
        logs["test_loss"].append(test_loss)
        logs["train_loss"].append(train_loss_avg)
        logs["wall_time"].append(total_elapsed)
        logs["max_gpu_mem"].append(max_mem)
        
        log_print(f"[Round {rnd:03d}] {args.method} (Solver: {args.newton_solver}) | Acc: {test_acc*100:.2f}% | Loss: {test_loss:.4f} | Time: {round_elapsed:.2f}s | Max GPU Mem: {max_mem:.2f} MB")

        if test_loss < 1e-6:
            log_print("Early stopping as test loss is very low.")
            break
        elif test_loss > 10.0:
            log_print("Early stopping as test loss is very high.")
            break

    save_filename = f"{args.method}_{args.training_mode}_{args.dataset}_{args.model_name}_{args.newton_solver}_s{args.seed}.npz"
    save_path = os.path.join(args.log_dir, save_filename)
    np.savez(save_path, method=args.method, solver=args.newton_solver, training_mode=args.training_mode, **logs)
    log_print(f"Saved to {save_path}")
