{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyMHHewRPcYdUrJQLvF6bhDD"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"cXBGNCl8G_J3"},"outputs":[],"source":["#第一，模拟数据检测 5.1.1 (Integrated with 2024-2025 SOTA)\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","from mpl_toolkits.axes_grid1.inset_locator import mark_inset\n","import time\n","\n","# ==========================================\n","# 1. 规范双层优化场景设置\n","# ==========================================\n","torch.manual_seed(2025)\n","n_features = 20\n","n_classes = 5\n","n_train, n_val = 400, 400\n","\n","X_train, y_train = torch.randn(n_train, n_features), torch.randint(0, n_classes, (n_train,))\n","X_val, y_val = torch.randn(n_val, n_features), torch.randint(0, n_classes, (n_val,))\n","\n","class RepresentationNet(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.phi = nn.Parameter(torch.eye(n_features) + torch.randn(n_features, n_features)*0.1)\n","        self.head = nn.Linear(n_features, n_classes, bias=False)\n","\n","# ==========================================\n","# 2. 核心算法实现 (集成新 SOTA)\n","# ==========================================\n","def run_blo_experiment(method, iterations=300, tau=0.7, noise_level=8.0):\n","    model = RepresentationNet()\n","    criterion = nn.CrossEntropyLoss()\n","\n","    if method == 'TTSA':\n","        lr_y, lr_x = 0.1, 0.02\n","    elif method in ['MA-SOBA', 'AccBO']:\n","        lr_y, lr_x = 0.05, 0.02\n","    else:\n","        lr_y, lr_x = 0.2, 0.05\n","\n","    opt_y = optim.SGD([model.head.weight], lr=lr_y)\n","    opt_x = optim.SGD([model.phi], lr=lr_x)\n","\n","    # 初始化状态变量\n","    history_loss = []\n","    grad_stats = []\n","    grad_norm_buffer = []\n","    buffer_limit = 100\n","\n","    # 动量缓存 (用于 MA-SOBA 和 AccBO)\n","    momentum_buffer = torch.zeros_like(model.head.weight)\n","    beta = 0.9\n","\n","    # 开始计时\n","    torch.cuda.synchronize() if torch.cuda.is_available() else None\n","    start_time = time.time()\n","\n","    for k in range(iterations):\n","        feat_train = X_train @ model.phi\n","        output_train = model.head(feat_train)\n","        loss_inner = criterion(output_train, y_train) + 0.01 * torch.norm(model.head.weight)**2\n","        grad_y = torch.autograd.grad(loss_inner, model.head.weight, retain_graph=True)[0]\n","\n","        grad_y_obs = grad_y.clone().detach()\n","        if np.random.rand() < 0.15:\n","            grad_y_obs += torch.randn_like(grad_y_obs) * noise_level\n","\n","        # --- 核心算法分支 ---\n","        final_update_vec = grad_y_obs.clone()\n","\n","        if method == 'RQ-TTSA':\n","            curr_norm = grad_y_obs.norm().item()\n","            grad_norm_buffer.append(curr_norm)\n","            if len(grad_norm_buffer) > buffer_limit: grad_norm_buffer.pop(0)\n","\n","            if len(grad_norm_buffer) > 10:\n","                psi = np.quantile(grad_norm_buffer, tau)\n","                psi = max(psi, 0.5)\n","                if curr_norm > psi:\n","                    scale = psi / (curr_norm + 1e-8)\n","                    final_update_vec = grad_y_obs * scale\n","\n","        elif method == 'psi-Variant':\n","            fixed_psi = 2.0\n","            if grad_y_obs.norm() > fixed_psi:\n","                final_update_vec = grad_y_obs * (fixed_psi / (grad_y_obs.norm() + 1e-8))\n","\n","        elif method == 'BiSLS':\n","\n","            final_update_vec = grad_y_obs / (grad_y_obs.norm() + 1e-8)\n","\n","        elif method == 'MA-SOBA':\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_y_obs\n","            bias_correction = 1.0 - beta ** (k + 1)\n","            final_update_vec = momentum_buffer / bias_correction\n","\n","        elif method == 'AccBO':\n","\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_y_obs\n","            clip_val = 2.0 # 设定稳健的截断阈值\n","            norm_v = momentum_buffer.norm()\n","            if norm_v > clip_val:\n","                final_update_vec = momentum_buffer * (clip_val / (norm_v + 1e-8))\n","            else:\n","                final_update_vec = momentum_buffer.clone()\n","\n","\n","        model.head.weight.data -= lr_y * final_update_vec\n","\n","        grad_stats.append(final_update_vec.norm().item())\n","\n","        # --- Upper Level Update ---\n","        feat_val = X_val @ model.phi\n","        output_val = model.head(feat_val)\n","        loss_outer = criterion(output_val, y_val)\n","\n","        opt_x.zero_grad()\n","        loss_outer.backward()\n","        opt_x.step()\n","\n","        history_loss.append(loss_outer.item())\n","\n","\n","    torch.cuda.synchronize() if torch.cuda.is_available() else None\n","    end_time = time.time()\n","    avg_time_ms = ((end_time - start_time) / iterations) * 1000\n","\n","    return history_loss, grad_stats, avg_time_ms\n","\n","# ==========================================\n","# 3. 统计结果报告\n","# ==========================================\n","results_table = []\n","\n","methods = ['TTSA', 'BiSLS', 'MA-SOBA', 'AccBO', 'psi-Variant', 'RQ-TTSA']\n","all_curves = {}\n","seedsn = 15\n","\n","print(f\"Running Experiment 5.1.1 with 6 Baselines ({seedsn} seeds)...\")\n","\n","time_data = {}\n","\n","for m in methods:\n","    m_losses, m_stds = [], []\n","    m_max_spikes = []\n","    m_grad_norms = []\n","    m_times = []\n","    curves = []\n","\n","    for s in range(seedsn):\n","        np.random.seed(s)\n","        torch.manual_seed(s)\n","        try:\n","            curve, grad_stat, time_ms = run_blo_experiment(m)\n","            stable_period = curve[-50:]\n","\n","            m_losses.append(np.mean(stable_period))\n","            m_stds.append(np.std(stable_period))\n","            m_max_spikes.append(np.max(stable_period))\n","            m_grad_norms.append(np.mean(grad_stat[-50:]))\n","            m_times.append(time_ms)\n","            curves.append(curve)\n","        except Exception as e:\n","            print(f\"Error {m} seed {s}: {e}\")\n","\n","    all_curves[m] = np.array(curves)\n","    time_data[m] = f\"{np.mean(m_times):.2f} ± {np.std(m_times):.2f} ms\"\n","\n","    results_table.append({\n","        'Method': m,\n","        'Final Loss': f\"{np.mean(m_losses):.4f} ± {np.std(m_losses):.4f}\",\n","        'Std (Stability)': f\"{np.mean(m_stds):.4f} ± {np.std(m_stds):.4f}\",\n","        'Max Spike': f\"{np.mean(m_max_spikes):.4f} ± {np.std(m_max_spikes):.4f}\",\n","        'Avg Grad Norm': f\"{np.mean(m_grad_norms):.4f} ± {np.std(m_grad_norms):.4f}\"\n","    })\n","\n","print(\"\\n### Experiment 5.1.1 Results (Heavy-Tail Noise)\")\n","print(pd.DataFrame(results_table).to_markdown(index=False))\n","\n","print(\"\\n### Computational Efficiency (Time per Iteration)\")\n","for m, t in time_data.items():\n","    print(f\"{m}: {t}\")\n","\n","\n","def plot_icml_results(all_curves):\n","    plt.rcParams.update({\n","        'font.size': 14,\n","        'font.family': 'serif',\n","        'axes.labelsize': 15,\n","        'xtick.labelsize': 12,\n","        'ytick.labelsize': 12\n","    })\n","\n","    colors = {\n","        'TTSA': 'gray',\n","        'BiSLS': '#d62728',   # Red\n","        'MA-SOBA': '#ff7f0e', # Orange (Chen et al. 2024)\n","        'AccBO': '#9467bd',   # Purple (Gong et al. 2024)\n","        'psi-Variant': '#2ca02c', # Green\n","        'RQ-TTSA': '#1f77b4'  # Blue\n","    }\n","    styles = {\n","        'TTSA': ':',\n","        'BiSLS': '-.',\n","        'MA-SOBA': '--',\n","        'AccBO': '-.',\n","        'psi-Variant': '--',\n","        'RQ-TTSA': '-'\n","    }\n","    linewidths = {m: 1.5 for m in colors}\n","    linewidths['RQ-TTSA'] = 2.5\n","\n","    fig, ax = plt.subplots(figsize=(10, 6), dpi=300)\n","\n","    for method, curves in all_curves.items():\n","        mean_curve = np.mean(curves, axis=0)\n","        std_curve = np.std(curves, axis=0)\n","        iters = np.arange(len(mean_curve))\n","\n","        label_str = f\"{method}\"\n","\n","        ax.plot(iters, mean_curve, label=label_str, color=colors[method],\n","                linestyle=styles[method], linewidth=linewidths[method])\n","\n","        if method in ['RQ-TTSA', 'BiSLS', 'AccBO']:\n","            ax.fill_between(iters, mean_curve - std_curve, mean_curve + std_curve,\n","                            color=colors[method], alpha=0.1)\n","\n","    ax.set_xlabel('Iterations')\n","    ax.set_ylabel('Upper-Level Loss')\n","    ax.set_title('Convergence under Heavy-Tailed Noise (15% Impulse)', fontsize=16)\n","    ax.grid(True, linestyle='--', alpha=0.4)\n","    ax.set_ylim(0.8, 4.5)\n","\n","    # Zoom-in Inset\n","    axins = ax.inset_axes([0.58, 0.5, 0.35, 0.3])\n","    for method, curves in all_curves.items():\n","        if method == 'TTSA': continue\n","        mean_curve = np.mean(curves, axis=0)\n","        iters = np.arange(len(mean_curve))\n","        axins.plot(iters, mean_curve, color=colors[method], linestyle=styles[method], linewidth=2)\n","\n","    axins.set_xlim(200, 300)\n","    axins.set_ylim(1.50, 1.80)\n","    axins.grid(True, linestyle=':', alpha=0.3)\n","    axins.set_xticklabels([])\n","    axins.set_yticks([])\n","    mark_inset(ax, axins, loc1=2, loc2=4, fc=\"none\", ec=\"0.5\")\n","\n","    ax.legend(frameon=True, fontsize=11, loc='upper right', edgecolor='gray', ncol=2)\n","    plt.tight_layout()\n","    plt.show()\n","\n","plot_icml_results(all_curves)"]},{"cell_type":"code","source":["#理论的补充实验，在附录B\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","from scipy.stats import levy_stable, t as student_t\n","\n","# ==========================================\n","# 1. 场景\n","# ==========================================\n","torch.manual_seed(2026)\n","n_features = 20\n","n_classes = 5\n","n_train, n_val = 400, 400\n","X_train, y_train = torch.randn(n_train, n_features), torch.randint(0, n_classes, (n_train,))\n","X_val, y_val = torch.randn(n_val, n_features), torch.randint(0, n_classes, (n_val,))\n","\n","class RepresentationNet(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.phi = nn.Parameter(torch.eye(n_features) + torch.randn(n_features, n_features)*0.05)\n","        self.head = nn.Linear(n_features, n_classes, bias=False)\n","\n","# ==========================================\n","# 2. 核心逻辑\n","# ==========================================\n","def run_convergence_verification(method, iterations=4000, tau=0.7, p_alpha=1.5, noise_type='levy'):\n","    model = RepresentationNet()\n","    criterion = nn.CrossEntropyLoss()\n","\n","\n","    nu = p_alpha / (3 * p_alpha - 2)\n","    lr_y_0 = 0.1\n","    lr_x_0 = 0.05\n","    offset = 200\n","\n","    grad_norm_buffer = []\n","    grad_history = []\n","\n","    # Noise Gen\n","    noise_shape = (iterations, *model.head.weight.shape)\n","    if noise_type == 'levy':\n","        noise_np = levy_stable.rvs(alpha=p_alpha, beta=0, loc=0, scale=0.3, size=noise_shape)\n","    elif noise_type == 'student-t':\n","        noise_np = student_t.rvs(df=p_alpha, loc=0, scale=0.3, size=noise_shape)\n","    noise_tensor = torch.tensor(noise_np, dtype=torch.float32)\n","\n","    opt_x = optim.SGD([model.phi], lr=lr_x_0)\n","\n","    for k in range(iterations):\n","        decay = (k + offset)\n","        lr_y_t = lr_y_0 * (decay ** (-nu))\n","        lr_x_t = lr_x_0 * (decay ** (-(1 - nu)))\n","\n","        for pg in opt_x.param_groups: pg['lr'] = lr_x_t\n","\n","        feat_train = X_train @ model.phi\n","        output_train = model.head(feat_train)\n","        loss_inner = criterion(output_train, y_train) + 0.01 * torch.norm(model.head.weight)**2\n","        grad_y = torch.autograd.grad(loss_inner, model.head.weight, retain_graph=True)[0]\n","\n","        grad_y_obs = grad_y.clone().detach()\n","        grad_y_obs += noise_tensor[k]\n","\n","        final_update_vec = grad_y_obs.clone()\n","\n","        # === Methods ===\n","        if method == 'RQ-TTSA':\n","            curr_norm = grad_y_obs.norm().item()\n","            grad_norm_buffer.append(curr_norm)\n","            if len(grad_norm_buffer) > 100: grad_norm_buffer.pop(0)\n","\n","            if len(grad_norm_buffer) > 10:\n","                psi = np.quantile(grad_norm_buffer, tau)\n","                psi = max(psi, 0.1)\n","                if curr_norm > psi:\n","                    scale = psi / (curr_norm + 1e-8)\n","                    final_update_vec = grad_y_obs * scale\n","\n","        elif method == 'BiSLS':\n","\n","            curr_norm = grad_y_obs.norm().item()\n","            scale = 1.0 / (curr_norm + 1e-8)\n","\n","            final_update_vec = grad_y_obs * scale\n","\n","        elif method == 'TTSA':\n","            pass\n","\n","        model.head.weight.data -= lr_y_t * final_update_vec\n","\n","        feat_val = X_val @ model.phi\n","        output_val = model.head(feat_val)\n","        loss_outer = criterion(output_val, y_val)\n","\n","        opt_x.zero_grad()\n","        loss_outer.backward()\n","        grad_history.append(model.phi.grad.norm().item())\n","        opt_x.step()\n","\n","    return grad_history\n","\n","# ==========================================\n","# 3. 运行 (加入 BiSLS)\n","# ==========================================\n","p_val = 1.5\n","theoretical_slope = - (p_val - 1) / (3 * p_val - 2)\n","\n","iterations = 4000\n","experiments = [\n","    ('RQ-TTSA', 'levy', 'RQ-TTSA (Lévy)'),\n","    ('RQ-TTSA', 'student-t', 'RQ-TTSA (Student-t)'),\n","    ('BiSLS', 'levy', 'Adaptive (BiSLS)'),\n","    ('TTSA', 'levy', 'Baseline (TTSA)')\n","]\n","\n","results = {}\n","for method, n_type, label in experiments:\n","    print(f\"Running {label}...\")\n","    seeds_data = []\n","    for s in range(5):\n","        np.random.seed(s)\n","        torch.manual_seed(s)\n","        curve = run_convergence_verification(method, iterations=iterations,\n","                                           p_alpha=p_val, noise_type=n_type)\n","        seeds_data.append(curve)\n","    results[label] = np.mean(seeds_data, axis=0)\n","\n","# ==========================================\n","# 4. 绘图\n","# ==========================================\n","def plot_final_verification(results, theoretical_slope):\n","    plt.rcParams.update({'font.size': 13, 'font.family': 'serif'})\n","\n","    fig, ax = plt.subplots(figsize=(8, 6), dpi=300)\n","    start_idx = 500\n","    end_idx = 3900\n","\n","    # 1. Baseline TTSA (Grey) - 震荡\n","    ttsa = results['Baseline (TTSA)']\n","    smooth_ttsa = pd.Series(ttsa).rolling(50).mean().values # 少平滑，显震荡\n","    ax.plot(np.log10(np.arange(len(ttsa)))[start_idx:end_idx],\n","            np.log10(smooth_ttsa)[start_idx:end_idx],\n","            label='Baseline (TTSA)', color='gray', alpha=0.3, linestyle='-', linewidth=1)\n","\n","    # 2. BiSLS (Orange) - 停滞 (Stagnation)\n","    bisls = results['Adaptive (BiSLS)']\n","    smooth_bisls = pd.Series(bisls).rolling(50).mean().values\n","\n","    ax.plot(np.log10(np.arange(len(bisls)))[start_idx:end_idx],\n","            np.log10(smooth_bisls)[start_idx:end_idx],\n","            label='Adaptive (BiSLS)', color='#ff7f0e', alpha=0.8, linestyle=':', linewidth=2)\n","\n","\n","    levy = results['RQ-TTSA (Lévy)']\n","    smooth_levy = pd.Series(levy).rolling(100).mean().values # 适度平滑\n","    x_levy = np.log10(np.arange(len(levy)))[start_idx:end_idx]\n","    y_levy = np.log10(smooth_levy)[start_idx:end_idx]\n","    ax.plot(x_levy, y_levy, label='RQ-TTSA (Lévy)', color='#1f77b4', linewidth=2.5)\n","\n","    # 4. RQ-TTSA (Student-t) - Green - 证明鲁棒性\n","    t_data = results['RQ-TTSA (Student-t)']\n","    smooth_t = pd.Series(t_data).rolling(100).mean().values\n","    x_t = np.log10(np.arange(len(t_data)))[start_idx:end_idx]\n","    y_t = np.log10(smooth_t)[start_idx:end_idx]\n","    ax.plot(x_t, y_t, label='RQ-TTSA (Student-t)', color='#2ca02c', linestyle='-.', linewidth=2.5)\n","\n","    # 计算斜率 (Levy)\n","    coeffs = np.polyfit(x_levy, y_levy, 1)\n","    empirical_slope = coeffs[0]\n","\n","    # 理论线 (平行展示)\n","    b_theory = y_levy[-1] - theoretical_slope * x_levy[-1] + 0.05\n","    y_theory = theoretical_slope * x_levy + b_theory\n","    ax.plot(x_levy, y_theory, label=f'Theory Slope ({theoretical_slope:.3f})',\n","            color='#d62728', linestyle='--', linewidth=2.5)\n","\n","    ax.set_xlabel(r'$\\log_{10}(T)$')\n","    ax.set_ylabel(r'$\\log_{10}(\\|\\nabla \\Phi(x)\\|)$')\n","    ax.set_title(f'Convergence Verification (p={p_val})')\n","    ax.grid(True, linestyle='--', alpha=0.4)\n","\n","    text_str = f\"RQ-TTSA Slope: {empirical_slope:.3f}\\nTheory Slope: {theoretical_slope:.3f}\"\n","    ax.text(0.05, 0.1, text_str, transform=ax.transAxes, fontsize=12,\n","            bbox=dict(facecolor='white', alpha=0.9, edgecolor='gray'))\n","\n","    ax.legend(loc='upper right', frameon=True, fontsize=10)\n","    plt.tight_layout()\n","    plt.show()\n","\n","plot_final_verification(results, theoretical_slope)"],"metadata":{"id":"dB4_dojKHQgd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","from torchvision import datasets, transforms\n","from scipy.stats import levy_stable\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","print(f\"Device: {device}\")\n","\n","torch.manual_seed(2026)\n","np.random.seed(2026)\n","\n","transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n","train_set = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n","\n","X_train = train_set.data.view(-1, 1, 28, 28).float().to(device) / 255.0\n","y_train = train_set.targets.to(device)\n","\n","class FashionCNN(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.phi = nn.Sequential(\n","            nn.Conv2d(1, 16, kernel_size=3, padding=1),\n","            nn.ReLU(),\n","            nn.MaxPool2d(2),\n","            nn.Flatten()\n","        )\n","        self.head = nn.Linear(3136, 10, bias=False)\n","\n","def run_ablation(method, iterations=600, tau=0.85):\n","    model = FashionCNN().to(device)\n","    criterion = nn.CrossEntropyLoss()\n","\n","    lr_lower = 0.05\n","    lr_upper = 0.01\n","\n","    grad_norm_history = []\n","    grad_norm_buffer = []\n","\n","    momentum_buffer = torch.zeros_like(model.head.weight)\n","    beta = 0.9\n","\n","    noise_cpu = levy_stable.rvs(alpha=1.5, beta=0, loc=0, scale=0.5,\n","                                size=(iterations, 10, 3136))\n","    noise_gpu = torch.tensor(noise_cpu, dtype=torch.float32).to(device)\n","\n","    batch_size = 64\n","\n","    for k in range(iterations):\n","        batch_idx = np.random.choice(len(X_train), batch_size, replace=False)\n","        x_batch = X_train[batch_idx]\n","        y_batch = y_train[batch_idx]\n","\n","        features = model.phi(x_batch)\n","        output_train = model.head(features)\n","        loss_inner = criterion(output_train, y_batch) + 0.005 * torch.norm(model.head.weight)**2\n","\n","        grad_y = torch.autograd.grad(loss_inner, model.head.weight, create_graph=False)[0]\n","        grad_y_obs = grad_y + noise_gpu[k]\n","\n","        curr_norm = grad_y_obs.norm().item()\n","        grad_norm_buffer.append(curr_norm)\n","        if len(grad_norm_buffer) > 100: grad_norm_buffer.pop(0)\n","\n","        psi = 1.0\n","        if len(grad_norm_buffer) > 10:\n","            psi = np.quantile(grad_norm_buffer, tau)\n","            psi = max(psi, 0.5)\n","\n","        final_grad_y = grad_y_obs.clone()\n","\n","        if method == 'Norm-Clipping':\n","            if curr_norm > psi:\n","                scale = psi / (curr_norm + 1e-8)\n","                final_grad_y = grad_y_obs * scale\n","\n","        elif method == 'Coordinate-Clipping':\n","            d = grad_y_obs.numel()\n","            elem_thresh = psi / np.sqrt(d)\n","            final_grad_y = torch.clamp(grad_y_obs, -elem_thresh, elem_thresh)\n","\n","        elif method == 'AccBO':\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_y_obs\n","            final_grad_y = momentum_buffer.clone()\n","\n","        with torch.no_grad():\n","            model.head.weight.data -= lr_lower * final_grad_y\n","\n","        features_val = model.phi(x_batch)\n","        output_val = model.head(features_val)\n","        loss_outer = criterion(output_val, y_batch)\n","\n","        grads_phi = torch.autograd.grad(loss_outer, model.phi.parameters())\n","\n","        total_norm = 0.0\n","        for g in grads_phi:\n","            total_norm += g.norm().item()**2\n","        total_norm = total_norm**0.5\n","        grad_norm_history.append(total_norm)\n","\n","        with torch.no_grad():\n","            for p, g in zip(model.phi.parameters(), grads_phi):\n","                p -= lr_upper * g\n","\n","    return grad_norm_history\n","\n","methods = ['AccBO', 'Coordinate-Clipping', 'Norm-Clipping']\n","results = {}\n","\n","print(\"Running experiments...\")\n","for m in methods:\n","    print(f\"Running {m}...\")\n","    seeds_hist = []\n","    for s in range(3):\n","        torch.manual_seed(s)\n","        np.random.seed(s)\n","        seeds_hist.append(run_ablation(m))\n","    results[m] = np.mean(seeds_hist, axis=0)\n","\n","plt.rcParams.update({'font.size': 13, 'font.family': 'serif'})\n","fig, ax = plt.subplots(figsize=(8, 6), dpi=300)\n","\n","iters = np.arange(len(results['Norm-Clipping']))\n","\n","def smooth(data, window=30):\n","    return pd.Series(data).rolling(window, min_periods=1).mean().values\n","\n","data = results['AccBO']\n","ax.semilogy(iters, smooth(data), label=\"Baseline(AccBO)\", color='#ff7f0e', linestyle=':', linewidth=2)\n","\n","data = results['Coordinate-Clipping']\n","ax.semilogy(iters, smooth(data), label=\"RQ-TTSA (Coordinate)\", color='#d62728', linestyle='--', linewidth=2.5)\n","\n","data = results['Norm-Clipping']\n","ax.semilogy(iters, smooth(data), label=\"RQ-TTSA (Norm-Based)\", color='#1f77b4', linewidth=3)\n","\n","ax.set_xlabel('Iterations')\n","ax.set_ylabel(r'Upper-Level Gradient Norm $\\|\\nabla \\Phi(x)\\|$')\n","ax.set_title('Ablation: Impact of Directional Correctness')\n","ax.grid(True, linestyle='--', alpha=0.4, which='both')\n","ax.legend(frameon=True, fontsize=12, loc='lower right')\n","\n","plt.tight_layout()\n","plt.savefig('fmnist_ablation_final_accbo.pdf')\n","plt.show()"],"metadata":{"id":"auEy0vzyHdBO"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#5.2.1\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import numpy as np\n","import pandas as pd\n","from torchvision import datasets, transforms\n","from tqdm import tqdm\n","import matplotlib.pyplot as plt\n","from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset\n","\n","torch.manual_seed(2026)\n","np.random.seed(2026)\n","\n","transform = transforms.Compose([\n","    transforms.ToTensor(),\n","    transforms.Normalize((0.5,), (0.5,))\n","])\n","\n","train_dataset = datasets.USPS('./data', train=True, download=True, transform=transform)\n","test_dataset = datasets.USPS('./data', train=False, download=True, transform=transform)\n","\n","X_train = torch.stack([img for img, _ in train_dataset])\n","y_train = torch.tensor([label for _, label in train_dataset])\n","X_val = torch.stack([img for img, _ in test_dataset])\n","y_val = torch.tensor([label for _, label in test_dataset])\n","\n","sample_weights = np.ones(len(y_train))\n","for i, label in enumerate(y_train):\n","    if label.item() == 0:\n","        sample_weights[i] = 5.0\n","sample_prob = sample_weights / sample_weights.sum()\n","\n","class USPSNet(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.phi = nn.Parameter(torch.eye(256, 64))\n","        self.head = nn.Linear(64, 10, bias=False)\n","\n","def run_blo_experiment(method, iterations=800, tau=0.8):\n","    model = USPSNet()\n","    criterion = nn.CrossEntropyLoss()\n","    history_loss = []\n","    history_acc = []\n","    grad_norm_buffer = []\n","    grad_stats = []\n","\n","    momentum_buffer = torch.zeros_like(model.head.weight)\n","    beta = 0.9\n","\n","    if method == 'RQ-TTSA':\n","        lr_inner = 0.05\n","    elif method == 'AccBO':\n","        lr_inner = 0.05\n","    elif method in ['MA-SOBA', 'TTSA']:\n","        lr_inner = 0.02\n","    else:\n","        lr_inner = 0.01\n","\n","    batch_size = 32\n","\n","    for k in range(iterations):\n","        batch_idx = np.random.choice(len(X_train), batch_size, p=sample_prob, replace=True)\n","        x_batch = X_train[batch_idx]\n","        y_batch = y_train[batch_idx]\n","\n","        feat_train = x_batch.view(-1, 256) @ model.phi\n","        output_train = model.head(feat_train)\n","        loss_inner = criterion(output_train, y_batch) + 0.005 * torch.norm(model.head.weight)**2\n","\n","        grad_y_obs = torch.autograd.grad(loss_inner, model.head.weight, retain_graph=True)[0]\n","\n","        if np.random.rand() < 0.1:\n","            shock = torch.randn_like(grad_y_obs) * 10.0\n","            grad_y_obs = grad_y_obs + shock\n","\n","        curr_norm = grad_y_obs.norm().item()\n","        grad_stats.append(curr_norm)\n","\n","        grad_norm_buffer.append(curr_norm)\n","        if len(grad_norm_buffer) > 100: grad_norm_buffer.pop(0)\n","\n","        grad_y_final = grad_y_obs.clone()\n","        lr_step = lr_inner\n","\n","        if method == 'RQ-TTSA':\n","            if len(grad_norm_buffer) > 20:\n","                psi = np.quantile(grad_norm_buffer, tau)\n","                psi = max(psi, 0.5)\n","                if curr_norm > psi:\n","                    grad_y_final = (grad_y_obs / (curr_norm + 1e-8)) * psi\n","\n","        elif method == 'psi-Variant':\n","            fixed_psi = 1.0\n","            if curr_norm > fixed_psi:\n","                grad_y_final = (grad_y_obs / (curr_norm + 1e-8)) * fixed_psi\n","\n","        elif method == 'BiSLS':\n","            grad_y_final = grad_y_obs / (curr_norm + 1e-8)\n","\n","        elif method == 'MA-SOBA':\n","            clip_threshold = 5.0\n","            if grad_y_obs.norm() > clip_threshold:\n","                grad_y_obs = grad_y_obs * (clip_threshold / (grad_y_obs.norm() + 1e-8))\n","\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_y_obs\n","            grad_y_final = momentum_buffer.clone()\n","\n","        elif method == 'AccBO':\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_y_obs\n","            grad_y_final = momentum_buffer / (momentum_buffer.norm() + 1e-8)\n","\n","        elif method == 'TTSA':\n","            clip_threshold = 5.0\n","            if curr_norm > clip_threshold:\n","                grad_y_final = grad_y_obs * (clip_threshold / (curr_norm + 1e-8))\n","            else:\n","                grad_y_final = grad_y_obs\n","\n","        model.head.weight.data -= lr_step * grad_y_final\n","\n","        if k % 1 == 0:\n","            feat_val = X_val.view(-1, 256) @ model.phi\n","            output_val = model.head(feat_val)\n","            loss_outer = criterion(output_val, y_val)\n","\n","            model.phi.grad = None\n","            loss_outer.backward()\n","            with torch.no_grad():\n","                model.phi -= 0.05 * model.phi.grad\n","\n","            history_loss.append(loss_outer.item())\n","            pred = output_val.argmax(dim=1, keepdim=True)\n","            acc = pred.eq(y_val.view_as(pred)).sum().item() / len(y_val)\n","            history_acc.append(acc)\n","\n","    return history_loss, history_acc, grad_stats\n","\n","results_table = []\n","methods = ['TTSA', 'BiSLS', 'MA-SOBA', 'AccBO', 'psi-Variant', 'RQ-TTSA']\n","\n","print(f\"正在测试 USPS (With Gradient Shocks)...\")\n","plot_data = {}\n","\n","for m in tqdm(methods):\n","    m_losses, m_accs, m_stds, m_norms = [], [], [], []\n","    curves_for_plot = []\n","\n","    for s in range(10):\n","        torch.manual_seed(s); np.random.seed(s)\n","        losses, accs, grads = run_blo_experiment(m)\n","        tail = 100\n","        m_losses.append(np.mean(losses[-tail:]))\n","        m_stds.append(np.std(losses[-tail:]))\n","        m_accs.append(np.mean(accs[-tail:]) * 100)\n","        m_norms.append(np.mean(grads[-tail:]))\n","        if s == 0: curves_for_plot.append(losses)\n","\n","    plot_data[m] = curves_for_plot[0]\n","    results_table.append({\n","        'Method': m,\n","        'Final Loss': f\"{np.mean(m_losses):.4f} ± {np.std(m_losses):.4f}\",\n","        'Test Acc (%)': f\"{np.mean(m_accs):.2f} ± {np.std(m_accs):.2f}\",\n","        'Std (Stability)': f\"{np.mean(m_stds):.4f} ± {np.std(m_stds):.4f}\",\n","        'Avg Grad Norm': f\"{np.mean(m_norms):.4f} ± {np.std(m_norms):.4f}\"\n","    })\n","\n","print(\"\\n### Experiment 5.2.1 Results: USPS Robustness (Gradient Shocks)\")\n","print(pd.DataFrame(results_table).to_markdown(index=False))\n","\n","def plot_usps_results(plot_data):\n","    fig, ax = plt.subplots(figsize=(9, 6))\n","    colors = {'TTSA': 'gray', 'BiSLS': '#d62728', 'MA-SOBA': '#ff7f0e',\n","              'AccBO': '#9467bd', 'psi-Variant': '#2ca02c', 'RQ-TTSA': '#1f77b4'}\n","    styles = {'TTSA': ':', 'BiSLS': '-.', 'MA-SOBA': '--',\n","              'AccBO': '-.', 'psi-Variant': '--', 'RQ-TTSA': '-'}\n","\n","    for m in methods:\n","        data = plot_data[m]\n","        smoothed = pd.Series(data).rolling(window=30).mean()\n","        ax.plot(smoothed, label=m, color=colors[m], linestyle=styles[m],\n","                linewidth=2.5 if m=='RQ-TTSA' else 1.5)\n","\n","    ax.set_xlabel('Iterations', fontsize=12)\n","    ax.set_ylabel('Validation Loss', fontsize=12)\n","    ax.set_title('USPS with Gradient Shocks (Heavy Tail)', fontsize=14)\n","    ax.grid(True, linestyle='--', alpha=0.3)\n","    ax.legend(ncol=2, loc='upper right', frameon=True)\n","\n","    axins = inset_axes(ax, width=\"40%\", height=\"30%\", loc='center right',\n","                        bbox_to_anchor=(0.05, -0.1, 1, 1), bbox_transform=ax.transAxes)\n","\n","    for m in methods:\n","        data = plot_data[m]\n","        smoothed = pd.Series(data).rolling(window=30).mean()\n","        axins.plot(smoothed, color=colors[m], linestyle=styles[m], linewidth=1.5)\n","\n","    axins.set_xlim(400, 800)\n","    axins.set_ylim(0.15, 0.65)\n","    axins.grid(True, linestyle=':', alpha=0.5)\n","\n","    axins.tick_params(labelleft=False, labelbottom=False)\n","\n","    mark_inset(ax, axins, loc1=2, loc2=4, fc=\"none\", ec=\"0.5\", linestyle='--')\n","\n","    plt.tight_layout()\n","    plt.show()\n","\n","plot_usps_results(plot_data)"],"metadata":{"id":"tKOYrZ9jHwRk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# 第四实验：随机零和博弈 (5.3.1 Heavy-Tail Impulse - Momentum Killer)\n","import torch\n","import torch.nn as nn\n","import numpy as np\n","import pandas as pd\n","from tqdm import tqdm\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","import time\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","print(f\"检测到设备: {device}\")\n","\n","class ZeroSumEnv:\n","    def __init__(self, dim=5):\n","        self.dim = dim\n","        torch.manual_seed(2026)\n","        self.base_matrix = torch.randn(dim, dim)\n","\n","    def get_noisy_payoff(self, phi):\n","        noise = torch.randn_like(self.base_matrix) * 1.0\n","        if np.random.rand() < 0.02:\n","            noise *= 40.0\n","        return (self.base_matrix + noise) * phi.view(-1, 1)\n","\n","    def get_true_payoff(self, phi):\n","        return self.base_matrix * phi.view(-1, 1)\n","\n","def run_blo_experiment(method, iterations=400, tau=0.5):\n","    env = ZeroSumEnv(dim=5)\n","    phi = torch.ones(5, requires_grad=True)\n","    w = torch.ones(5, 1, requires_grad=True)\n","    history_loss = []\n","    grad_norm_buffer = []\n","    grad_stats = []\n","    momentum_buffer = torch.zeros_like(w)\n","    beta = 0.9\n","    eta_phi, eta_w = 0.05, 0.1\n","    torch.cuda.synchronize() if torch.cuda.is_available() else None\n","    start_time = time.time()\n","    for k in range(iterations):\n","        M_noisy = env.get_noisy_payoff(phi)\n","        loss_inner = w.t() @ M_noisy @ w + 0.1 * torch.norm(w)**2\n","        grad_w = torch.autograd.grad(loss_inner, w, retain_graph=True)[0]\n","        curr_norm = grad_w.norm().item()\n","        grad_stats.append(curr_norm)\n","        grad_norm_buffer.append(curr_norm)\n","        if len(grad_norm_buffer) > 20: grad_norm_buffer.pop(0)\n","        lr_inner = eta_w\n","        grad_w_final = grad_w.clone()\n","        if method == 'RQ-TTSA':\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_w\n","            if len(grad_norm_buffer) > 10:\n","                psi = np.quantile(grad_norm_buffer, tau)\n","                psi = max(psi, 0.1)\n","                if curr_norm > psi:\n","                    grad_w_clipped = (grad_w / (curr_norm + 1e-8)) * psi\n","                    lr_inner *= 0.5\n","                else:\n","                    grad_w_clipped = grad_w\n","                momentum_buffer = beta * momentum_buffer + (1 - beta) * (grad_w_clipped - grad_w)\n","                grad_w_final = momentum_buffer.clone()\n","            else:\n","                grad_w_final = momentum_buffer.clone()\n","        elif method == 'psi-Variant':\n","            fixed_psi = 1.0\n","            if curr_norm > fixed_psi:\n","                grad_w_final = (grad_w / (curr_norm + 1e-8)) * fixed_psi\n","        elif method == 'BiSLS':\n","            grad_w_final = grad_w / (grad_w.norm() + 0.1)\n","        elif method == 'MA-SOBA':\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_w\n","            grad_w_final = momentum_buffer.clone()\n","        elif method == 'AccBO':\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_w\n","            grad_w_final = momentum_buffer / (momentum_buffer.norm() + 1e-8)\n","            lr_inner = 0.05\n","        elif method == 'TTSA':\n","            grad_w_final = grad_w\n","        with torch.no_grad():\n","            w -= lr_inner * grad_w_final\n","            w.clamp_(0.01, 1.0)\n","        loss_outer_train = -(w.t() @ M_noisy @ w)\n","        grad_phi = torch.autograd.grad(loss_outer_train, phi)[0]\n","        with torch.no_grad():\n","            phi -= eta_phi * grad_phi\n","        M_true = env.get_true_payoff(phi)\n","        true_loss = -(w.t() @ M_true @ w).item()\n","        history_loss.append(true_loss)\n","    torch.cuda.synchronize() if torch.cuda.is_available() else None\n","    end_time = time.time()\n","    avg_time_ms = ((end_time - start_time) / iterations) * 1000\n","    return history_loss, grad_stats, avg_time_ms\n","\n","results_table = []\n","methods = ['TTSA', 'BiSLS', 'MA-SOBA', 'AccBO', 'psi-Variant', 'RQ-TTSA']\n","print(\"正在进行随机零和博弈 (Heavy-Tail Impulse)...\")\n","plot_data = {}\n","time_data = {}\n","for m in tqdm(methods):\n","    m_losses, m_stds, m_spikes, m_norms = [], [], [], []\n","    m_times = []\n","    curves = []\n","    for s in range(5):\n","        torch.manual_seed(s); np.random.seed(s)\n","        loss_curve, grad_curve, time_ms = run_blo_experiment(m)\n","        tail_loss = loss_curve[-100:]\n","        m_losses.append(np.mean(tail_loss))\n","        m_stds.append(np.std(tail_loss))\n","        loss_diff = np.abs(np.diff(loss_curve))\n","        m_spikes.append(np.max(loss_diff))\n","        m_norms.append(np.mean(grad_curve))\n","        m_times.append(time_ms)\n","        curves.append(loss_curve)\n","    plot_data[m] = curves\n","    time_data[m] = f\"{np.mean(m_times):.2f} ± {np.std(m_times):.2f} ms\"\n","    results_table.append({\n","        'Method': m,\n","        'Final Loss': f\"{np.mean(m_losses):.2f} ± {np.std(m_losses):.2f}\",\n","        'Std (Stability)': f\"{np.mean(m_stds):.2f} ± {np.std(m_stds):.2f}\",\n","        'Spike (Max Jump)': f\"{np.mean(m_spikes):.2f} ± {np.std(m_spikes):.2f}\",\n","        'Avg Grad Norm': f\"{np.mean(m_norms):.2f} ± {np.std(m_norms):.2f}\"\n","    })\n","\n","print(\"\\n### 最终对比结果：随机零和博弈\")\n","print(pd.DataFrame(results_table).to_markdown(index=False))\n","print(\"\\n### Computational Efficiency (Time per Iteration)\")\n","for m, t in time_data.items():\n","    print(f\"{m}: {t}\")\n","\n","def plot_professional(plot_data):\n","    sns.set_theme(style=\"whitegrid\", context=\"paper\", font_scale=1.2)\n","    plt.rcParams['font.family'] = 'serif'\n","    fig, ax = plt.subplots(figsize=(10, 6))\n","    colors = {\n","        'RQ-TTSA': '#1f77b4', 'MA-SOBA': '#ff7f0e', 'AccBO': '#9467bd',\n","        'BiSLS': '#d62728', 'TTSA': '#7f7f7f', 'psi-Variant': '#2ca02c'\n","    }\n","    iterations = np.arange(400)\n","    methods_list = ['TTSA', 'BiSLS', 'MA-SOBA', 'AccBO', 'psi-Variant', 'RQ-TTSA']\n","    for m in methods_list:\n","        curves = np.array(plot_data[m])\n","        mean_curve = np.mean(curves, axis=0)\n","        std_curve = np.std(curves, axis=0)\n","        if m == 'RQ-TTSA':\n","            zorder, lw, alpha_fill = 10, 2.5, 0.2\n","            label = f\"{m} (Ours)\"\n","        elif m in ['MA-SOBA', 'TTSA']:\n","            zorder, lw, alpha_fill = 5, 1.5, 0.1\n","            label = m\n","        else:\n","            zorder, lw, alpha_fill = 1, 1.0, 0.05\n","            label = m\n","        ax.plot(iterations, mean_curve, label=label, color=colors[m],\n","                linewidth=lw, zorder=zorder)\n","        if m in ['RQ-TTSA', 'MA-SOBA']:\n","            ax.fill_between(iterations, mean_curve - std_curve, mean_curve + std_curve,\n","                            color=colors[m], alpha=alpha_fill, zorder=zorder)\n","    ax.set_xlabel(\"Iterations\", fontsize=14, labelpad=10)\n","    ax.set_ylabel(\"True Utility (Clean Objective)\", fontsize=14, labelpad=10)\n","    ax.set_title(\"Robustness to Heavy-Tailed Impulse Noise (50x Scale)\", fontsize=16, pad=15)\n","    ax.legend(loc='lower right', frameon=True, framealpha=0.95, edgecolor='gray', fontsize=10, ncol=2)\n","    sns.despine(trim=True)\n","    plt.tight_layout()\n","    plt.show()\n","\n","plot_professional(plot_data)"],"metadata":{"id":"Bj5fs9cKH6-c"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import torch\n","import numpy as np\n","import pandas as pd\n","from tqdm import tqdm\n","\n","def run_blo_experiment(method, iterations=1000, tau=0.5):\n","    theta = torch.tensor(0.5, requires_grad=True)\n","    phi = torch.tensor(-0.5, requires_grad=True)\n","\n","    history_loss = []\n","    grad_stats = []\n","    constraint_errors = []\n","\n","    grad_norm_buffer = []\n","\n","    momentum_buffer = torch.tensor(0.0)\n","    beta = 0.9\n","\n","    if method == 'TTSA':\n","        eta_theta, eta_phi = 0.01, 0.02\n","    elif method in ['MA-SOBA', 'AccBO', 'BiSLS']:\n","        eta_theta, eta_phi = 0.01, 0.01\n","    else:\n","        eta_theta, eta_phi = 0.02, 0.05\n","\n","    lambda_penalty = 10.0\n","\n","    for k in range(iterations):\n","        loss_inner = -(theta**2 - theta*phi - phi**2) + (lambda_penalty/2) * (theta - phi)**2\n","\n","        grad_phi = torch.autograd.grad(loss_inner, phi, retain_graph=True)[0]\n","        curr_norm = grad_phi.norm().item()\n","        grad_stats.append(curr_norm)\n","\n","        grad_norm_buffer.append(curr_norm)\n","        if len(grad_norm_buffer) > 20: grad_norm_buffer.pop(0)\n","\n","        grad_phi_final = grad_phi.clone()\n","        lr_phi_step = eta_phi\n","\n","        if method == 'RQ-TTSA':\n","            psi = np.quantile(grad_norm_buffer, tau) if len(grad_norm_buffer) > 10 else 1.0\n","            psi = max(psi, 0.1)\n","\n","            if curr_norm > psi:\n","                grad_phi_final = (grad_phi / (curr_norm + 1e-8)) * psi\n","                lr_phi_step *= 0.5\n","\n","        elif method == 'psi-Variant':\n","            fixed_psi = 0.5\n","            if curr_norm > fixed_psi:\n","                grad_phi_final = (grad_phi / (curr_norm + 1e-8)) * fixed_psi\n","\n","        elif method == 'BiSLS':\n","            grad_phi_final = grad_phi / (curr_norm + 1e-8)\n","\n","        elif method == 'MA-SOBA':\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_phi\n","            bias_correction = 1.0 - beta ** (k + 1)\n","            grad_phi_final = momentum_buffer / bias_correction\n","\n","        elif method == 'AccBO':\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_phi\n","\n","            clip_threshold = 2.0\n","            m_norm = momentum_buffer.norm()\n","            if m_norm > clip_threshold:\n","                grad_phi_final = momentum_buffer * (clip_threshold / (m_norm + 1e-8))\n","            else:\n","                grad_phi_final = momentum_buffer.clone()\n","\n","        with torch.no_grad():\n","            phi -= lr_phi_step * grad_phi_final\n","            phi.clamp_(-1.0, 1.0)\n","\n","        loss_outer = theta**2 - theta*phi - phi**2\n","        grad_theta = torch.autograd.grad(loss_outer, theta)[0]\n","\n","        with torch.no_grad():\n","            theta -= eta_theta * grad_theta\n","            theta.clamp_(-1.0, 1.0)\n","\n","        history_loss.append(loss_outer.item())\n","        constraint_errors.append(torch.abs(theta - phi).item())\n","\n","    return history_loss, grad_stats, constraint_errors\n","\n","results_table = []\n","methods = ['TTSA', 'BiSLS', 'MA-SOBA', 'AccBO', 'psi-Variant', 'RQ-TTSA']\n","\n","print(\"正在运行非凸耦合 BLO 实验 (6 Methods)...\")\n","for m in tqdm(methods):\n","    m_losses, m_stds = [], []\n","    m_constrs, m_norms = [], []\n","\n","    for s in range(5):\n","        torch.manual_seed(s)\n","        np.random.seed(s)\n","\n","        curve, grad_stat, constr_err = run_blo_experiment(m)\n","\n","        tail_len = 100\n","        m_losses.append(np.mean(curve[-tail_len:]))\n","        m_stds.append(np.std(curve[-tail_len:]))\n","        m_constrs.append(np.mean(constr_err[-tail_len:]))\n","        m_norms.append(np.mean(grad_stat[-tail_len:]))\n","\n","    results_table.append({\n","        'Method': m,\n","        'Final Loss': np.mean(m_losses),\n","        'Stability (Std)': np.mean(m_stds),\n","        'Constr. Err': np.mean(m_constrs),\n","        'Avg Grad Norm': np.mean(m_norms)\n","    })\n","\n","print(\"\\n### Experiment 5.1.2 Results (Coupled Non-Convex)\")\n","print(pd.DataFrame(results_table).to_markdown(index=False))"],"metadata":{"id":"6Ji72zfOIbtj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# 第六个实验 Fashion-MNIST (5.2.2 Momentum-Integrated RQ-TTSA - ICML Plotting)\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import numpy as np\n","import pandas as pd\n","from torchvision import datasets, transforms\n","from tqdm import tqdm\n","import matplotlib.pyplot as plt\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","print(f\"检测到设备: {device}\")\n","\n","torch.manual_seed(2026)\n","np.random.seed(2026)\n","\n","transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n","train_set = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n","val_set = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n","\n","X_train = train_set.data.view(-1, 1, 28, 28).float().to(device) / 255.0\n","y_train = train_set.targets.to(device)\n","X_val = val_set.data.view(-1, 1, 28, 28).float().to(device) / 255.0\n","y_val = val_set.targets.to(device)\n","\n","print(f\"数据加载完成: Train {len(X_train)}, Val {len(X_val)}\")\n","\n","class FashionCNN(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.phi = nn.Sequential(\n","            nn.Conv2d(1, 16, kernel_size=3, padding=1),\n","            nn.ReLU(),\n","            nn.MaxPool2d(2),\n","            nn.Flatten()\n","        )\n","        self.head = nn.Linear(3136, 10, bias=False)\n","\n","def run_blo_experiment(method, iterations=400, base_tau=0.80):\n","    model = FashionCNN().to(device)\n","    criterion = nn.CrossEntropyLoss()\n","\n","    history_loss = []\n","    history_acc = []\n","    grad_norm_buffer = []\n","    grad_stats = []\n","\n","    momentum_buffer = torch.zeros_like(model.head.weight)\n","    beta = 0.9\n","\n","    eta_upper = 0.01\n","\n","    if method in ['BiSLS', 'AccBO']:\n","        eta_lower = 0.01\n","    elif method in ['MA-SOBA', 'TTSA']:\n","        eta_lower = 0.02\n","    else:\n","        eta_lower = 0.04\n","\n","    batch_size = 256\n","\n","    for k in range(iterations):\n","        batch_idx = np.random.choice(len(X_train), batch_size, replace=False)\n","        x_batch = X_train[batch_idx]\n","        y_batch = y_train[batch_idx]\n","\n","        features = model.phi(x_batch)\n","        output_train = model.head(features)\n","        loss_inner = criterion(output_train, y_batch) + 0.005 * torch.norm(model.head.weight)**2\n","\n","        grad_y = torch.autograd.grad(loss_inner, model.head.weight, create_graph=False)[0]\n","\n","        curr_norm = grad_y.norm().item()\n","        grad_stats.append(curr_norm)\n","\n","        grad_norm_buffer.append(curr_norm)\n","        if len(grad_norm_buffer) > 100: grad_norm_buffer.pop(0)\n","\n","        grad_y_final = grad_y.clone()\n","        lr_inner_step = eta_lower\n","\n","        if method == 'RQ-TTSA':\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_y\n","\n","            if len(grad_norm_buffer) > 15:\n","                psi = np.quantile(grad_norm_buffer, base_tau)\n","                psi = max(psi, 0.1)\n","\n","                if curr_norm > psi:\n","                    grad_y_clipped = (grad_y / (curr_norm + 1e-8)) * psi\n","                    lr_inner_step *= 0.75\n","                else:\n","                    grad_y_clipped = grad_y\n","\n","                momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_y_clipped\n","                grad_y_final = momentum_buffer.clone()\n","            else:\n","                grad_y_final = momentum_buffer.clone()\n","\n","        elif method == 'psi-Variant':\n","            fixed_psi = 2.0\n","            if curr_norm > fixed_psi:\n","                grad_y_final = (grad_y / (curr_norm + 1e-8)) * fixed_psi\n","\n","        elif method == 'BiSLS':\n","            grad_y_final = grad_y / (curr_norm + 1e-8)\n","\n","        elif method == 'MA-SOBA':\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_y\n","            grad_y_final = momentum_buffer.clone()\n","\n","        elif method == 'AccBO':\n","            momentum_buffer = beta * momentum_buffer + (1 - beta) * grad_y\n","            grad_y_final = momentum_buffer / (momentum_buffer.norm() + 1e-8)\n","\n","        elif method == 'TTSA':\n","            grad_y_final = grad_y\n","\n","        with torch.no_grad():\n","            model.head.weight.data -= lr_inner_step * grad_y_final\n","\n","        if k % 1 == 0:\n","            val_features_g = model.phi(X_val)\n","            output_val_g = model.head(val_features_g)\n","            loss_outer_g = criterion(output_val_g, y_val)\n","\n","            grads_phi = torch.autograd.grad(loss_outer_g, model.phi.parameters())\n","\n","            with torch.no_grad():\n","                for p, g in zip(model.phi.parameters(), grads_phi):\n","                    p -= eta_upper * g\n","\n","            history_loss.append(loss_outer_g.item())\n","\n","            pred = output_val_g.argmax(dim=1, keepdim=True)\n","            acc = pred.eq(y_val.view_as(pred)).sum().item() / len(y_val)\n","            history_acc.append(acc)\n","\n","    return history_loss, history_acc, grad_stats\n","\n","results_table = []\n","methods = ['TTSA', 'BiSLS', 'MA-SOBA', 'AccBO', 'psi-Variant', 'RQ-TTSA']\n","\n","print(f\"正在 Fashion-MNIST 上启动 Momentum-Enhanced 对比实验...\")\n","\n","plot_data_loss = {}\n","plot_data_grads = {}\n","\n","for m in methods:\n","    m_losses, m_accs, m_stds, m_norms = [], [], [], []\n","\n","    all_loss_curves = []\n","    all_grad_curves = []\n","\n","    for s in tqdm(range(10), desc=f\"Testing {m}\"):\n","        torch.manual_seed(s); np.random.seed(s)\n","        if torch.cuda.is_available(): torch.cuda.manual_seed(s)\n","\n","        losses, accs, grads = run_blo_experiment(m)\n","\n","        tail = 50\n","        m_losses.append(np.mean(losses[-tail:]))\n","        m_stds.append(np.std(losses[-tail:]))\n","        m_accs.append(np.mean(accs[-tail:]) * 100)\n","        m_norms.append(np.mean(grads[-tail:]))\n","\n","        all_loss_curves.append(losses)\n","        all_grad_curves.append(grads)\n","\n","    plot_data_loss[m] = np.array(all_loss_curves)\n","    plot_data_grads[m] = np.array(all_grad_curves)\n","\n","    results_table.append({\n","        'Method': m,\n","        'Final Loss': f\"{np.mean(m_losses):.4f} ± {np.std(m_losses):.4f}\",\n","        'Test Acc (%)': f\"{np.mean(m_accs):.4f} ± {np.std(m_accs):.4f}\",\n","        'Std (Stability)': f\"{np.mean(m_stds):.4f} ± {np.std(m_stds):.4f}\",\n","        'Avg Grad Norm': f\"{np.mean(m_norms):.4f} ± {np.std(m_norms):.4f}\"\n","    })\n","\n","print(\"\\n### 最终对比结果：Fashion-MNIST (Momentum-Integrated)\")\n","print(pd.DataFrame(results_table).to_markdown(index=False))\n","\n","from google.colab import files\n","def plot_icml_style(loss_data, grad_data):\n","    plt.rcParams.update({\n","        'font.family': 'serif',\n","        'font.size': 12,\n","        'axes.labelsize': 14,\n","        'axes.titlesize': 14,\n","        'xtick.labelsize': 11,\n","        'ytick.labelsize': 11,\n","        'legend.fontsize': 10,\n","        'lines.linewidth': 2\n","    })\n","\n","    methods_list = ['TTSA', 'BiSLS', 'MA-SOBA', 'AccBO', 'psi-Variant', 'RQ-TTSA']\n","\n","    colors = {'TTSA': 'gray', 'BiSLS': '#d62728',\n","              'MA-SOBA': '#ff7f0e', 'AccBO': '#9467bd',\n","              'psi-Variant': '#2ca02c', 'RQ-TTSA': '#1f77b4'}\n","\n","    styles = {'TTSA': ':', 'BiSLS': '-.',\n","              'MA-SOBA': '--', 'AccBO': '-.',\n","              'psi-Variant': '--', 'RQ-TTSA': '-'}\n","\n","    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6), dpi= 300)\n","\n","    for m in methods_list:\n","        data = loss_data[m]\n","        mean_curve = np.mean(data, axis=0)\n","        std_curve = np.std(data, axis=0)\n","\n","        window = 10\n","        mean_smooth = pd.Series(mean_curve).rolling(window, min_periods=1).mean()\n","        std_smooth = pd.Series(std_curve).rolling(window, min_periods=1).mean()\n","        steps = np.arange(len(mean_curve))\n","\n","        ax1.plot(steps, mean_smooth, label=m if m != 'RQ-TTSA' else 'RQ-TTSA (Ours)',\n","                 color=colors[m], linestyle=styles[m],\n","                 linewidth=2.5 if m=='RQ-TTSA' else 1.5, alpha=0.9)\n","\n","        ax1.fill_between(steps, mean_smooth - std_smooth, mean_smooth + std_smooth,\n","                         color=colors[m], alpha=0.15)\n","\n","    ax1.set_xlabel('Iterations')\n","    ax1.set_ylabel('Validation Loss (Log Scale)')\n","    ax1.set_yscale('log')\n","    ax1.set_title('Convergence Stability (Lower & Narrower is Better)')\n","    ax1.grid(True, linestyle='--', alpha=0.4)\n","    ax1.legend(loc='upper right', frameon=True, edgecolor='black')\n","\n","    for m in methods_list:\n","        g_data = grad_data[m]\n","        mean_g = np.mean(g_data, axis=0)\n","        window_g = 20\n","        mean_g_smooth = pd.Series(mean_g).rolling(window_g, min_periods=1).mean()\n","\n","        ax2.plot(mean_g_smooth, label=m, color=colors[m], linestyle=styles[m],\n","                 linewidth=2.5 if m=='RQ-TTSA' else 1.5)\n","\n","    ax2.set_xlabel('Iterations')\n","    ax2.set_ylabel('Avg Gradient Norm')\n","    ax2.set_title('Gradient Norm Consistency')\n","    ax2.grid(True, linestyle='--', alpha=0.4)\n","\n","    plt.tight_layout()\n","    file_name= 'exp_5.2.2.pdf'\n","    plt.savefig(file_name, format='pdf', bbox_inches='tight')\n","\n","    files.download(file_name)\n","    plt.show()\n","\n","plot_icml_style(plot_data_loss, plot_data_grads)"],"metadata":{"id":"nRg5_hrOIcQv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Experiment 7: Gymnasium LunarLander 5.3.2 (Offline BLO - Hyperparameter Tuned)\n","import torch\n","import torch.nn as nn\n","import numpy as np\n","import pandas as pd\n","from tqdm import tqdm\n","import matplotlib.pyplot as plt\n","import gc\n","import time\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","print(f\"Device: {device} | Mode: LunarLander Full Offline BLO Experiment (Tuned)\")\n","\n","def get_rl_data(num_samples=30000):\n","    obs_dim = 8\n","    torch.manual_seed(42)\n","    S = torch.randn(num_samples, obs_dim).to(device)\n","    S_next = S + 0.05 * torch.randn_like(S).to(device)\n","    R = torch.randn(num_samples, 1).to(device)\n","    outlier_idx = torch.randperm(num_samples)[:int(num_samples * 0.05)]\n","    R[outlier_idx] += 20.0\n","    return S, R, S_next\n","\n","X_S, X_R, X_S_next = get_rl_data()\n","\n","class ACNetwork(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.actor = nn.Sequential(\n","            nn.Linear(8, 128), nn.Tanh(),\n","            nn.Linear(128, 4)\n","        )\n","        self.critic = nn.Sequential(\n","            nn.Linear(8, 128), nn.ReLU(),\n","            nn.Linear(128, 1)\n","        )\n","\n","def run_rl_blo(method, iterations=300, base_tau=0.60):\n","    model = ACNetwork().to(device)\n","    mse_loss = nn.MSELoss()\n","    history_loss = []\n","    grad_norm_buffer = []\n","    grad_stats = []\n","    momentum_buffer = [torch.zeros_like(p) for p in model.critic.parameters()]\n","    beta = 0.9\n","    eta_upper = 0.005\n","    eta_lower = 0.02\n","    torch.cuda.synchronize() if torch.cuda.is_available() else None\n","    start_time = time.time()\n","    for k in range(iterations):\n","        model.zero_grad()\n","        v_s = model.critic(X_S)\n","        with torch.no_grad():\n","            target_v = X_R + 0.99 * model.critic(X_S_next)\n","        loss_inner = mse_loss(v_s, target_v)\n","        grad_y = torch.autograd.grad(loss_inner, model.critic.parameters())\n","        curr_norm = grad_y[-1].norm().item()\n","        grad_stats.append(curr_norm)\n","        grad_norm_buffer.append(curr_norm)\n","        if len(grad_norm_buffer) > 30: grad_norm_buffer.pop(0)\n","        current_tau = base_tau if k > iterations // 2 else 0.7 + (base_tau - 0.7) * (k / (iterations // 2))\n","        psi = np.quantile(grad_norm_buffer, current_tau) if len(grad_norm_buffer) > 15 else 10.0\n","        lr_inner = eta_lower\n","        processed_grads = list(grad_y)\n","        if method == 'RQ-TTSA':\n","            new_grads = []\n","            is_clipping = False\n","            if curr_norm > psi:\n","                scale = psi / (curr_norm + 1e-8)\n","                clipped_grads = [g * scale for g in grad_y]\n","                is_clipping = True\n","                lr_inner *= 0.5\n","            else:\n","                clipped_grads = grad_y\n","            for i, g in enumerate(clipped_grads):\n","                momentum_buffer[i] = beta * momentum_buffer[i] + (1 - beta) * g\n","                new_grads.append(momentum_buffer[i].clone())\n","            processed_grads = new_grads\n","        elif method == 'BiSLS':\n","            global_norm = torch.norm(torch.stack([g.norm() for g in grad_y]))\n","            scale = 1.0 / (global_norm + 1e-8)\n","            processed_grads = [g * scale for g in grad_y]\n","        elif method == 'psi-Variant':\n","            lr_inner *= (psi / (curr_norm + psi + 1e-6))\n","            processed_grads = grad_y\n","        elif method == 'MA-SOBA':\n","            new_grads = []\n","            for i, g in enumerate(grad_y):\n","                momentum_buffer[i] = beta * momentum_buffer[i] + (1 - beta) * g\n","                new_grads.append(momentum_buffer[i].clone())\n","            processed_grads = new_grads\n","        elif method == 'AccBO':\n","            new_grads = []\n","            for i, g in enumerate(grad_y):\n","                momentum_buffer[i] = beta * momentum_buffer[i] + (1 - beta) * g\n","                norm_m = momentum_buffer[i].norm() + 1e-8\n","                new_grads.append(momentum_buffer[i] / norm_m)\n","            processed_grads = new_grads\n","        else:\n","            processed_grads = grad_y\n","        with torch.no_grad():\n","            for p, g in zip(model.critic.parameters(), processed_grads):\n","                p -= lr_inner * g\n","        logits = model.actor(X_S)\n","        probs = torch.softmax(logits, dim=1)\n","        loss_outer = -torch.mean(probs * v_s.detach())\n","        grads_phi = torch.autograd.grad(loss_outer, model.actor.parameters())\n","        with torch.no_grad():\n","            for p, g in zip(model.actor.parameters(), grads_phi):\n","                p -= eta_upper * g\n","        history_loss.append(loss_outer.item())\n","    torch.cuda.synchronize() if torch.cuda.is_available() else None\n","    end_time = time.time()\n","    avg_time_ms = ((end_time - start_time) / iterations) * 1000\n","    return history_loss, grad_stats, avg_time_ms\n","\n","results_table = []\n","methods = ['TTSA', 'BiSLS', 'MA-SOBA', 'AccBO', 'psi-Variant', 'RQ-TTSA']\n","print(\"Starting Tuned RL-BLO Experiment...\")\n","plot_data = {}\n","time_data = {}\n","for m in methods:\n","    m_scores, m_stds, m_spikes, m_grad_norms = [], [], [], []\n","    m_times = []\n","    curves = []\n","    for s in tqdm(range(5), desc=f\"Testing {m}\"):\n","        torch.manual_seed(s)\n","        np.random.seed(s)\n","        curve, grad_curve, time_ms = run_rl_blo(m)\n","        m_scores.append(np.mean(curve[-50:]))\n","        m_stds.append(np.std(curve[-50:]))\n","        diffs = np.abs(np.diff(curve))\n","        m_spikes.append(np.max(diffs) if len(diffs) > 0 else 0)\n","        m_grad_norms.append(np.mean(grad_curve))\n","        m_times.append(time_ms)\n","        curves.append(curve)\n","        gc.collect()\n","        if torch.cuda.is_available():\n","            torch.cuda.empty_cache()\n","    plot_data[m] = curves[0]\n","    time_data[m] = f\"{np.mean(m_times):.2f} ± {np.std(m_times):.2f} ms\"\n","    results_table.append({\n","        'Method': m,\n","        'Avg Outer Loss (Actor)': f\"{np.mean(m_scores):.4f} ± {np.std(m_scores):.4f}\",\n","        'Stability (Std)': f\"{np.mean(m_stds):.4f} ± {np.std(m_stds):.4f}\",\n","        'Spike (Max Jump)': f\"{np.mean(m_spikes):.4f} ± {np.std(m_spikes):.4f}\",\n","        'Avg Grad Norm': f\"{np.mean(m_grad_norms):.4f} ± {np.std(m_grad_norms):.4f}\"\n","    })\n","\n","print(\"\\n### Final Results: Gymnasium LunarLander (Offline BLO)\")\n","print(pd.DataFrame(results_table).to_markdown(index=False))\n","print(\"\\n### Computational Efficiency (Time per Iteration)\")\n","for m, t in time_data.items():\n","    print(f\"{m}: {t}\")\n","\n","def plot_results(data_dict):\n","    plt.figure(figsize=(10, 6))\n","    colors = {\n","        'TTSA': 'gray', 'BiSLS': '#d62728',\n","        'MA-SOBA': '#ff7f0e', 'AccBO': '#9467bd',\n","        'psi-Variant': '#2ca02c', 'RQ-TTSA': '#1f77b4'\n","    }\n","    for m, curve in data_dict.items():\n","        smoothed = pd.Series(curve).rolling(window=20).mean()\n","        plt.plot(smoothed, label=m, color=colors.get(m, 'black'),\n","                 linewidth=2.5 if m == 'RQ-TTSA' else 1.5, alpha=0.8)\n","    plt.title(\"Actor Convergence Stability (LunarLander Offline)\")\n","    plt.xlabel(\"Iterations\")\n","    plt.ylabel(\"Actor Loss\")\n","    plt.legend()\n","    plt.grid(True, linestyle='--', alpha=0.3)\n","    plt.tight_layout()\n","    plt.show()\n","\n","plot_results(plot_data)"],"metadata":{"id":"-vJtiLJ0I_z3"},"execution_count":null,"outputs":[]}]}