import argparse
import random
import copy
import datetime
import time
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import warnings

# ======================
# 0. Logging Utility
# ======================
def log_print(*args, **kwargs):
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    print(f"[{timestamp}]", *args, **kwargs, flush=True)

# Try importing clip
try:
    import clip
except ImportError:
    clip = None

# Try importing timm
try:
    import timm
except ImportError:
    timm = None

# Ignore partial torchvision compatibility warnings
warnings.filterwarnings("ignore")

# ======================
# 1. Utility Functions & Dataset Management
# ======================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_dataset(name, root='./data'):
    name = name.lower()
    
    # Default Transforms
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)) 
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    if name == 'cifar10':
        num_classes = 10
        in_channels = 3
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC) 
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC) 
        ])
        train_dataset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)

    elif name == 'cifar100':
        num_classes = 100
        in_channels = 3
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC)
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC)
        ])
        train_dataset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test)

    elif name == 'mnist':
        num_classes = 10
        in_channels = 1
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_dataset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform_test)
        
    elif name == 'fashionmnist':
        num_classes = 10
        in_channels = 1
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        train_dataset = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform_test)
    
    else:
        raise ValueError(f"Dataset {name} not supported.")

    return train_dataset, test_dataset, num_classes, in_channels

def split_dataset_non_iid(train_dataset, num_clients, alpha=0.5, num_classes=10):
    """Split Non-IID data using Dirichlet distribution"""
    if hasattr(train_dataset, 'targets'):
        labels = np.array(train_dataset.targets)
    elif hasattr(train_dataset, 'labels'):
        labels = np.array(train_dataset.labels)
    else:
        labels = np.array([y for _, y in train_dataset])

    idxs = np.arange(len(train_dataset))
    client_indices = [[] for _ in range(num_clients)]

    for c in range(num_classes):
        idxs_c = idxs[labels == c]
        np.random.shuffle(idxs_c)
        
        proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
        proportions = np.array([p * (len(idxs_c)) for p in proportions])
        proportions = proportions.astype(int)
        proportions[-1] = len(idxs_c) - np.sum(proportions[:-1])
        
        if proportions[-1] < 0: proportions[-1] = 0
        
        split_idxs = np.split(idxs_c, np.cumsum(proportions)[:-1])
        for client_id, idx in enumerate(split_idxs):
            client_indices[client_id].extend(idx.tolist())

    for cid in range(num_clients):
        np.random.shuffle(client_indices[cid])
        
    return client_indices

# ======================
# 1.5 Optimizer Tools (CG, Neumann, L-BFGS, etc.)
# ======================

def conjugate_gradient(mvp_fn, b, max_iter=10, tol=1e-4, device='cpu'):
    """CG Solver for Ax=b"""
    x = torch.zeros_like(b).to(device)
    r = b.clone()
    p = r.clone()
    rsold = torch.dot(r, r)

    for i in range(max_iter):
        Ap = mvp_fn(p)
        alpha = rsold / (torch.dot(p, Ap) + 1e-8)
        x = x + alpha * p
        r = r - alpha * Ap
        rsnew = torch.dot(r, r)
        if torch.sqrt(rsnew) < tol:
            break
        p = r + (rsnew / rsold) * p
        rsold = rsnew
    return x

def neumann_series_solver(mvp_fn, b, k=5, scale=1.0, device='cpu'):
    """
    Solve Ax = b using Neumann Series Approximation.
    Approximate A^{-1} ≈ sum (I - scale*A)^k * scale
    """
    x = b.clone().to(device) * scale
    term = b.clone().to(device)
    
    for _ in range(k):
        # term_{i+1} = (I - scale*A) * term_i
        A_term = mvp_fn(term)
        term = term - scale * A_term
        x = x + term * scale
    return x

def lbfgs_solver(g_flat, s_list, y_list, rho_list, gamma=1.0):
    """
    Two-loop recursion for L-BFGS direction.
    Computes H_k * g_flat where H_k is L-BFGS inverse Hessian approx.
    """
    q = g_flat.clone()
    alpha_list = []
    m = len(s_list)
    
    # First loop
    for i in range(m - 1, -1, -1):
        s = s_list[i]
        y = y_list[i]
        rho = rho_list[i]
        alpha = rho * torch.dot(s, q)
        alpha_list.append(alpha)
        q = q - alpha * y
        
    # Initial approximation H_0
    r = gamma * q
    
    # Second loop
    for i in range(m):
        idx = m - 1 - i 
        s = s_list[i]
        y = y_list[i]
        rho = rho_list[i]
        alpha = alpha_list[idx]
        beta = rho * torch.dot(y, r)
        r = r + s * (alpha - beta)
        
    return r

