import argparse
import random
import copy
import time
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import warnings

# 忽略部分 torchvision 的兼容性警告
warnings.filterwarnings("ignore")

# ======================
# 1. 工具函数 & 数据集管理
# ======================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_dataset(name, root='./data'):
    name = name.lower()
    
    # 默认 Transform
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)) 
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    if name == 'cifar10':
        num_classes = 10
        in_channels = 3
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)

    elif name == 'cifar100':
        num_classes = 100
        in_channels = 3
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        train_dataset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test)

    elif name == 'mnist':
        num_classes = 10
        in_channels = 1
        train_dataset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform_test)
        
    elif name == 'fashionmnist':
        num_classes = 10
        in_channels = 1
        train_dataset = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform_test)
    
    else:
        raise ValueError(f"Dataset {name} not supported.")

    return train_dataset, test_dataset, num_classes, in_channels

def split_dataset_non_iid(train_dataset, num_clients, alpha=0.5, num_classes=10):
    """使用 Dirichlet 分布划分 Non-IID 数据"""
    if hasattr(train_dataset, 'targets'):
        labels = np.array(train_dataset.targets)
    elif hasattr(train_dataset, 'labels'):
        labels = np.array(train_dataset.labels)
    else:
        labels = np.array([y for _, y in train_dataset])

    idxs = np.arange(len(train_dataset))
    client_indices = [[] for _ in range(num_clients)]

    for c in range(num_classes):
        idxs_c = idxs[labels == c]
        np.random.shuffle(idxs_c)
        
        # 使用 Dirichlet 分布生成比例
        proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
        
        # 将比例转换为具体的样本数量，并在最后做累积求和以进行切分
        # 使用 astype(int) 可能会导致总数不匹配，因此我们处理最后一个客户端以兜底
        proportions = np.array([p * (len(idxs_c)) for p in proportions])
        proportions = proportions.astype(int)
        
        # 修正由于取整导致的数量不一致，将剩余部分加到最后一个客户端（或者随机分配）
        proportions[-1] = len(idxs_c) - np.sum(proportions[:-1])
        
        # 如果计算出的某个份额为负数（极罕见情况），重置为0并修正
        if proportions[-1] < 0:
             proportions[-1] = 0
             # 简单的重新归一化逻辑略复杂，这里假设 alpha 不会极端到这种地步
        
        split_idxs = np.split(idxs_c, np.cumsum(proportions)[:-1])
        
        for client_id, idx in enumerate(split_idxs):
            client_indices[client_id].extend(idx.tolist())

    for cid in range(num_clients):
        np.random.shuffle(client_indices[cid])
        
    return client_indices

# ======================
# 2. 模型定义
# ======================
def modify_model_architecture(model, num_classes, in_channels):
    # 修改输入层 (针对 MNIST/FashionMNIST)
    if in_channels != 3:
        if hasattr(model, 'conv1') and isinstance(model.conv1, nn.Conv2d):
            old_layer = model.conv1
            new_layer = nn.Conv2d(in_channels, old_layer.out_channels, 
                                  kernel_size=old_layer.kernel_size, 
                                  stride=old_layer.stride, 
                                  padding=old_layer.padding, 
                                  bias=old_layer.bias is not None)
            model.conv1 = new_layer
            
    # 修改输出层 (全连接层)
    if hasattr(model, 'fc') and isinstance(model.fc, nn.Linear):
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
    elif hasattr(model, 'classifier'):
        if isinstance(model.classifier, nn.Linear):
            in_features = model.classifier.in_features
            model.classifier = nn.Linear(in_features, num_classes)
        elif isinstance(model.classifier, nn.Sequential):
            last_layer_idx = len(model.classifier) - 1
            last_layer = model.classifier[last_layer_idx]
            if isinstance(last_layer, nn.Linear):
                in_features = last_layer.in_features
                model.classifier[last_layer_idx] = nn.Linear(in_features, num_classes)
                
    return model