# ======================
# 2. Model Definitions
# ======================

class CLIPModelWrapper(nn.Module):
    def __init__(self, backbone, input_dim, num_classes):
        super().__init__()
        self.backbone = backbone.float()
        self.head = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        features = self.backbone(x)
        return self.head(features)

class TimmModelWrapper(nn.Module):
    def __init__(self, model_name, num_classes, pretrained=True):
        super(TimmModelWrapper, self).__init__()
        try:
            self.backbone = timm.create_model(
                model_name, 
                pretrained=pretrained, 
                num_classes=0, 
                img_size=224
            )
        except TypeError:
            self.backbone = timm.create_model(
                model_name, 
                pretrained=pretrained, 
                num_classes=0
            )
        self.in_features = self.backbone.num_features
        self.head = nn.Linear(self.in_features, num_classes)

    def forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        features = self.backbone(x)
        out = self.head(features)
        return out

def build_model(model_name, num_classes, in_channels, force_random=False):
    use_pretrained = True 
    if "clip" in model_name.lower():
        try:
            import clip
            model, _ = clip.load(model_name, device="cpu")
            if hasattr(model, 'visual'):
                input_dim = model.visual.output_dim
                return CLIPModelWrapper(model.visual, input_dim, num_classes)
        except Exception as e:
            log_print(f"[Warning] Failed to load via CLIP library: {e}. Trying timm...")

    try:
        return TimmModelWrapper(
            model_name=model_name, 
            num_classes=num_classes, 
            pretrained=use_pretrained
        )
    except Exception as e:
        raise ValueError(f"Model '{model_name}' not found in torchvision or timm.\nError: {e}")


# ======================
# 3. Client Logic
# ======================
class Client:
    def __init__(self, client_id, dataset, indices, device, args, num_classes, in_channels):
        self.client_id = client_id
        self.device = device
        self.args = args
        self.num_classes = num_classes
        self.in_channels = in_channels
        
        should_drop = len(indices) > args.batch_size
        self.loader = DataLoader(
            Subset(dataset, indices), 
            batch_size=args.batch_size, 
            shuffle=True,
            drop_last=should_drop
        )
        self.model = build_model(args.model_name, num_classes, in_channels, force_random=True).to(self.device)

    def _set_freeze_status(self):
        if self.args.training_mode == 'head':
            for param in self.model.parameters():
                param.requires_grad = False
            
            if hasattr(self.model, 'head'):
                for param in self.model.head.parameters():
                    param.requires_grad = True
            elif hasattr(self.model, 'fc'):
                for param in self.model.fc.parameters():
                    param.requires_grad = True
            elif hasattr(self.model, 'classifier'):
                for param in self.model.classifier.parameters():
                    param.requires_grad = True
            else:
                log_print(f"[Error] Client {self.client_id}: Could not find head layer to unfreeze!")

            self.model.eval() 
            if hasattr(self.model, 'head'): self.model.head.train()
        else:
            self.model.train()
            for param in self.model.parameters():
                param.requires_grad = True

    def local_train(self, global_state_dict):
        """FedAvg / FedProx / FedSophia Step 1 (Gradient Accumulation)"""
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status() 
        
        trainable_params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
        if len(trainable_params) == 0:
            return self.model.state_dict(), 0, 0.0
        
        if self.args.optimizer == 'sgd':
            optimizer = optim.SGD(trainable_params, lr=self.args.lr, momentum=0.9)
        else:
            optimizer = optim.Adam(trainable_params, lr=self.args.lr)
            
        criterion = nn.CrossEntropyLoss()
        global_params_dict = {}
        if self.args.method == "FedProx":
             global_params_dict = {n: p.detach().clone() for n, p in self.model.named_parameters() if p.requires_grad}
        
        num_samples = 0
        loss_sum = 0.0

        for _ in range(self.args.local_epochs):
            for x, y in self.loader:
                x, y = x.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                logits = self.model(x)
                loss = criterion(logits, y)

                if self.args.method == "FedProx" and self.args.mu > 0:
                    prox_term = 0.0
                    for n, p in self.model.named_parameters():
                        if p.requires_grad:
                            prox_term += (p - global_params_dict[n]).norm(2)**2
                    loss += (self.args.mu / 2) * prox_term

                loss.backward()
                optimizer.step()

                num_samples += x.size(0)
                loss_sum += loss.item() * x.size(0)

        return self.model.state_dict(), num_samples, loss_sum / max(num_samples, 1)

    def compute_diag_hessian(self, global_state_dict):
        """
        FedSophia / FedNew / FedNewton-Approx: Compute diagonal Hessian estimation.
        """
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status()
        self.model.train()
        
        criterion = nn.CrossEntropyLoss(reduction='sum') 
        
        try:
            x, y = next(iter(self.loader))
        except StopIteration:
            return {}, 0

        x, y = x.to(self.device), y.to(self.device)
        bs = x.size(0)
        
        logits = self.model(x)
        probs = torch.softmax(logits, dim=1)
        y_sampled = torch.multinomial(probs, num_samples=1).squeeze(1)
        
        loss = criterion(logits, y_sampled)
        
        grads = torch.autograd.grad(loss, filter(lambda p: p.requires_grad, self.model.parameters()), create_graph=False)
        
        h_diag_local = {}
        idx = 0
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                g = grads[idx]
                h_diag_local[name] = (g * g).clone().detach() / bs
                idx += 1
        
        return h_diag_local, bs

    def compute_gradient(self, global_state_dict):
        """FedNewton / FedNew (Gradient) / FedDANE (Gradient)"""
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status()
        
        trainable_params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
        if len(trainable_params) == 0:
            return {}, 0, 0.0

        criterion = nn.CrossEntropyLoss()
        # Zero grad manually
        for p in trainable_params:
            if p.grad is not None: p.grad.zero_()
        
        total_loss = 0.0
        total_samples = 0
        
        for x, y in self.loader:
            x, y = x.to(self.device), y.to(self.device)
            logits = self.model(x)
            loss = criterion(logits, y)
            (loss * x.size(0)).backward() 
            total_loss += loss.item() * x.size(0)
            total_samples += x.size(0)
            
        local_grads = {}
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                local_grads[name] = param.grad.clone().detach() / max(total_samples, 1)
        
        return local_grads, total_samples, total_loss / max(total_samples, 1)

    def compute_newton_step(self, global_state_dict, global_grads, lbfgs_state=None):
        r"""
        FedNewton Solver Hub:
        Supports: exact, cg, diag, neumann, lbfgs, lowrank
        """
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status() 
        self.model.eval() 

        # --- L-BFGS Special Path ---
        if self.args.newton_solver == 'lbfgs':
            if lbfgs_state is None or len(lbfgs_state['s_list']) == 0:
                # If no L-BFGS history, fallback to gradient descent (return gradient)
                local_update_s = {}
                for name, param in self.model.named_parameters():
                    if not param.requires_grad: continue
                    is_head = ('head' in name) or ('fc' in name) or ('classifier' in name)
                    if self.args.training_mode == 'head' and not is_head: continue

                    if name in global_grads:
                        local_update_s[name] = global_grads[name].cpu()
                return local_update_s, 0
            
            # Flatten global gradient (MUST MATCH SERVER FILTERING LOGIC)
            g_parts = []
            target_keys = []
            
            for name, param in self.model.named_parameters():
                if not param.requires_grad: continue
                
                # [Fix] Must match server filtering logic
                is_head = ('head' in name) or ('fc' in name) or ('classifier' in name)
                if self.args.training_mode == 'head' and not is_head: continue
                
                target_keys.append(name)
                
                if name in global_grads:
                     g_parts.append(global_grads[name].view(-1).to(self.device))
                else:
                     g_parts.append(torch.zeros_like(param).view(-1).to(self.device))
            
            if not g_parts: return {}, 0 # Nothing to update

            flat_g = torch.cat(g_parts)

            # Check dimension mismatch
            if flat_g.shape[0] != lbfgs_state['s_list'][0].shape[0]:
                 # If mismatch, fallback to gradient
                 local_update_s = {}
                 for k in target_keys:
                     if k in global_grads: local_update_s[k] = global_grads[k].cpu()
                 return local_update_s, 0

            delta_flat = lbfgs_solver(
                flat_g, 
                lbfgs_state['s_list'], 
                lbfgs_state['y_list'], 
                lbfgs_state['rho_list'],
                gamma=1.0
            )

            delta_flat = delta_flat * self.args.newton_lr
            
            local_update_s, ptr = {}, 0
            for name in target_keys:
                # Note: re-fetch param to determine shape
                param = dict(self.model.named_parameters())[name]
                numel = param.numel()
                local_update_s[name] = delta_flat[ptr:ptr+numel].view(param.shape).cpu()
                ptr += numel
                
            return local_update_s, 1 

        # --- Preparation for Second Order Methods ---
        head_keys, head_params, body_keys = [], [], []
        for name, param in self.model.named_parameters():
            if not param.requires_grad: continue
            if 'head' in name or 'fc' in name or 'classifier' in name:
                head_keys.append(name)
                head_params.append(param)
            else:
                body_keys.append(name)

        head_module = None
        if hasattr(self.model, 'head'): head_module = self.model.head
        elif hasattr(self.model, 'fc'): head_module = self.model.fc
        elif hasattr(self.model, 'classifier'): head_module = self.model.classifier
        
        if head_module is None: return {}, 0

        # Data Gathering
        feats_container = []
        def hook_fn(module, input, output): feats_container.append(input[0].detach()) 
        handle = head_module.register_forward_hook(hook_fn)

        features_list, targets_list, total_samples = [], [], 0
        for i, (x, y) in enumerate(self.loader):
            if i >= self.args.hessian_batches: break
            x, y = x.to(self.device), y.to(self.device)
            total_samples += x.size(0)
            feats_container = [] 
            with torch.no_grad(): self.model(x) 
            if len(feats_container) > 0:
                features_list.append(feats_container[0])
                targets_list.append(y)
        handle.remove()
        
        if total_samples == 0 or not features_list: return {}, 0
        all_features = torch.cat(features_list, dim=0)
        all_targets = torch.cat(targets_list, dim=0)

        curr_w = head_module.weight.detach()
        curr_b = head_module.bias.detach() if head_module.bias is not None else None
        inputs = (curr_w, curr_b) if curr_b is not None else (curr_w,)

        # Prepare Flat Gradient
        g_parts = []
        for k in head_keys:
            if k in global_grads: g_parts.append(global_grads[k].view(-1).to(self.device))
            else: g_parts.append(torch.zeros_like(self.model.state_dict()[k]).view(-1).to(self.device))
        flat_g = torch.cat(g_parts)
        delta_flat = None

        # Helper Functions
        def head_loss_func(w, b):
            out = torch.nn.functional.linear(all_features, w, b)
            return torch.nn.functional.cross_entropy(out, all_targets, reduction='sum')
        def head_loss_func_no_bias(w):
            out = torch.nn.functional.linear(all_features, w, None)
            return torch.nn.functional.cross_entropy(out, all_targets, reduction='sum')

        # ==========================================
        # Solver Branching
        # ==========================================

        # 1. Exact
        if self.args.newton_solver == 'exact':
            if curr_b is not None: H_tuple = torch.autograd.functional.hessian(head_loss_func, inputs)
            else: H_tuple = ((torch.autograd.functional.hessian(head_loss_func_no_bias, inputs),),)
            
            H_blocks_rows = []
            for i in range(len(inputs)):
                row_blocks = []
                for j in range(len(inputs)):
                    block = H_tuple[i][j].reshape(inputs[i].numel(), inputs[j].numel())
                    row_blocks.append(block)
                H_blocks_rows.append(torch.cat(row_blocks, dim=1))
            H_mat = torch.cat(H_blocks_rows, dim=0) / max(total_samples, 1)
            
            I = torch.eye(H_mat.size(0), device=self.device)
            reg_H = H_mat + self.args.damping * I
            try:
                delta_flat = torch.linalg.solve(reg_H, flat_g)
            except:
                delta_flat = flat_g # Fallback

        # 2. Diagonal Approximation (Calculated on the fly from exact Hessian diagonal for accuracy)
        elif self.args.newton_solver == 'diag':
             # Using hvp with standard basis vectors is slow, so we use autograd loop or simple approximation
             # Here we fallback to a quick diagonal estimation using squared gradients similar to FedSophia
             # but on current batch data
             
            # Calculate per-sample gradients for diagonal estimation
            # This is a simplified version. A true diag(H) needs second derivatives.
            # We use the Gauss-Newton approximation diag(G^T G) or simply return diag-inverse.
            
            # Better approach: Compute diagonal of Hessian using randomization (Hutchinson) or just use FedSophia logic
            # For simplicity here, we use the `compute_diag_hessian` logic but applied to the head only data
            h_diag, _ = self.compute_diag_hessian(global_state_dict) # Re-uses existing function
            
            h_parts = []
            for k in head_keys:
                if k in h_diag: h_parts.append(h_diag[k].view(-1).to(self.device))
                else: h_parts.append(torch.ones_like(flat_g[:1]).to(self.device)) # Dummy
            
            flat_h = torch.cat(h_parts)
            delta_flat = flat_g / (flat_h + self.args.damping)

        # 3. CG & Neumann & LowRank (Matrix-Free Methods)
        elif self.args.newton_solver in ['cg', 'neumann', 'lowrank']:
            
            def mvp_fn(v_flat):
                v_list, ptr = [], 0
                for inp in inputs:
                    numel = inp.numel()
                    v_list.append(v_flat[ptr : ptr + numel].view(inp.shape))
                    ptr += numel
                v_tuple = tuple(v_list)
                
                if curr_b is not None: _, hvp_tuple = torch.autograd.functional.hvp(head_loss_func, inputs, v_tuple)
                else: _, hvp_tuple = torch.autograd.functional.hvp(head_loss_func_no_bias, inputs, v_tuple)
                
                hvp_flat = torch.cat([t.reshape(-1) for t in hvp_tuple])
                return (hvp_flat / max(total_samples, 1)) + self.args.damping * v_flat

            if self.args.newton_solver == 'cg':
                delta_flat = conjugate_gradient(mvp_fn, flat_g, max_iter=self.args.newton_cg_max_iter, device=self.device)
            
            elif self.args.newton_solver == 'neumann':
                # Scale needs to be small enough so that ||I - scale*H|| < 1
                # A crude heuristic is 1/max_curvature. Often 0.1 or 0.01 works for normalized data.
                delta_flat = neumann_series_solver(mvp_fn, flat_g, k=5, scale=0.1, device=self.device)
                
            elif self.args.newton_solver == 'lowrank':
                 # Lanczos algorithm to find top-k eigenvectors is complex.
                 # Simplified: use CG but terminate very early (k=1 or 2) which acts like low-rank approx
                 delta_flat = conjugate_gradient(mvp_fn, flat_g, max_iter=2, device=self.device)

        # Post-Processing
        if delta_flat is None: delta_flat = flat_g
        
        if self.args.max_norm > 0:
            total_norm = torch.norm(delta_flat)
            if total_norm > self.args.max_norm:
                delta_flat = delta_flat * (self.args.max_norm / (total_norm + 1e-6))

        delta_flat = delta_flat * self.args.newton_lr
        
        local_update_s, ptr = {}, 0
        for name, param in zip(head_keys, head_params):
            numel = param.numel()
            local_update_s[name] = delta_flat[ptr : ptr + numel].view(param.shape).cpu()
            ptr += numel

        if self.args.training_mode == 'full':
             for name in body_keys:
                 if name in global_grads:
                     local_update_s[name] = (self.args.lr * global_grads[name].to(self.device)).cpu()

        return local_update_s, total_samples