def build_model(model_name='resnet18', num_classes=10, in_channels=3, pretrained_path=None, force_random=False):
    try:
        model_fn = getattr(torchvision.models, model_name.lower()) 
    except AttributeError:
        raise ValueError(f"Model {model_name} not found in torchvision.models")

    # 处理 torchvision 版本兼容性 (weights 参数)
    weights = None
    if pretrained_path is None and not force_random:
        try: weights = 'DEFAULT'
        except: pass 

    try: 
        model = model_fn(weights=weights)
    except: 
        model = model_fn(pretrained=(weights=='DEFAULT'))

    model = modify_model_architecture(model, num_classes, in_channels)

    if pretrained_path is not None and not force_random:
        if os.path.exists(pretrained_path):
            checkpoint = torch.load(pretrained_path, map_location='cpu')
            state_dict = checkpoint['state_dict'] if isinstance(checkpoint, dict) and 'state_dict' in checkpoint else checkpoint
            # 移除 DDP 可能产生的 module. 前缀
            new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
            model.load_state_dict(new_state_dict, strict=False)

    return model

# ======================
# 3. 客户端逻辑
# ======================
class Client:
    def __init__(self, client_id, dataset, indices, device, args, num_classes, in_channels):
        self.client_id = client_id
        self.device = device
        self.args = args
        self.num_classes = num_classes
        self.in_channels = in_channels
        
        # 如果数据量太少，drop_last=False 避免报错
        should_drop = len(indices) > args.batch_size
        self.loader = DataLoader(
            Subset(dataset, indices), 
            batch_size=args.batch_size, 
            shuffle=True,
            drop_last=should_drop
        )
        self.model = build_model(args.model_name, num_classes, in_channels, force_random=True).to(self.device)

    def _set_freeze_status(self):
        """设置模型冻结状态，Head模式下冻结Backbone"""
        self.model.train()
        for param in self.model.parameters():
            param.requires_grad = True

        if self.args.training_mode == 'head':
            # 冻结除 fc/classifier 外的所有层
            for name, param in self.model.named_parameters():
                is_head = ('fc' in name) or ('classifier' in name)
                if not is_head:
                    param.requires_grad = False
            
            # 设置 BatchNorm 状态
            for name, module in self.model.named_modules():
                is_head = ('fc' in name) or ('classifier' in name)
                if name != "":
                    if not is_head:
                        module.eval() # Backbone Freeze (BN 统计量不更新)
                    else:
                        module.train()

    def local_train(self, global_state_dict):
        """FedAvg / FedProx 的本地训练逻辑"""
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status()
        
        trainable_params = filter(lambda p: p.requires_grad, self.model.parameters())
        
        # [KEY] FedAvg 使用通用的 learning rate (args.lr)
        if self.args.optimizer == 'sgd':
            optimizer = optim.SGD(trainable_params, lr=self.args.lr, momentum=0.9)
        else:
            optimizer = optim.Adam(trainable_params, lr=self.args.lr)
            
        criterion = nn.CrossEntropyLoss()
        
        global_params_dict = {}
        if self.args.method == "FedProx":
             global_params_dict = {n: p.detach().clone() for n, p in self.model.named_parameters() if p.requires_grad}
        
        num_samples = 0
        loss_sum = 0.0

        for _ in range(self.args.local_epochs):
            for x, y in self.loader:
                x, y = x.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                logits = self.model(x)
                loss = criterion(logits, y)

                # FedProx 正则项
                if self.args.method == "FedProx" and self.args.mu > 0:
                    prox_term = 0.0
                    for n, p in self.model.named_parameters():
                        if p.requires_grad:
                            prox_term += (p - global_params_dict[n]).norm(2)**2
                    loss += (self.args.mu / 2) * prox_term

                loss.backward()
                optimizer.step()

                num_samples += x.size(0)
                loss_sum += loss.item() * x.size(0)

        return self.model.state_dict(), num_samples, loss_sum / max(num_samples, 1)

    def compute_gradient(self, global_state_dict):
        """FedNewton 第一步：计算梯度"""
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status()
        
        # 用于累积梯度的 Dummy Optimizer
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=0.0)
        criterion = nn.CrossEntropyLoss()
        optimizer.zero_grad()
        
        total_loss = 0.0
        total_samples = 0
        
        for x, y in self.loader:
            x, y = x.to(self.device), y.to(self.device)
            logits = self.model(x)
            loss = criterion(logits, y)
            (loss * x.size(0)).backward() 
            total_loss += loss.item() * x.size(0)
            total_samples += x.size(0)
            
        local_grads = {}
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                # 返回平均梯度 (Sum gradients / Total samples)
                local_grads[name] = param.grad.clone().detach() / max(total_samples, 1)
        
        return local_grads, total_samples, total_loss / max(total_samples, 1)

    # =========================================================
    # Exact Newton Step (Head) + SGD Step (Backbone)
    # =========================================================
    def compute_newton_step(self, global_state_dict, global_grads):
        """
        FedNewton 第二步：计算更新步长
        - Head: 使用局部数据计算 Hessian，解 Delta = lr_newton * (H + damp*I)^-1 * g
        - Body: 使用标准 SGD，Delta = lr_sgd * g
        """
        self.model.load_state_dict(global_state_dict)
        self._set_freeze_status()
        self.model.eval() # 梯度/Hessian计算模式

        # --- 1. 区分 Head 和 Body 参数 ---
        head_keys = []
        head_params = []
        body_keys = []
        
        for name, param in self.model.named_parameters():
            if not param.requires_grad: continue
            
            is_head = ('fc' in name) or ('classifier' in name)
            if is_head:
                head_keys.append(name)
                head_params.append(param)
            else:
                body_keys.append(name)

        # --- 2. 提取 Backbone 特征用于计算 Head 的 Hessian ---
        if hasattr(self.model, 'fc'): head_module = self.model.fc
        else: head_module = self.model.classifier

        # Hook: 截取 Head 层的输入 (即 Backbone 的输出)
        feats_container = []
        def hook_fn(module, input, output):
            feats_container.append(input[0].detach()) 

        handle = head_module.register_forward_hook(hook_fn)

        features_list = []
        targets_list = []
        total_samples = 0
        max_hessian_batches = self.args.hessian_batches 
        
        # 遍历数据收集特征
        for i, (x, y) in enumerate(self.loader):
            if i >= max_hessian_batches: break
            x, y = x.to(self.device), y.to(self.device)
            total_samples += x.size(0)
            
            feats_container = [] # 清空
            with torch.no_grad():
                self.model(x) # 触发 Hook
            
            features_list.append(feats_container[0])
            targets_list.append(y)
            
        handle.remove()
        
        if total_samples == 0: return {}, 0
        
        all_features = torch.cat(features_list, dim=0) # [N_subset, Feat_dim]
        all_targets = torch.cat(targets_list, dim=0)   # [N_subset]

        # --- 3. 使用 Autograd 计算 Head 的 Exact Hessian ---
        curr_w = head_module.weight.detach()
        curr_b = head_module.bias.detach() if head_module.bias is not None else None
        
        # 定义 Loss 函数: Loss = f(weight, bias)
        def head_loss_func(w, b):
            out = torch.nn.functional.linear(all_features, w, b)
            loss = torch.nn.functional.cross_entropy(out, all_targets, reduction='sum')
            return loss

        def head_loss_func_no_bias(w):
            out = torch.nn.functional.linear(all_features, w, None)
            return torch.nn.functional.cross_entropy(out, all_targets, reduction='sum')

        # 计算 Hessian
        if curr_b is not None:
            inputs = (curr_w, curr_b)
            H_tuple = torch.autograd.functional.hessian(head_loss_func, inputs)
        else:
            inputs = (curr_w,)
            H_tuple = torch.autograd.functional.hessian(head_loss_func_no_bias, inputs)
            H_tuple = ((H_tuple,),) # 统一格式

        # --- 4. 组装 Hessian 矩阵和梯度向量 ---
        # 4.1 组装梯度向量 g (来自 Global Gradients)
        g_parts = []
        for k in head_keys:
            if k in global_grads:
                g_parts.append(global_grads[k].view(-1).to(self.device))
            else:
                g_parts.append(torch.zeros_like(head_module.get_parameter(k.split('.')[-1])).view(-1))
        
        if not g_parts: return {}, total_samples
        flat_g = torch.cat(g_parts) # [P]

        # 4.2 组装 Hessian 矩阵 H
        H_blocks_rows = []
        for i, row_inputs in enumerate(inputs):
            row_blocks = []
            for j, col_inputs in enumerate(inputs):
                block = H_tuple[i][j]
                numel_i = inputs[i].numel()
                numel_j = inputs[j].numel()
                block_flat = block.reshape(numel_i, numel_j)
                row_blocks.append(block_flat)
            H_blocks_rows.append(torch.cat(row_blocks, dim=1)) 
            
        H_mat = torch.cat(H_blocks_rows, dim=0) # [P, P]
        H_avg = H_mat / max(total_samples, 1)   # 平均 Hessian

        # --- 5. 求解 Newton Update (Head) ---
        # 公式: Delta = newton_lr * (H + lambda*I)^-1 * g
        damping = self.args.damping
        I = torch.eye(H_avg.size(0), device=self.device)
        
        # 求解线性方程 (比直接求逆更稳定)
        delta_flat = torch.linalg.solve(H_avg + damping * I, flat_g)
        
        # [KEY] 使用 args.newton_lr 作为牛顿法的步长
        step_scale = self.args.newton_lr 
        delta_flat = delta_flat * step_scale

        # 还原到参数字典
        local_update_s = {}
        ptr = 0
        for name, param in zip(head_keys, head_params):
            numel = param.numel()
            update_vec = delta_flat[ptr : ptr + numel]
            local_update_s[name] = update_vec.view(param.shape).cpu()
            ptr += numel

        # --- 6. 处理 Body 部分 (SGD) ---
        if self.args.training_mode == 'full':
            # [KEY] Body 部分使用 args.lr (SGD Learning Rate)
            sgd_lr = self.args.lr
            for name in body_keys:
                if name in global_grads:
                    g = global_grads[name].to(self.device)
                    local_update_s[name] = (sgd_lr * g).cpu()

        return local_update_s, total_samples