# ======================
# 4. Server Logic
# ======================
def average_weights(w_list, s_list):
    w_avg = copy.deepcopy(w_list[0])
    total_samples = sum(s_list)
    for k in w_avg.keys():
        if 'num_batches_tracked' in k: continue
        weighted_sum = sum(w[k] * s for w, s in zip(w_list, s_list))
        w_avg[k] = weighted_sum / total_samples
    return w_avg

def average_gradients(g_list, s_list):
    if not g_list: return {}
    all_keys = set()
    for g in g_list: all_keys.update(g.keys())
    
    g_avg = {}
    total_samples = sum(s_list)
    
    for k in all_keys:
        weighted_sum = 0.0
        for g, s in zip(g_list, s_list):
            if k in g:
                weighted_sum += g[k] * s
        g_avg[k] = weighted_sum / total_samples
    return g_avg

def test_global(model, global_weights, test_loader, device):
    model.load_state_dict(global_weights)
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            loss_sum += loss.item() * x.size(0)
            _, pred = torch.max(out, 1)
            correct += (pred == y).sum().item()
            total += x.size(0)
    return loss_sum / max(total, 1), correct / max(total, 1)

# ======================
# 5. Main Process
# ======================
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', type=str, default='FedAvg', 
                        choices=['FedAvg', 'FedProx', 'FedNewton', 'FedNewton-Approx', 'FedSophia', 'FedNew', 'FedDANE'])
    parser.add_argument('--training_mode', type=str, default='full', choices=['full', 'head'], help='full: update all; head: update head only')
    parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, cifar100, mnist, femnist')
    parser.add_argument('--model_name', type=str, default='resnet18')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--rounds', type=int, default=20)
    parser.add_argument('--num_clients', type=int, default=10)
    parser.add_argument('--fraction', type=float, default=0.5)
    parser.add_argument('--local_epochs', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=32)
    # LR settings
    parser.add_argument('--lr', type=float, default=0.01, help="Local SGD LR or Global Learning Rate for Newton methods")
    parser.add_argument('--newton_lr', type=float, default=0.1, help="For FedNewton / FedDANE / FedNewton-Approx")
    # FedSophia / FedNew Params
    parser.add_argument('--sophia_lr', type=float, default=0.05, help="Server side LR for FedSophia/FedNew")
    parser.add_argument('--rho', type=float, default=0.04, help="Hessian smoothing parameter")
    parser.add_argument('--betas', type=str, default='0.9,0.99', help="Betas for Sophia/New (Momentum, Hessian average)")
    parser.add_argument('--mu', type=float, default=0.01, help="FedProx mu or FedDANE regularization")
    parser.add_argument('--log_dir', type=str, default='./logs')
    parser.add_argument('--optimizer', type=str, default='sgd')
    parser.add_argument('--hessian_batches', type=int, default=32)
    parser.add_argument('--damping', type=float, default=0.01, help="Damping for FedNewton and FedNewton-Approx")
    parser.add_argument('--alpha', type=float, default=0.5, help="Paramter to control data heterogeneity")
    parser.add_argument('--max_norm', type=float, default=0, help='Maximum L2 norm for the Newton update step. Set to 0 to disable.')
    
    # [New] Solver Options
    parser.add_argument('--newton_solver', type=str, default='exact', 
                        choices=['exact', 'cg', 'diag', 'neumann', 'lbfgs', 'lowrank'], 
                        help='Solver for linear system Hx=g')
    parser.add_argument('--newton_cg_max_iter', type=int, default=10, help='Max iterations for CG solver')
    parser.add_argument('--lbfgs_m', type=int, default=5, help='Memory size for L-BFGS')

    args = parser.parse_args()

    set_seed(args.seed)
    log_print(args)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.makedirs(args.log_dir, exist_ok=True)
    
    log_print(f"Start: {args.method} | Mode: {args.training_mode} | Dataset: {args.dataset} | Solver: {args.newton_solver}")

    train_ds, test_ds, num_classes, in_channels = get_dataset(args.dataset)
    
    test_loader = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=2)
    client_indices = split_dataset_non_iid(train_ds, args.num_clients, alpha=args.alpha, num_classes=num_classes)
    
    global_model = build_model(args.model_name, num_classes, in_channels).to(device)
    global_weights = global_model.state_dict()

    # Initialize Momentum (m) and Hessian Accumulation (h) for FedSophia / FedNew
    sophia_m = {}
    sophia_h = {}
    beta1, beta2 = [float(x) for x in args.betas.split(',')]
    
    for k, v in global_weights.items():
        if v.dtype == torch.float32 or v.dtype == torch.float64:
            sophia_m[k] = torch.zeros_like(v)
            sophia_h[k] = torch.zeros_like(v)

    # L-BFGS State
    lbfgs_state = {
        's_list': [], # parameter updates
        'y_list': [], # gradient differences
        'rho_list': [],
        'prev_w': None,
        'prev_g': None
    }

    clients = [Client(i, train_ds, client_indices[i], device, args, num_classes, in_channels) for i in range(args.num_clients)]

    logs = {"rounds": [], "test_acc": [], "test_loss": [], "train_loss": [], "wall_time": [], "max_gpu_mem": []}
    start_time = time.time()

    for rnd in range(1, args.rounds + 1):
        start_rnd_time = time.time() 

        # Reset GPU memory stats
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats(device)

        m = max(int(args.fraction * args.num_clients), 1)
        selected_users = np.random.choice(range(args.num_clients), m, replace=False)
        
        local_losses = []
        local_samples_list = []

        # ==========================================
        # Method Implementation
        # ==========================================
        if args.method in ['FedSophia', 'FedNew']:
            local_deltas = []
            local_hessians = []
            
            for idx in selected_users:
                w_before = copy.deepcopy(global_weights)
                w_after, n, loss = clients[idx].local_train(w_before)
                
                diff = {}
                for k in w_before.keys():
                    if k in w_after:
                        diff[k] = (w_after[k] - w_before[k]).cpu() 
                
                local_deltas.append(diff)
                local_losses.append(loss)
                local_samples_list.append(n)
                
                h_local, _ = clients[idx].compute_diag_hessian(w_before)
                local_hessians.append(h_local)

            avg_diff = average_weights(local_deltas, local_samples_list)
            avg_h = average_gradients(local_hessians, local_samples_list)

            for k in global_weights.keys():
                if k not in avg_diff or k not in sophia_m: continue
                g = -avg_diff[k].to(device)
                sophia_m[k] = beta1 * sophia_m[k] + (1 - beta1) * g
                
                if k in avg_h:
                    h = avg_h[k].to(device)
                    sophia_h[k] = beta2 * sophia_h[k] + (1 - beta2) * h
                
                if args.method == 'FedSophia':
                    preconditioner = torch.maximum(sophia_h[k], torch.tensor(args.rho, device=device))
                    step = args.sophia_lr * (sophia_m[k] / preconditioner)
                else: # FedNew
                    preconditioner = sophia_h[k] + args.rho
                    step = args.sophia_lr * (sophia_m[k] / preconditioner)

                global_weights[k] -= step

            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Strategy: FedAvg / FedProx ---
        elif args.method in ['FedAvg', 'FedProx']:
            local_weights = []
            for idx in selected_users:
                w, n, loss = clients[idx].local_train(copy.deepcopy(global_weights))
                local_weights.append(w)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_weights = average_weights(local_weights, local_samples_list)
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Strategy: FedNewton (The Master Solver) ---
        elif args.method == 'FedNewton':
            local_grads_list = []
            for idx in selected_users:
                g, n, loss = clients[idx].compute_gradient(copy.deepcopy(global_weights))
                local_grads_list.append(g)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_grads = average_gradients(local_grads_list, local_samples_list)
            
            # --- L-BFGS Update (Global State) ---
            if args.newton_solver == 'lbfgs':
                current_w_flat = []
                current_g_flat = []
                
                # 1. Define layers to track
                target_keys = []
                for k, p in global_model.named_parameters():
                    if not p.requires_grad: continue
                    
                    # [Fix] Filter parameters based on training_mode
                    is_head = ('head' in k) or ('fc' in k) or ('classifier' in k)
                    if args.training_mode == 'head' and not is_head:
                        continue
                    
                    target_keys.append(k)

                # 2. Flatten these specific parameters
                for k in target_keys:
                    p = global_weights[k]
                    current_w_flat.append(p.view(-1).to(device))
                    
                    if k in global_grads:
                        current_g_flat.append(global_grads[k].view(-1).to(device))
                    else:
                        current_g_flat.append(torch.zeros_like(p).view(-1).to(device))
                
                if len(current_w_flat) > 0:
                    curr_w_vec = torch.cat(current_w_flat)
                    curr_g_vec = torch.cat(current_g_flat)
                    
                    if lbfgs_state['prev_w'] is not None:
                        # Ensure dimension match; if not (e.g., first round or mode switch), reset
                        if curr_w_vec.shape != lbfgs_state['prev_w'].shape:
                            print(f"[Warning] L-BFGS dimension mismatch ({lbfgs_state['prev_w'].shape} vs {curr_w_vec.shape}). Resetting state.")
                            lbfgs_state['s_list'] = []
                            lbfgs_state['y_list'] = []
                            lbfgs_state['rho_list'] = []
                            lbfgs_state['prev_w'] = None
                            lbfgs_state['prev_g'] = None
                        else:
                            s = curr_w_vec - lbfgs_state['prev_w']
                            y = curr_g_vec - lbfgs_state['prev_g']
                            
                            denom = torch.dot(y, s)
                            if denom > 1e-10:
                                rho = 1.0 / denom
                                lbfgs_state['s_list'].append(s)
                                lbfgs_state['y_list'].append(y)
                                lbfgs_state['rho_list'].append(rho)
                                if len(lbfgs_state['s_list']) > args.lbfgs_m:
                                    lbfgs_state['s_list'].pop(0)
                                    lbfgs_state['y_list'].pop(0)
                                    lbfgs_state['rho_list'].pop(0)
                    
                    lbfgs_state['prev_w'] = curr_w_vec.clone()
                    lbfgs_state['prev_g'] = curr_g_vec.clone()

            # Calculate Step
            local_steps_list = []
            for idx in selected_users:
                s_j, _ = clients[idx].compute_newton_step(copy.deepcopy(global_weights), global_grads, lbfgs_state)
                local_steps_list.append(s_j)

            global_step = average_gradients(local_steps_list, local_samples_list)

            for name, param in global_weights.items():
                if name in global_step:
                    if param.dtype == torch.long: continue 
                    global_weights[name] = param - global_step[name].to(device)
            
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Legacy: FedNewton-Approx (Can be replaced by FedNewton + solver='diag') ---
        elif args.method == 'FedNewton-Approx':
            local_grads_list = []
            for idx in selected_users:
                g, n, loss = clients[idx].compute_gradient(copy.deepcopy(global_weights))
                local_grads_list.append(g)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_grads = average_gradients(local_grads_list, local_samples_list)
            
            local_steps_list = []
            for idx in selected_users:
                s_j, _ = clients[idx].compute_newton_approx_step(copy.deepcopy(global_weights), global_grads)
                local_steps_list.append(s_j)

            global_step = average_gradients(local_steps_list, local_samples_list)
            for name, param in global_weights.items():
                if name in global_step:
                    if param.dtype == torch.long: continue 
                    global_weights[name] = param - args.newton_lr * global_step[name].to(device)
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Legacy: FedDANE ---
        elif args.method == 'FedDANE':
            local_grads_list = []
            for idx in selected_users:
                g, n, loss = clients[idx].compute_gradient(copy.deepcopy(global_weights))
                local_grads_list.append(g)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_grads = average_gradients(local_grads_list, local_samples_list)
            
            local_steps_list = []
            for idx in selected_users:
                s_j, _ = clients[idx].compute_dane_step(copy.deepcopy(global_weights), global_grads)
                local_steps_list.append(s_j)

            global_step = average_gradients(local_steps_list, local_samples_list)
            for name, param in global_weights.items():
                if name in global_step:
                    if param.dtype == torch.long: continue 
                    global_weights[name] = param - args.newton_lr * global_step[name].to(device)
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)


        test_loss, test_acc = test_global(global_model, global_weights, test_loader, device)
        
        # Time Calculation
        total_elapsed = time.time() - start_time      
        round_elapsed = time.time() - start_rnd_time 
        
        # Record Memory
        max_mem = 0.0
        if torch.cuda.is_available():
            max_mem = torch.cuda.max_memory_allocated(device) / (1024 * 1024)

        logs["rounds"].append(rnd)
        logs["test_acc"].append(test_acc)
        logs["test_loss"].append(test_loss)
        logs["train_loss"].append(train_loss_avg)
        logs["wall_time"].append(total_elapsed)
        logs["max_gpu_mem"].append(max_mem)
        
        log_print(f"[Round {rnd:03d}] {args.method} (Solver: {args.newton_solver}) | Acc: {test_acc*100:.2f}% | Loss: {test_loss:.4f} | Time: {round_elapsed:.2f}s | Max GPU Mem: {max_mem:.2f} MB")

        if test_loss < 1e-6:
            log_print("Early stopping as test loss is very low.")
            break
        elif test_loss > 10.0:
            log_print("Early stopping as test loss is very high.")
            break

    save_filename = f"{args.method}_{args.training_mode}_{args.dataset}_{args.model_name}_{args.newton_solver}_s{args.seed}.npz"
    save_path = os.path.join(args.log_dir, save_filename)
    np.savez(save_path, method=args.method, solver=args.newton_solver, training_mode=args.training_mode, **logs)
    log_print(f"Saved to {save_path}")