# ======================
# 4. 服务端逻辑
# ======================
def average_weights(w_list, s_list):
    w_avg = copy.deepcopy(w_list[0])
    total_samples = sum(s_list)
    for k in w_avg.keys():
        if 'num_batches_tracked' in k: continue
        weighted_sum = sum(w[k] * s for w, s in zip(w_list, s_list))
        w_avg[k] = weighted_sum / total_samples
    return w_avg

def average_gradients(g_list, s_list):
    if not g_list: return {}
    all_keys = set()
    for g in g_list: all_keys.update(g.keys())
    
    g_avg = {}
    total_samples = sum(s_list)
    
    for k in all_keys:
        weighted_sum = 0.0
        for g, s in zip(g_list, s_list):
            if k in g:
                weighted_sum += g[k] * s
        g_avg[k] = weighted_sum / total_samples
    return g_avg

def test_global(model, global_weights, test_loader, device):
    model.load_state_dict(global_weights)
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            loss_sum += loss.item() * x.size(0)
            _, pred = torch.max(out, 1)
            correct += (pred == y).sum().item()
            total += x.size(0)
    return loss_sum / max(total, 1), correct / max(total, 1)

# ======================
# 5. 主流程
# ======================
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', type=str, default='FedAvg', choices=['FedAvg', 'FedProx', 'FedNewton'])
    parser.add_argument('--training_mode', type=str, default='full', choices=['full', 'head'], help='full: update all; head: update head only')
    parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, cifar100, mnist, femnist')
    parser.add_argument('--model_name', type=str, default='resnet18')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--rounds', type=int, default=20)
    parser.add_argument('--num_clients', type=int, default=10)
    parser.add_argument('--fraction', type=float, default=0.5)
    parser.add_argument('--local_epochs', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=0.01, help="Learning rate for SGD/Adam (Backbone or FedAvg)")
    parser.add_argument('--newton_lr', type=float, default=0.1, help="Step size for FedNewton Head Update")
    parser.add_argument('--mu', type=float, default=0.01)
    parser.add_argument('--log_dir', type=str, default='./logs')
    parser.add_argument('--optimizer', type=str, default='sgd')
    parser.add_argument('--hessian_batches', type=int, default=32, help='Max batches for Hessian calculation')
    parser.add_argument('--damping', type=float, default=0.1, help='Damping factor (lambda) for Hessian')

    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.makedirs(args.log_dir, exist_ok=True)
    
    print(f"Start: {args.method} | Mode: {args.training_mode} | Dataset: {args.dataset} | Model: {args.model_name}")
    print(f"Params -> SGD LR: {args.lr}, Newton LR: {args.newton_lr}, Damping: {args.damping}")

    train_ds, test_ds, num_classes, in_channels = get_dataset(args.dataset)
    
    test_loader = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=2)
    client_indices = split_dataset_non_iid(train_ds, args.num_clients, alpha=0.5, num_classes=num_classes)
    global_model = build_model(args.model_name, num_classes, in_channels).to(device)
    global_weights = global_model.state_dict()

    clients = [Client(i, train_ds, client_indices[i], device, args, num_classes, in_channels) for i in range(args.num_clients)]

    logs = {"rounds": [], "test_acc": [], "test_loss": [], "train_loss": [], "wall_time": []}
    start_time = time.time()

    for rnd in range(1, args.rounds + 1):
        m = max(int(args.fraction * args.num_clients), 1)
        selected_users = np.random.choice(range(args.num_clients), m, replace=False)
        
        local_losses = []
        local_samples_list = []

        # --- Strategy: FedAvg / FedProx ---
        if args.method in ['FedAvg', 'FedProx']:
            local_weights = []
            for idx in selected_users:
                w, n, loss = clients[idx].local_train(copy.deepcopy(global_weights))
                local_weights.append(w)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            global_weights = average_weights(local_weights, local_samples_list)
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # --- Strategy: FedNewton ---
        elif args.method == 'FedNewton':
            # Step 1: Compute Gradients (on Clients)
            local_grads_list = []
            for idx in selected_users:
                g, n, loss = clients[idx].compute_gradient(copy.deepcopy(global_weights))
                local_grads_list.append(g)
                local_samples_list.append(n)
                local_losses.append(loss)
            
            # Aggregation of Gradients
            global_grads = average_gradients(local_grads_list, local_samples_list)
            
            # Step 2: Compute Updates (Head: Newton, Body: SGD)
            local_steps_list = []
            for idx in selected_users:
                # 传入 global_grads，让 Client 计算 Delta
                s_j, _ = clients[idx].compute_newton_step(copy.deepcopy(global_weights), global_grads)
                local_steps_list.append(s_j)

            # Step 3: Apply Updates to Global Model
            # w_new = w_old - avg(Delta)
            global_step = average_gradients(local_steps_list, local_samples_list)

            for name, param in global_weights.items():
                if name in global_step:
                    if param.dtype == torch.long: continue 
                    global_weights[name] = param - global_step[name].to(device)
            
            train_loss_avg = sum([l*n for l,n in zip(local_losses, local_samples_list)]) / max(sum(local_samples_list), 1)

        # Test & Log
        test_loss, test_acc = test_global(global_model, global_weights, test_loader, device)
        elapsed = time.time() - start_time
        
        logs["rounds"].append(rnd)
        logs["test_acc"].append(test_acc)
        logs["test_loss"].append(test_loss)
        logs["train_loss"].append(train_loss_avg)
        logs["wall_time"].append(elapsed)
        
        print(f"[Round {rnd:03d}] {args.method} | Acc: {test_acc*100:.2f}% | Loss: {test_loss:.4f}")

    # Save Results
    save_filename = f"{args.method}_{args.training_mode}_{args.dataset}_{args.model_name}_s{args.seed}.npz"
    save_path = os.path.join(args.log_dir, save_filename)
    
    np.savez(save_path, 
             method=args.method, 
             training_mode=args.training_mode,
             **logs)
    
    print(f"Saved to {save_path}")