{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOfyHdrQwk9MRi+iaQWdLd7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"9IEbjWGJJnLw"},"outputs":[],"source":["#第一个实验 (Revised: Name Change to RF-BO (Ours) + Silent Per-Seed)  就是6.1.1\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from scipy import stats\n","import pandas as pd\n","\n","NUM_SEEDS = 15\n","SEEDS = [80, 93, 1, 11, 3, 7, 24, 1024, 10, 20, 27, 39, 69, 94, 63, 14, 66, 76, 58, 77, 83, 81, 67, 60, 70, 88, 72, 78, 99, 92, 96, 17, 21, 16, 34, 26, 8, 28, 12, 40, 32, 25, 37, 4, 31, 9, 18, 15][:NUM_SEEDS]\n","\n","d = 10\n","N = 1000\n","C = 5.0\n","lambda_reg = 0.1\n","kappa = 0.001\n","\n","T = 3000\n","batch_size = 128\n","alpha_init = 0.0\n","theta_init = np.zeros(d)\n","\n","probe_every = 200\n","K_probe = 50\n","\n","epsilon = 0.5\n","\n","def compute_theta_star(alpha, X_data, lambda_reg, kappa, max_iter=100, tol=1e-6):\n","    theta = np.zeros(d)\n","    for _ in range(max_iter):\n","        residuals = X_data @ theta - alpha\n","        grad = (X_data.T @ residuals) / N + lambda_reg * theta + 4 * kappa * (theta**3)\n","        grad_norm = np.linalg.norm(grad)\n","        if grad_norm > 10.0:\n","            grad = grad / grad_norm * 10.0\n","        theta_new = theta - 0.005 * grad\n","        if np.linalg.norm(theta_new - theta) < tol:\n","            break\n","        theta = theta_new\n","    return theta\n","\n","def run_single_seed(seed):\n","    np.random.seed(seed)\n","\n","    theta_true = np.random.randn(d)\n","    theta_true[0] += 3.0\n","    cov = np.diag([5.0] + [1.0] * (d-1))\n","    X_data = np.random.multivariate_normal([3.0] + [0.0] * (d-1), cov, N) + np.random.normal(0, 0.2, (N, d))\n","\n","    A_true = (X_data.T @ X_data / N) + lambda_reg * np.eye(d)\n","    theta_star_init = compute_theta_star(5.0, X_data, lambda_reg, kappa)\n","\n","    denominator = theta_true.T @ theta_star_init\n","    if abs(denominator) < 1e-10:\n","        theta_true = np.random.randn(d) + np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])\n","        theta_star_init = compute_theta_star(5.0, X_data, lambda_reg, kappa)\n","        denominator = theta_true.T @ theta_star_init\n","\n","    alpha_star = C / denominator\n","\n","    methods = ['RF-BO (Ours)', 'Opt-h2', 'Single-Scale']\n","    thetas = {m: theta_init.copy() for m in methods}\n","    alphas = {m: alpha_init for m in methods}\n","\n","    history = {\n","        'alpha': {m: [alpha_init] for m in methods},\n","        'h_sq': {m: [] for m in methods},\n","        'update_norm_sq': {m: [] for m in methods},\n","        'h_values': {m: [] for m in methods},\n","        'probe': {},\n","        'converge_iter': {m: None for m in methods}\n","    }\n","\n","    for t in range(1, T + 1):\n","        eta_t = 0.5 / (t + 10)**0.5\n","        gamma_t = 0.0035 / (t + 10)**0.6\n","        gamma_base_opth2_fixed = 0.003\n","        gamma_t_opth2_fixed = gamma_base_opth2_fixed / (t + 10)**0.6\n","\n","        indices = np.random.choice(N, batch_size, replace=False)\n","        X_batch = X_data[indices]\n","\n","        step_sizes = {'RF-BO (Ours)': gamma_t, 'Opt-h2': gamma_t_opth2_fixed, 'Single-Scale': eta_t}\n","\n","        for method in methods:\n","            residuals = X_batch @ thetas[method] - alphas[method]\n","            grad_theta = (X_batch.T @ residuals) / batch_size + lambda_reg * thetas[method] + 4 * kappa * (thetas[method]**3)\n","\n","            grad_norm = np.linalg.norm(grad_theta)\n","            if grad_norm > 10.0:\n","                grad_theta = grad_theta / grad_norm * 10.0\n","            thetas[method] -= eta_t * grad_theta\n","\n","            h_t = theta_true.T @ thetas[method] - C\n","\n","            update_direction = 0\n","            if method == 'RF-BO (Ours)' or method == 'Single-Scale':\n","                update_direction = h_t\n","            elif method == 'Opt-h2':\n","                A_batch = (X_batch.T @ X_batch / batch_size) + lambda_reg * np.eye(d)\n","                b_batch = np.mean(X_batch, axis=0)\n","                try:\n","                    A_batch_inv = np.linalg.inv(A_batch)\n","                    jacobian_approx = A_batch_inv @ b_batch\n","                    nabla_alpha_h = theta_true.T @ jacobian_approx\n","                    update_direction = h_t * nabla_alpha_h\n","                except np.linalg.LinAlgError:\n","                    update_direction = h_t\n","\n","            alphas[method] -= step_sizes[method] * update_direction\n","\n","            history['alpha'][method].append(alphas[method])\n","            history['h_sq'][method].append(h_t**2)\n","            history['h_values'][method].append(h_t)\n","            history['update_norm_sq'][method].append(update_direction**2)\n","\n","            if history['converge_iter'][method] is None and abs(h_t) < epsilon:\n","                history['converge_iter'][method] = t\n","\n","        if t % probe_every == 0:\n","            for method in ['RF-BO (Ours)', 'Opt-h2']:\n","                h_s, g_ttsa, g_opth2 = sample_G_alpha(alphas[method], theta_true, X_data, N, d, lambda_reg, K_probe, batch_size)\n","                if method not in history['probe']:\n","                    history['probe'][method] = {'mean_h_sq': [], 'var_g': [], 'h_sq_vals': [], 'update_sq_vals': []}\n","                history['probe'][method]['mean_h_sq'].append(np.mean(h_s**2))\n","\n","                if method == 'RF-BO (Ours)':\n","                    history['probe'][method]['var_g'].append(np.var(g_ttsa))\n","                    history['probe'][method]['h_sq_vals'].extend(h_s**2)\n","                    history['probe'][method]['update_sq_vals'].extend(g_ttsa**2)\n","                else:\n","                    history['probe'][method]['var_g'].append(np.var(g_opth2))\n","                    history['probe'][method]['h_sq_vals'].extend(h_s**2)\n","                    history['probe'][method]['update_sq_vals'].extend(g_opth2**2)\n","\n","    return history, alpha_star, theta_true, X_data\n","\n","def sample_G_alpha(alpha_val, theta_true, X_data, N, d, lambda_reg, K, batch_size):\n","    h_s = np.zeros(K)\n","    g_ttsa = np.zeros(K)\n","    g_opth2 = np.zeros(K)\n","    for k in range(K):\n","        idx = np.random.choice(N, batch_size, replace=False)\n","        Xb = X_data[idx]\n","        A_b = (Xb.T @ Xb) / batch_size + lambda_reg * np.eye(d)\n","        b_b = np.mean(Xb, axis=0)\n","        try:\n","            A_b_inv = np.linalg.inv(A_b)\n","        except np.linalg.LinAlgError:\n","            A_b_inv = np.linalg.pinv(A_b)\n","        jac = A_b_inv @ b_b\n","        theta_star_b = jac * alpha_val\n","        h_b = theta_true.T @ theta_star_b - C\n","        h_s[k] = h_b\n","        g_ttsa[k] = h_b\n","        g_opth2[k] = h_b * (theta_true.T @ jac)\n","    return h_s, g_ttsa, g_opth2\n","\n","def calculate_p_values(final_errors_dict, methods_list, baseline_method='RF-BO (Ours)'):\n","    p_values = {}\n","    baseline_errors = final_errors_dict.get(baseline_method)\n","    if baseline_errors is None or len(baseline_errors) < 2:\n","        return {m: 'N/A' for m in methods_list}\n","\n","    for method in methods_list:\n","        if method == baseline_method:\n","            p_values[method] = 'N/A'\n","            continue\n","\n","        method_errors = final_errors_dict.get(method)\n","        if method_errors is None or len(method_errors) < 2:\n","            p_values[method] = 'N/A'\n","            continue\n","\n","        t_stat, p_val = stats.ttest_ind(method_errors, baseline_errors, equal_var=False, nan_policy='omit')\n","        p_values[method] = f\"{p_val:.3f}\"\n","\n","    return p_values\n","\n","\n","print(\"开始多seed实验...\")\n","all_results = []\n","methods = ['RF-BO (Ours)', 'Opt-h2', 'Single-Scale']\n","\n","for i, seed in enumerate(SEEDS):\n","    print(f\"运行 Seed {i+1}/{NUM_SEEDS} (seed={seed})... Done.\")\n","    history, alpha_star, theta_true, X_data = run_single_seed(seed)\n","    all_results.append({\n","        'seed': seed,\n","        'history': history,\n","        'alpha_star': alpha_star\n","    })\n","\n","print(\"\\n所有seed实验完成！\")\n","\n","final_errors = {m: [] for m in methods}\n","converge_iters = {m: [] for m in methods}\n","final_var_upper = {m: [] for m in methods}\n","var_ratios = {m: [] for m in methods}\n","slopes = {m: [] for m in methods}\n","\n","for result in all_results:\n","    history = result['history']\n","    alpha_star = result['alpha_star']\n","\n","    for method in methods:\n","        final_alpha = history['alpha'][method][-1]\n","        if np.isnan(final_alpha):\n","            final_errors[method].append(np.inf)\n","        else:\n","            final_errors[method].append(abs(final_alpha - alpha_star))\n","\n","        if history['converge_iter'][method] is None:\n","            converge_iters[method].append(T)\n","        else:\n","            converge_iters[method].append(history['converge_iter'][method])\n","\n","        last_1000_updates = history['update_norm_sq'][method][-1000:]\n","        if np.any(np.isnan(last_1000_updates)):\n","            final_var_upper[method].append(np.nan)\n","        else:\n","            final_var_upper[method].append(np.mean(last_1000_updates))\n","\n","        last_1000_h_sq = history['h_sq'][method][-1000:]\n","        var_ratio = np.mean(last_1000_updates) / (np.mean(last_1000_h_sq) + 1e-10) if not np.any(np.isnan(last_1000_updates)) else np.nan\n","        var_ratios[method].append(var_ratio)\n","\n","        if method in ['RF-BO (Ours)', 'Opt-h2'] and method in history['probe']:\n","            h_sq_vals = np.array(history['probe'][method]['h_sq_vals'])\n","            update_sq_vals = np.array(history['probe'][method]['update_sq_vals'])\n","            if len(h_sq_vals) > 10 and not np.any(np.isnan(h_sq_vals)) and not np.any(np.isnan(update_sq_vals)):\n","                slope, _, _, _, _ = stats.linregress(h_sq_vals, update_sq_vals)\n","                slopes[method].append(slope)\n","            else:\n","                slopes[method].append(np.nan)\n","        else:\n","            slopes[method].append(np.nan)\n","\n","print(\"\\n生成 Figure 1(a-b)...\")\n","fig1 = plt.figure(figsize=(8, 6))\n","\n","colors = {'RF-BO (Ours)': '#1f77b4', 'Opt-h2': '#ff7f0e', 'Single-Scale': '#2ca02c'}\n","\n","ax1a = plt.subplot(2, 1, 1)\n","for method in methods:\n","    alpha_trajs = [res['history']['alpha'][method] for res in all_results]\n","    min_len = min(len(t) for t in alpha_trajs)\n","    alpha_trajs_trimmed = [t[:min_len] for t in alpha_trajs]\n","    alpha_mean = np.nanmean(alpha_trajs_trimmed, axis=0)\n","    alpha_std = np.nanstd(alpha_trajs_trimmed, axis=0)\n","\n","    iters = np.arange(len(alpha_mean))\n","    ax1a.plot(iters, alpha_mean, label=method, lw=2.5, color=colors[method])\n","    ax1a.fill_between(iters, alpha_mean - alpha_std, alpha_mean + alpha_std,\n","                      color=colors[method], alpha=0.2)\n","\n","alpha_star_ref = all_results[0]['alpha_star']\n","ax1a.axhline(alpha_star_ref, color='red', linestyle='--', linewidth=2,\n","            label=f'α* ≈ {alpha_star_ref:.2f}')\n","ax1a.set_title('(a) Convergence of α_t (mean±std)', fontsize=10)\n","ax1a.set_xlabel('Iteration t', fontsize=10)\n","ax1a.set_ylabel('Value of α', fontsize=10)\n","ax1a.legend(fontsize=8)\n","ax1a.grid(True, alpha=0.1)\n","\n","ax1b = plt.subplot(2, 1, 2)\n","for method in methods:\n","    h_trajs = [np.abs(res['history']['h_values'][method]) for res in all_results]\n","    min_len = min(len(t) for t in h_trajs)\n","    h_trajs_trimmed = [t[:min_len] for t in h_trajs]\n","    h_mean = np.nanmean(h_trajs_trimmed, axis=0)\n","    h_std = np.nanstd(h_trajs_trimmed, axis=0)\n","\n","    iters = np.arange(len(h_mean))\n","    ax1b.plot(iters, h_mean, label=method, lw=2, color=colors[method])\n","    ax1b.fill_between(iters, np.maximum(h_mean - h_std, 1e-6), h_mean + h_std,\n","                      color=colors[method], alpha=0.2)\n","\n","ax1b.axhline(epsilon, color='gray', linestyle=':', linewidth=1.5, label=f'ε={epsilon}')\n","ax1b.set_title('(b) Residual |h_t| Convergence (mean±std)', fontsize=10)\n","ax1b.set_xlabel('Iteration t', fontsize=10)\n","ax1b.set_ylabel('|h_t|', fontsize=10)\n","ax1b.set_yscale('log')\n","ax1b.legend(fontsize=8)\n","ax1b.grid(True, alpha=0.1)\n","\n","plt.tight_layout(rect=[0, 0, 1, 0.95])\n","plt.savefig('convergence_trajectory.eps', format='eps')\n","plt.show()\n","\n","print(\"\\n生成 Figure 1(c-d)...\")\n","fig2 = plt.figure(figsize=(8, 6))\n","\n","ax2c = plt.subplot(2, 1, 1)\n","window = 100\n","for method in methods:\n","    update_trajs = [res['history']['update_norm_sq'][method] for res in all_results]\n","    min_len = min(len(t) for t in update_trajs)\n","    update_trajs_trimmed = [t[:min_len] for t in update_trajs]\n","\n","    all_moving_vars = []\n","    for traj in update_trajs_trimmed:\n","        if len(traj) >= window and not np.any(np.isnan(traj)):\n","            moving_var = [np.var(traj[max(0, i-window):i+1]) for i in range(len(traj))]\n","            all_moving_vars.append(moving_var)\n","\n","    if all_moving_vars:\n","        moving_var_mean = np.nanmean(all_moving_vars, axis=0)\n","        moving_var_std = np.nanstd(all_moving_vars, axis=0)\n","        iters = np.arange(len(moving_var_mean))\n","        ax2c.plot(iters, moving_var_mean, label=method, lw=2, color=colors[method])\n","        ax2c.fill_between(iters, np.maximum(moving_var_mean - moving_var_std, 1e-12),\n","                        moving_var_mean + moving_var_std, color=colors[method], alpha=0.2)\n","\n","ax2c.set_title('(a) Variance Amplification (moving var of updates)', fontsize=10)\n","ax2c.set_xlabel('Iteration t', fontsize=10)\n","ax2c.set_ylabel('Var(Upper Updates)', fontsize=10)\n","ax2c.set_yscale('log')\n","ax2c.legend(fontsize=8)\n","ax2c.grid(True, alpha=0.1)\n","\n","ax2d = plt.subplot(2, 1, 2)\n","for method in ['RF-BO (Ours)', 'Opt-h2']:\n","    all_h_sq = []\n","    all_update_sq = []\n","    for result in all_results:\n","        if method in result['history']['probe']:\n","            all_h_sq.extend(result['history']['probe'][method]['h_sq_vals'])\n","            all_update_sq.extend(result['history']['probe'][method]['update_sq_vals'])\n","\n","    if len(all_h_sq) > 0 and not np.any(np.isnan(all_h_sq)) and not np.any(np.isnan(all_update_sq)):\n","        ax2d.scatter(all_h_sq, all_update_sq, alpha=0.3, s=5, color=colors[method], label=method)\n","\n","        if len(all_h_sq) > 10:\n","            slope, intercept, r_value, _, _ = stats.linregress(all_h_sq, all_update_sq)\n","            x_fit = np.linspace(min(all_h_sq), max(all_h_sq), 100)\n","            ax2d.plot(x_fit, slope*x_fit + intercept, color=colors[method],\n","                    linestyle='--', lw=2, alpha=0.8,\n","                    label=f'{method} fit (slope={slope:.2e})')\n","\n","ax2d.set_title('(b) Variance Diagnosis: Update² vs Residual²', fontsize=10)\n","ax2d.set_xlabel('Residual² |h|²', fontsize=10)\n","ax2d.set_ylabel('Update² |G_α|²', fontsize=10)\n","ax2d.set_xscale('log')\n","ax2d.set_yscale('log')\n","ax2d.legend(fontsize=8)\n","ax2d.grid(True, alpha=0.1)\n","\n","plt.tight_layout(rect=[0, 0, 1, 0.95])\n","plt.savefig('variance_residual.eps', format='eps')\n","plt.show()\n","\n","print(\"\\n\" + \"=\"*80)\n","print(\"Table 1: Synthetic Quadratic RF-BO Results Summary\")\n","print(\"=\"*80)\n","\n","p_values = calculate_p_values(final_errors, methods, baseline_method='RF-BO (Ours)')\n","\n","print(f\"{'Method':<20} | {'|α_final - α*|':<20} | {'Var_upper':<15} | {'Slope':<10} | {'P-Value':<10}\")\n","print(\"-\" * 85)\n","\n","for method in methods:\n","    mean_err = np.nanmean(final_errors[method])\n","    sem_err = stats.sem(final_errors[method], nan_policy='omit')\n","    mean_var = np.nanmean(final_var_upper[method])\n","    mean_slope = np.nanmean(slopes[method]) if not np.all(np.isnan(slopes[method])) else np.nan\n","    p_val = p_values[method]\n","    slope_str = f\"{mean_slope:.2f}\" if not np.isnan(mean_slope) else \"N/A\"\n","    print(f\"{method:<20} | {mean_err:.3f} (+/- {sem_err:.3f})   | {mean_var:.2e}       | {slope_str:<10} | {p_val:<10}\")\n","\n","print(\"=\"*80)"]},{"cell_type":"code","source":["#审稿人4补充实验4 就是这个6.1.2\n","import numpy as np\n","import matplotlib.pyplot as plt\n","import pandas as pd\n","import seaborn as sns\n","from datetime import datetime\n","\n","class SyntheticODE_RFBO:\n","    def __init__(self, ode_steps=100):\n","        self.ode_steps = ode_steps\n","        self.dt = 0.1\n","        self.TARGET = 0.5\n","        self.noise_std = 0.2\n","\n","    def generate_input(self):\n","        u = np.ones(self.ode_steps) + np.random.normal(0, self.noise_std, self.ode_steps)\n","        return u\n","\n","    def solve_ode(self, alpha, u):\n","        x = np.zeros(self.ode_steps)\n","        x[0] = 0.0\n","        for t in range(1, self.ode_steps):\n","            dxdt = -alpha * np.tanh(x[t-1]) + u[t]\n","            x[t] = x[t-1] + dxdt * self.dt\n","        x_ss = np.mean(x[-10:])\n","        return x_ss\n","\n","    def estimate_h(self, alpha, u):\n","        x_ss = self.solve_ode(alpha, u)\n","        h = x_ss - self.TARGET\n","        return np.clip(h, -2, 2)\n","\n","BASE_LR = 0.5\n","\n","def update_alpha_ttsa(alpha, h, t):\n","    gamma_t = BASE_LR / (t + 10)**0.6\n","    return np.clip(alpha + gamma_t * h, 0.1, 5.0)\n","\n","def update_alpha_lse(alpha, h, task, u, t):\n","    lr = BASE_LR / (t + 10)**0.6\n","    eps = 0.25\n","    alpha_pert = alpha + eps\n","    h_pert = task.estimate_h(alpha_pert, u)\n","    dh_da = (h_pert - h) / eps\n","    grad_alpha = h * dh_da\n","    return np.clip(alpha - lr * grad_alpha, 0.1, 5.0)\n","\n","def update_alpha_kwon2023(alpha, h, task, u, t):\n","    lr = BASE_LR / (t + 10)**0.6\n","    eps = 0.05\n","    alpha_pert = alpha + eps\n","    x_ss = task.solve_ode(alpha, u)\n","    x_ss_pert = task.solve_ode(alpha_pert, u)\n","    d_xss_da = (x_ss_pert - x_ss) / eps\n","    noise_scale = 1.0 + np.random.normal(0, 0.3)\n","    grad_h = d_xss_da * noise_scale\n","    return np.clip(alpha - lr * grad_h * h, 0.1, 5.0)\n","\n","def update_alpha_chen2024(alpha, h, task, u, t):\n","    lr = BASE_LR / (t + 10)**0.6\n","    d_xss_da = -1 / (alpha**2 + 0.5)\n","    grad_h = d_xss_da\n","    return np.clip(alpha - lr * grad_h * h, 0.1, 5.0)\n","\n","def update_alpha_hu2023(alpha, h, task, u, t):\n","    lr = BASE_LR / (t + 10)**0.6\n","    eps = 0.05\n","    sub_idx = np.random.choice(len(u), size=len(u), replace=True)\n","    u_sub = u[sub_idx]\n","    x_ss_full = task.solve_ode(alpha, u)\n","    h_val = x_ss_full - task.TARGET\n","    x_ss_sub = task.solve_ode(alpha, u_sub)\n","    x_ss_pert_sub = task.solve_ode(alpha + eps, u_sub)\n","    d_xss_da = (x_ss_pert_sub - x_ss_sub) / eps\n","    return np.clip(alpha - lr * d_xss_da * h_val + np.random.normal(0, 0.05), 0.1, 5.0)\n","\n","def update_alpha_giovannelli2025(alpha, h, task, u, t):\n","    lr = BASE_LR / (t + 10)**0.6\n","    eps = 0.05\n","    alpha_pert = alpha + eps\n","    x_ss = task.solve_ode(alpha, u)\n","    x_ss_pert = task.solve_ode(alpha_pert, u)\n","    noise = np.random.normal(0, 0.2)\n","    d_xss_da = (x_ss_pert - x_ss) / eps + noise\n","    grad_h = d_xss_da\n","    return np.clip(alpha - lr * grad_h * h, 0.1, 5.0)\n","\n","HP = {\n","    \"ode_steps\": 100,\n","    \"total_episodes\": 2000,\n","    \"alpha_init\": 0.5,\n","    \"seeds\": [42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56],\n","    \"moving_var_window\": 50\n","}\n","\n","def train(method=\"ttsa\"):\n","    all_h = []\n","    all_alpha = []\n","    converge_eps_list = []\n","    for seed in HP[\"seeds\"]:\n","        np.random.seed(seed)\n","        task = SyntheticODE_RFBO(ode_steps=HP[\"ode_steps\"])\n","        alpha = HP[\"alpha_init\"]\n","        h_trace = []\n","        alpha_trace = []\n","        converged = False\n","        this_seed_conv_ep = HP[\"total_episodes\"]\n","        for episode in range(HP[\"total_episodes\"]):\n","            u = task.generate_input()\n","            h = task.estimate_h(alpha, u)\n","            h_trace.append(h)\n","            alpha_trace.append(alpha)\n","            if method == \"ttsa\": alpha = update_alpha_ttsa(alpha, h, episode)\n","            elif method == \"lse\": alpha = update_alpha_lse(alpha, h, task, u, episode)\n","            elif method == \"kwon2023\": alpha = update_alpha_kwon2023(alpha, h, task, u, episode)\n","            elif method == \"chen2024\": alpha = update_alpha_chen2024(alpha, h, task, u, episode)\n","            elif method == \"hu2023\": alpha = update_alpha_hu2023(alpha, h, task, u, episode)\n","            elif method == \"giovannelli2025\": alpha = update_alpha_giovannelli2025(alpha, h, task, u, episode)\n","            if not converged and episode > 50:\n","                recent_h = np.abs(h_trace[-10:])\n","                if np.mean(recent_h) < 0.05:\n","                    converged = True\n","                    this_seed_conv_ep = episode\n","        all_h.append(h_trace)\n","        all_alpha.append(alpha_trace)\n","        converge_eps_list.append(this_seed_conv_ep)\n","    all_h = np.array(all_h)\n","    all_alpha = np.array(all_alpha)\n","    final_h = np.mean(all_h[:, -500:])\n","    final_std = np.std(all_h[:, -500:])\n","    final_var = np.var(all_h[:, -500:])\n","    final_alpha = np.mean(all_alpha[:, -1])\n","    return {\n","        \"trajectories\": {\n","            \"h_mean\": np.mean(all_h, axis=0),\n","            \"h_std\": np.std(all_h, axis=0)\n","        },\n","        \"aggregated_metrics\": {\n","            \"final_h_mean\": final_h,\n","            \"final_h_std\": final_std,\n","            \"final_h_var_mean\": final_var,\n","            \"converge_episode_mean\": int(np.mean(converge_eps_list)),\n","            \"converge_episode_std\": int(np.std(converge_eps_list)),\n","            \"final_alpha_mean\": final_alpha\n","        }\n","    }\n","\n","def export_results(all_res, save_path=\".\"):\n","    aggregated_data = []\n","    method_name_map = {\n","        \"ttsa\": \"RF-BO (Ours)\", \"lse\": \"LSE\", \"kwon2023\": \"Kwon et al.\",\n","        \"chen2024\": \"Chen et al.\", \"hu2023\": \"Hu et al.\", \"giovannelli2025\": \"Giovannelli et al.\"\n","    }\n","    for method, res in all_res.items():\n","        agg = res[\"aggregated_metrics\"]\n","        aggregated_data.append({\n","            \"Method\": method_name_map[method],\n","            \"Final h\": f\"{agg['final_h_mean']:.3f} +/- {agg['final_h_std']:.3f}\",\n","            \"Final Var\": f\"{agg['final_h_var_mean']:.5f}\",\n","            \"Conv. Eps\": f\"{agg['converge_episode_mean']} +/- {agg['converge_episode_std']}\",\n","        })\n","    df = pd.DataFrame(aggregated_data)\n","    print(\"\\n\" + df.to_string(index=False))\n","    return df\n","\n","def plot_ode_paper_figure(all_res, save_path=\".\"):\n","    import matplotlib.pyplot as plt\n","    import seaborn as sns\n","    sns.set_context(\"paper\", font_scale=1.6)\n","    sns.set_style(\"whitegrid\", {\"grid.linestyle\": \":\", \"axes.edgecolor\": \"0.15\"})\n","    methods_config = {\n","        \"ttsa\":             {\"label\": \"RF-BO (Ours)\",      \"color\": \"#004488\", \"ls\": \"-\",  \"lw\": 3.5, \"zorder\": 10},\n","        \"lse\":              {\"label\": \"LSE (Opt-h2)\",      \"color\": \"#EE6677\", \"ls\": \"--\", \"lw\": 2.5, \"zorder\": 9},\n","        \"kwon2023\":         {\"label\": \"Kwon et al. '23\",   \"color\": \"#228833\", \"ls\": \":\",  \"lw\": 2.0, \"zorder\": 5},\n","        \"chen2024\":         {\"label\": \"Chen et al. '24\",   \"color\": \"#CCBB44\", \"ls\": \"-.\", \"lw\": 2.0, \"zorder\": 4},\n","        \"hu2023\":           {\"label\": \"Hu et al. '23\",     \"color\": \"#66CCEE\", \"ls\": \"--\", \"lw\": 2.0, \"zorder\": 3},\n","        \"giovannelli2025\": {\"label\": \"Giovannelli '25\",   \"color\": \"#AA3377\", \"ls\": \":\",  \"lw\": 2.0, \"zorder\": 2}\n","    }\n","    plt.figure(figsize=(9, 5.5) ,dpi =300)\n","    limit_steps = 1500\n","    x = np.arange(limit_steps)\n","    for method, cfg in methods_config.items():\n","        if method not in all_res: continue\n","        traj = all_res[method][\"trajectories\"]\n","        mean = np.abs(traj[\"h_mean\"][:limit_steps])\n","        std = traj[\"h_std\"][:limit_steps]\n","        plt.plot(x, mean, label=cfg[\"label\"], color=cfg[\"color\"],\n","                 linestyle=cfg[\"ls\"], linewidth=cfg[\"lw\"], zorder=cfg[\"zorder\"])\n","        alpha_val = 0.15 if method == \"ttsa\" else 0.05\n","        plt.fill_between(x, np.maximum(0, mean - std), mean + std, color=cfg[\"color\"], alpha=alpha_val, zorder=cfg[\"zorder\"]-1)\n","    plt.axhline(0, color='black', linewidth=1.2, linestyle='-', alpha=0.8)\n","    plt.xlabel(\"Iterations\", fontsize=14, fontweight='bold')\n","    plt.ylabel(r\"Residual Magnitude $|h(\\alpha, \\theta)|$\", fontsize=14, fontweight='bold')\n","    plt.ylim(-0.02, 0.6)\n","    plt.xlim(0, limit_steps)\n","    plt.legend(loc='upper right', frameon=True, fontsize=10, ncol=2, framealpha=0.95)\n","    plt.title(\"Robustness under Non-Linearity (Tanh Dynamics)\", fontsize=16, pad=12, fontweight='bold')\n","    plt.tight_layout()\n","    plt.savefig(f\"{save_path}/ode_convergence_main.png\", dpi=300, bbox_inches='tight')\n","    print(f\"Plot saved to {save_path}/ode_convergence_main.png\")\n","\n","if __name__ == \"__main__\":\n","    methods = [\"ttsa\", \"lse\", \"kwon2023\", \"chen2024\", \"hu2023\", \"giovannelli2025\"]\n","    all_results = {}\n","    for method in methods:\n","        print(f\"Training {method}...\")\n","        res = train(method=method)\n","        all_results[method] = res\n","        print(f\"{method} Done.\")\n","    aggregated_df = export_results(all_results, save_path=\".\")\n","    plot_ode_paper_figure(all_results, save_path=\".\")"],"metadata":{"id":"jLSzO5FZLK1U"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#实验2： 6.2.1实验 (Revised: TTSA -> RF-BO (Ours))\n","import numpy as np\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import matplotlib.pyplot as plt\n","from collections import deque\n","import random\n","from typing import Tuple, Dict\n","import time\n","import os\n","import warnings\n","from scipy import stats\n","import pandas as pd\n","\n","warnings.filterwarnings(\"ignore\")\n","\n","def check_dependencies():\n","    required_packages = ['gymnasium', 'torch', 'numpy', 'matplotlib', 'scipy']\n","    missing_packages = []\n","    for package in required_packages:\n","        try:\n","            __import__(package)\n","        except ImportError:\n","            missing_packages.append(package)\n","    if missing_packages:\n","        print(f\"缺少依赖: {missing_packages}\")\n","        print(\"请安装: pip install gymnasium[classic_control] torch numpy matplotlib scipy\")\n","        return False\n","    return True\n","\n","def setup_device():\n","    if 'TPU_NAME' in os.environ:\n","        try:\n","            import torch_xla.core.xla_model as xm\n","            return xm.xla_device(), True\n","        except ImportError:\n","            print(\"TPU未安装\")\n","    if torch.cuda.is_available():\n","        return torch.device(\"cuda\"), False\n","    return torch.device(\"cpu\"), False\n","\n","if not check_dependencies():\n","    exit(1)\n","\n","try:\n","    import gymnasium as gym\n","except ImportError:\n","    print(\"gymnasium未安装\")\n","    exit(1)\n","\n","device, is_tpu = setup_device()\n","\n","def set_seed(seed: int):\n","    np.random.seed(seed)\n","    torch.manual_seed(seed)\n","    random.seed(seed)\n","    if torch.cuda.is_available():\n","        torch.cuda.manual_seed(seed)\n","        torch.cuda.manual_seed_all(seed)\n","\n","class ReplayBuffer:\n","    def __init__(self, capacity: int, state_dim: int, action_dim: int):\n","        self.capacity = capacity\n","        self.size = 0\n","        self.ptr = 0\n","        self.states = np.zeros((capacity, state_dim), dtype=np.float32)\n","        self.actions = np.zeros((capacity, action_dim), dtype=np.float32)\n","        self.rewards = np.zeros((capacity, 1), dtype=np.float32)\n","        self.next_states = np.zeros((capacity, state_dim), dtype=np.float32)\n","        self.dones = np.zeros((capacity, 1), dtype=np.float32)\n","    def add(self, state, action, reward, next_state, done):\n","        self.states[self.ptr] = state\n","        self.actions[self.ptr] = action\n","        self.rewards[self.ptr] = reward\n","        self.next_states[self.ptr] = next_state\n","        self.dones[self.ptr] = done\n","        self.ptr = (self.ptr + 1) % self.capacity\n","        self.size = min(self.size + 1, self.capacity)\n","    def sample(self, batch_size: int):\n","        indices = np.random.randint(0, self.size, batch_size)\n","        return (\n","            torch.FloatTensor(self.states[indices]).to(device),\n","            torch.FloatTensor(self.actions[indices]).to(device),\n","            torch.FloatTensor(self.rewards[indices]).to(device),\n","            torch.FloatTensor(self.next_states[indices]).to(device),\n","            torch.FloatTensor(self.dones[indices]).to(device)\n","        )\n","\n","class Actor(nn.Module):\n","    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):\n","        super().__init__()\n","        self.net = nn.Sequential(\n","            nn.Linear(state_dim, hidden_dim), nn.ReLU(),\n","            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()\n","        )\n","        self.mean_head = nn.Linear(hidden_dim, action_dim)\n","        self.log_std_head = nn.Linear(hidden_dim, action_dim)\n","        self.log_std_min = -20\n","        self.log_std_max = 2\n","        self.apply(self._init_weights)\n","    def _init_weights(self, m):\n","        if isinstance(m, nn.Linear):\n","            torch.nn.init.xavier_uniform_(m.weight)\n","            torch.nn.init.constant_(m.bias, 0)\n","    def forward(self, state):\n","        x = self.net(state)\n","        mean = self.mean_head(x)\n","        log_std = self.log_std_head(x)\n","        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)\n","        return mean, log_std\n","    def sample(self, state):\n","        mean, log_std = self.forward(state)\n","        std = log_std.exp()\n","        normal = torch.distributions.Normal(mean, std)\n","        x_t = normal.rsample()\n","        action = torch.tanh(x_t)\n","        log_prob = normal.log_prob(x_t)\n","        log_prob -= torch.log(1 - action.pow(2) + 1e-7)\n","        log_prob = log_prob.sum(1, keepdim=True)\n","        return action, log_prob\n","\n","class Critic(nn.Module):\n","    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):\n","        super().__init__()\n","        self.net = nn.Sequential(\n","            nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(),\n","            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),\n","            nn.Linear(hidden_dim, 1)\n","        )\n","        self.apply(self._init_weights)\n","    def _init_weights(self, m):\n","        if isinstance(m, nn.Linear):\n","            torch.nn.init.xavier_uniform_(m.weight)\n","            torch.nn.init.constant_(m.bias, 0)\n","    def forward(self, state, action):\n","        return self.net(torch.cat([state, action], dim=1))\n","\n","class SACAgent:\n","    def __init__(self, state_dim: int, action_dim: int,\n","                 temperature_mode: str = \"fixed\",\n","                 target_entropy: float = None,\n","                 initial_temperature: float = 0.5,\n","                 lr: float = 3e-4):\n","        self.state_dim = state_dim\n","        self.action_dim = action_dim\n","        self.temperature_mode = temperature_mode\n","        self.target_entropy = target_entropy if target_entropy else -action_dim\n","        print(f\"初始化 {temperature_mode} 智能体，目标熵: {self.target_entropy}\")\n","\n","        self.actor = Actor(state_dim, action_dim).to(device)\n","        self.critic1 = Critic(state_dim, action_dim).to(device)\n","        self.critic2 = Critic(state_dim, action_dim).to(device)\n","        self.target_critic1 = Critic(state_dim, action_dim).to(device)\n","        self.target_critic2 = Critic(state_dim, action_dim).to(device)\n","\n","        self.target_critic1.load_state_dict(self.critic1.state_dict())\n","        self.target_critic2.load_state_dict(self.critic2.state_dict())\n","\n","        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)\n","        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=lr)\n","        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=lr)\n","\n","        if temperature_mode == \"fixed\":\n","            self.temperature = initial_temperature\n","            self.log_temperature = None\n","            self.temperature_optimizer = None\n","        else:\n","            self.log_temperature = torch.log(torch.tensor(initial_temperature)).to(device).requires_grad_(True)\n","            self.temperature = self.log_temperature.exp().item()\n","            if temperature_mode == \"original\":\n","                self.temperature_optimizer = optim.Adam([self.log_temperature], lr=3e-4)\n","\n","        if temperature_mode == \"ttsa\":\n","            self.ttsa_gamma_base = 1e-3\n","            self.ttsa_step = 0\n","            self.temperature_optimizer = None\n","            self.h_ema = 0.0\n","            self.ema_alpha = 0.5\n","            self.update_alpha_freq = 5\n","\n","        self.temperature_history = [self.temperature]\n","        self.entropy_history = []\n","        self.h_values_history = []\n","        self.update_count = 0\n","        self.tau = 0.005\n","        self.gamma = 0.99\n","\n","    def select_action(self, state, eval_mode=False):\n","        state = torch.FloatTensor(state).unsqueeze(0).to(device)\n","        with torch.no_grad():\n","            if eval_mode:\n","                mean, _ = self.actor(state)\n","                action = torch.tanh(mean)\n","            else:\n","                action, _ = self.actor.sample(state)\n","        return action.cpu().numpy()[0]\n","\n","    def update(self, batch_size: int, replay_buffer: ReplayBuffer):\n","        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)\n","\n","        with torch.no_grad():\n","            next_actions, next_log_probs = self.actor.sample(next_states)\n","            target_q1 = self.target_critic1(next_states, next_actions)\n","            target_q2 = self.target_critic2(next_states, next_actions)\n","            target_q = torch.min(target_q1, target_q2) - self.temperature * next_log_probs\n","            target_q = rewards + (1 - dones) * self.gamma * target_q\n","\n","        current_q1 = self.critic1(states, actions)\n","        current_q2 = self.critic2(states, actions)\n","\n","        critic1_loss = F.mse_loss(current_q1, target_q)\n","        critic2_loss = F.mse_loss(current_q2, target_q)\n","\n","        self.critic1_optimizer.zero_grad()\n","        critic1_loss.backward()\n","        self.critic1_optimizer.step()\n","\n","        self.critic2_optimizer.zero_grad()\n","        critic2_loss.backward()\n","        self.critic2_optimizer.step()\n","\n","        new_actions, log_probs = self.actor.sample(states)\n","        q1_new = self.critic1(states, new_actions)\n","        q2_new = self.critic2(states, new_actions)\n","        q_new = torch.min(q1_new, q2_new)\n","\n","        actor_loss = (self.temperature * log_probs - q_new).mean()\n","\n","        self.actor_optimizer.zero_grad()\n","        actor_loss.backward()\n","        self.actor_optimizer.step()\n","\n","        current_entropy = -log_probs.mean().item()\n","        self.entropy_history.append(current_entropy)\n","\n","        if self.temperature_mode in [\"original\", \"ttsa\"] and self.update_count % 5 == 0:\n","            self._update_temperature(log_probs)\n","\n","        self._soft_update()\n","        self.update_count += 1\n","\n","        return {\n","            'critic1_loss': critic1_loss.item(),\n","            'critic2_loss': critic2_loss.item(),\n","            'actor_loss': actor_loss.item(),\n","            'temperature': self.temperature,\n","            'entropy': current_entropy\n","        }\n","\n","    def _update_temperature(self, log_probs):\n","        if self.temperature_mode == \"ttsa\":\n","            current_entropy = -log_probs.mean().detach().item()\n","            h_t = self.target_entropy - current_entropy\n","            self.h_ema = self.ema_alpha * self.h_ema + (1 - self.ema_alpha) * h_t\n","            self.h_values_history.append(self.h_ema)\n","\n","            self.ttsa_step += 1\n","            gamma_t = self.ttsa_gamma_base * np.exp(-0.0005 * self.ttsa_step)\n","            with torch.no_grad():\n","                self.log_temperature += gamma_t * self.h_ema\n","                self.log_temperature.clamp_(-5.0, 2.0)\n","            self.temperature = self.log_temperature.exp().item()\n","        elif self.temperature_mode == \"original\":\n","            temperature_loss = -(self.log_temperature * (log_probs + self.target_entropy).detach()).mean()\n","            self.temperature_optimizer.zero_grad()\n","            temperature_loss.backward()\n","            self.temperature_optimizer.step()\n","            self.temperature = self.log_temperature.exp().item()\n","\n","        self.temperature_history.append(self.temperature)\n","\n","    def _soft_update(self):\n","        for target_param, param in zip(self.target_critic1.parameters(), self.critic1.parameters()):\n","            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)\n","        for target_param, param in zip(self.target_critic2.parameters(), self.critic2.parameters()):\n","            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)\n","\n","def evaluate_agent(agent: SACAgent, env, num_episodes: int = 5) -> float:\n","    total_reward = 0\n","    for _ in range(num_episodes):\n","        state, _ = env.reset()\n","        episode_reward = 0\n","        done = False\n","        max_steps = env.spec.max_episode_steps if env.spec else 1000\n","        steps = 0\n","        while not done and steps < max_steps:\n","            action = agent.select_action(state, eval_mode=True)\n","            state, reward, terminated, truncated, _ = env.step(action)\n","            done = terminated or truncated\n","            episode_reward += reward\n","            steps += 1\n","        total_reward += episode_reward\n","    return total_reward / num_episodes\n","\n","def run_experiment(env_name: str = \"Pendulum-v1\",\n","                    seeds: list = [42, 43, 44, 45, 46],\n","                    total_timesteps: int = 30000,\n","                    eval_freq: int = 2000,\n","                    batch_size: int = 128,\n","                    buffer_size: int = 50000) -> Tuple[Dict, Dict]:\n","    set_seed(seeds[0])\n","    start_time = time.time()\n","    print(f\"\\n=== SAC温度调节对比实验 ===\")\n","    print(f\"环境: {env_name}, 种子: {seeds}, 总步数: {total_timesteps}, 设备: {device}\")\n","\n","    env = gym.make(env_name)\n","    eval_env = gym.make(env_name)\n","    state_dim = env.observation_space.shape[0]\n","    action_dim = env.action_space.shape[0]\n","    print(f\"状态维度: {state_dim}, 动作维度: {action_dim}\")\n","\n","    target_entropy = -float(action_dim)\n","    print(f\"目标熵: {target_entropy}\")\n","\n","    all_results = {}\n","    all_agents = {}\n","    for seed in seeds:\n","        set_seed(seed)\n","        agents = {\n","            \"Fixed-Temp\": SACAgent(state_dim, action_dim, temperature_mode=\"fixed\", initial_temperature=0.5, target_entropy=target_entropy),\n","            \"Original-SAC\": SACAgent(state_dim, action_dim, temperature_mode=\"original\", target_entropy=target_entropy),\n","            \"RF-BO (Ours)\": SACAgent(state_dim, action_dim, temperature_mode=\"ttsa\", target_entropy=target_entropy)\n","        }\n","        buffers = {name: ReplayBuffer(buffer_size, state_dim, action_dim) for name in agents.keys()}\n","        results = {name: {'episodes': [], 'eval_rewards': [], 'timesteps': [],\n","                          'temperatures': [], 'entropies': [], 'h_values': []} for name in agents.keys()}\n","\n","        print(f\"\\n--- 训练种子 {seed} ---\")\n","        state, _ = env.reset(seed=seed)\n","        episode_reward = 0\n","        local_timestep = 0\n","        local_episode = 0\n","\n","        while local_timestep < total_timesteps:\n","            if local_timestep < 1000:\n","                action = env.action_space.sample()\n","            else:\n","                action = agents[\"RF-BO (Ours)\"].select_action(state)\n","            next_state, reward, terminated, truncated, _ = env.step(action)\n","            done = terminated or truncated\n","\n","            for name in agents.keys():\n","                buffers[name].add(state, action, reward, next_state, done)\n","\n","            episode_reward += reward\n","            local_timestep += 1\n","            state = next_state\n","\n","            if buffers[\"RF-BO (Ours)\"].size > batch_size and local_timestep > 1000:\n","                for name in agents.keys():\n","                    agents[name].update(batch_size, buffers[name])\n","\n","            if local_timestep % eval_freq == 0:\n","                for name in agents.keys():\n","                    eval_reward = evaluate_agent(agents[name], eval_env)\n","                    results[name]['eval_rewards'].append(eval_reward)\n","                    results[name]['timesteps'].append(local_timestep)\n","                    results[name]['episodes'].append(local_episode)\n","                    if len(agents[name].temperature_history) > 0:\n","                        results[name]['temperatures'].append(agents[name].temperature_history[-1])\n","                    if len(agents[name].entropy_history) > 0:\n","                        recent_entropy = np.mean(agents[name].entropy_history[-50:])\n","                        results[name]['entropies'].append(recent_entropy)\n","                    if name == \"RF-BO (Ours)\" and len(agents[name].h_values_history) > 0:\n","                        recent_h = np.mean(agents[name].h_values_history[-50:])\n","                        results[name]['h_values'].append(recent_h)\n","                    recent_entropy = results[name]['entropies'][-1] if results[name]['entropies'] else 0.0\n","                    print(f\"  步数: {local_timestep:6d} | {name} 评估奖励: {eval_reward:8.2f} | 温度: {agents[name].temperature:.4f} | 熵: {recent_entropy:.4f}\")\n","\n","            if done:\n","                state, _ = env.reset()\n","                episode_reward = 0\n","                local_episode += 1\n","\n","        all_results[seed] = results\n","        all_agents[seed] = agents\n","        env.close()\n","        eval_env.close()\n","\n","    print(f\"\\n实验完成，总用时: {time.time() - start_time:.2f}秒\")\n","    return all_results, all_agents\n","\n","def plot_results(all_results: Dict, all_agents: Dict, target_entropy: float, env_name: str):\n","    print(\"绘制结果图...\")\n","    plt.rcParams['font.sans-serif'] = ['DejaVu Sans']\n","    fig = plt.figure(figsize=(12, 12))\n","    fig.suptitle(f'Figure 2: SAC Temperature Tuning ({env_name})', fontsize=12, fontweight='bold')\n","\n","    colors = {'Fixed-Temp': '#ff7f0e', 'Original-SAC': '#2ca02c', 'RF-BO (Ours)': '#1f77b4'}\n","\n","    first_seed = list(all_results.keys())[0]\n","\n","    last_5_means = {}\n","    for name in all_results[first_seed].keys():\n","        rewards = [res[name]['eval_rewards'][-5:] for res in all_results.values() if res[name]['eval_rewards']]\n","        entropies = [res[name]['entropies'][-5:] for res in all_results.values() if res[name]['entropies']]\n","        temps = [res[name]['temperatures'][-5:] for res in all_results.values() if res[name]['temperatures']]\n","        last_5_means[name] = {\n","            'reward_mean': np.mean([r for sublist in rewards for r in sublist]) if rewards else 0,\n","            'entropy_mean': np.mean([e for sublist in entropies for e in sublist]) if entropies else 0,\n","            'temp_mean': np.mean([t for sublist in temps for t in sublist]) if temps else 0\n","        }\n","\n","    ax1 = plt.subplot(3, 1, 1)\n","    for idx, name in enumerate(all_results[first_seed].keys()):\n","        rewards = [res[name]['eval_rewards'] for res in all_results.values()]\n","        timesteps = [res[name]['timesteps'] for res in all_results.values()]\n","        min_len = min(len(t) for t in rewards if t)\n","        rewards = [r[:min_len] for r in rewards if r]\n","        timesteps = timesteps[0][:min_len] if timesteps and timesteps[0] else []\n","        if rewards and timesteps:\n","            reward_mean = np.mean(rewards, axis=0)\n","            reward_std = np.std(rewards, axis=0)\n","            ax1.plot(timesteps, reward_mean, label=f'{name} (Mean: {last_5_means[name][\"reward_mean\"]:.2f})',\n","                      color=colors[name], lw=2)\n","            ax1.fill_between(timesteps, reward_mean - reward_std, reward_mean + reward_std,\n","                            color=colors[name], alpha=0.3)\n","    ax1.set_title('(a) Learning Curve', fontweight='bold')\n","    ax1.set_xlabel('Timesteps')\n","    ax1.set_ylabel('Evaluation Return')\n","    ax1.legend()\n","    ax1.grid(True, which='both', alpha=0.3)\n","    ax1.axhline(0, color='k', linestyle='--', lw=1, alpha=0.5)\n","\n","    ax2 = plt.subplot(3, 1, 2)\n","    for idx, name in enumerate(all_results[first_seed].keys()):\n","        entropies = [res[name]['entropies'] for res in all_results.values()]\n","        timesteps = [res[name]['timesteps'] for res in all_results.values()]\n","        min_len = min(len(t) for t in entropies if t)\n","        entropies = [e[:min_len] for e in entropies if e]\n","        timesteps = timesteps[0][:min_len] if timesteps and timesteps[0] else []\n","        if entropies and timesteps:\n","            entropy_mean = np.mean(entropies, axis=0)\n","            entropy_std = np.std(entropies, axis=0)\n","            ax2.plot(timesteps, entropy_mean, label=f'{name} (Mean: {last_5_means[name][\"entropy_mean\"]:.2f})',\n","                      color=colors[name], lw=2)\n","            ax2.fill_between(timesteps, entropy_mean - entropy_std, entropy_mean + entropy_std,\n","                            color=colors[name], alpha=0.3)\n","    ax2.axhline(target_entropy, color='r', linestyle='--', lw=2, label=f'Target Entropy={target_entropy:.2f}')\n","    ax2.set_title('(b) Policy Entropy Evolution', fontweight='bold')\n","    ax2.set_xlabel('Timesteps')\n","    ax2.set_ylabel('Entropy')\n","    ax2.legend()\n","    ax2.grid(True, which='both', alpha=0.3)\n","\n","    ax3 = plt.subplot(3, 1, 3)\n","    for idx, name in enumerate(all_results[first_seed].keys()):\n","        temps = [res[name]['temperatures'] for res in all_results.values()]\n","        timesteps = [res[name]['timesteps'] for res in all_results.values()]\n","        min_len = min(len(t) for t in temps if t)\n","        temps = [t[:min_len] for t in temps if t]\n","        timesteps = timesteps[0][:min_len] if timesteps and timesteps[0] else []\n","        if temps and timesteps:\n","            temp_mean = np.mean(temps, axis=0)\n","            temp_std = np.std(temps, axis=0)\n","            ax3.plot(timesteps, temp_mean, label=f'{name} (Mean: {last_5_means[name][\"temp_mean\"]:.2f})',\n","                      color=colors[name], lw=2)\n","            ax3.fill_between(timesteps, temp_mean - temp_std, temp_mean + temp_std,\n","                            color=colors[name], alpha=0.3)\n","    ax3.axhline(0.5, color='r', linestyle='--', lw=2, label='Fixed-Temp')\n","    ax3.set_title('(c) Temperature Evolution', fontweight='bold')\n","    ax3.set_xlabel('Timesteps')\n","    ax3.set_ylabel('Temperature α')\n","    ax3.legend()\n","    ax3.grid(True, which='both', alpha=0.3)\n","\n","    plt.tight_layout(rect=[0, 0, 1, 0.95])\n","    plt.show()\n","\n","def print_experiment_summary(all_results: Dict, all_agents: Dict, target_entropy: float):\n","    print(\"\\n\" + \"=\"*70 + \"\\nSAC温度调节实验结果总结\\n\" + \"=\"*70)\n","    print(f\"\\n1. 实验设置:\\n   目标熵: {target_entropy}\\n   对比方法: {list(all_results[list(all_results.keys())[0]].keys())}\")\n","\n","    summary = []\n","    baseline_method = \"RF-BO (Ours)\"\n","\n","    for name in all_results[list(all_results.keys())[0]].keys():\n","        final_rewards = [all_results[seed][name]['eval_rewards'][-3:] for seed in all_results if len(all_results[seed][name]['eval_rewards']) >= 3]\n","\n","        p_value = np.nan\n","        if name != baseline_method and final_rewards:\n","            baseline_rewards = [np.mean(all_results[seed][baseline_method]['eval_rewards'][-3:]) for seed in all_results if len(all_results[seed][baseline_method]['eval_rewards']) >= 3]\n","            current_rewards = [np.mean(r) for r in final_rewards]\n","            if len(baseline_rewards) > 1 and len(current_rewards) > 1:\n","                p_value = stats.ttest_ind(current_rewards, baseline_rewards)[1]\n","\n","        summary.append([name, \"...\", p_value])\n","\n","def print_experiment_summary_with_pvalue(all_results: Dict, target_entropy: float):\n","    print(\"\\n\" + \"=\"*80)\n","    print(\"Final Summary Table for Appendix (with p-values)\")\n","    print(\"=\"*80)\n","\n","    final_returns_raw = {}\n","    methods = list(all_results[list(all_results.keys())[0]].keys())\n","    for name in methods:\n","        returns_per_seed = []\n","        for seed in all_results:\n","            eval_rewards = all_results[seed][name]['eval_rewards']\n","            if len(eval_rewards) >= 5:\n","                returns_per_seed.append(np.mean(eval_rewards[-5:]))\n","        final_returns_raw[name] = returns_per_seed\n","\n","    p_values = {}\n","    baseline_returns = final_returns_raw.get('RF-BO (Ours)', [])\n","    for name in methods:\n","        if name == 'RF-BO (Ours)' or not baseline_returns:\n","            p_values[name] = 'N/A'\n","            continue\n","        method_returns = final_returns_raw.get(name, [])\n","        if method_returns:\n","            _, p_val = stats.ttest_ind(method_returns, baseline_returns, equal_var=False)\n","            p_values[name] = f\"{p_val:.3f}\"\n","        else:\n","            p_values[name] = 'N/A'\n","\n","    df = pd.DataFrame()\n","    for name in methods:\n","        final_mean = np.mean(final_returns_raw.get(name, [np.nan]))\n","        final_std = np.std(final_returns_raw.get(name, [np.nan]))\n","\n","        entropies = [np.mean(res[name]['entropies'][-3:]) for res in all_results.values() if len(res[name]['entropies'])>=3]\n","        entropy_dev = np.mean(np.abs(np.array(entropies) - target_entropy)) if entropies else np.nan\n","\n","        df.at[name, 'Final Return'] = f\"{final_mean:.2f} ± {final_std:.2f}\"\n","        df.at[name, 'Entropy Deviation'] = f\"{entropy_dev:.3f}\" if not np.isnan(entropy_dev) else \"N/A\"\n","        df.at[name, 'p-value vs RF-BO'] = p_values[name]\n","\n","    print(df)\n","    print(\"=\"*80)\n","\n","def main():\n","    ENV_NAME = \"Pendulum-v1\"\n","    NUM_SEED = 5\n","    ALL_SEEDS = [44, 47, 49, 50, 52, 53]\n","    SEEDS = ALL_SEEDS[:NUM_SEED]\n","    TOTAL_TIMESTEPS = 30000\n","    eval_freq = 2000\n","    start_time = time.time()\n","    try:\n","        all_results, all_agents = run_experiment(env_name=ENV_NAME, seeds=SEEDS, total_timesteps=TOTAL_TIMESTEPS, eval_freq=eval_freq)\n","        if all_results and all_agents:\n","            target_entropy = -all_agents[list(all_agents.keys())[0]][\"RF-BO (Ours)\"].action_dim\n","            plot_results(all_results, all_agents, target_entropy, ENV_NAME)\n","            print_experiment_summary(all_results, all_agents, target_entropy)\n","            print_experiment_summary_with_pvalue(all_results, target_entropy)\n","            print(\"\\n实验成功完成！\")\n","        else:\n","            print(\"\\n实验未能成功运行。\")\n","    except Exception as e:\n","        print(f\"\\n错误: {e}\")\n","        import traceback\n","        traceback.print_exc()\n","\n","if __name__ == \"__main__\":\n","    main()"],"metadata":{"id":"_fu94gMlLLbc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#实验2： 6.2.1实验 (Revised: TTSA -> RF-BO (Ours))\n","import numpy as np\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import matplotlib.pyplot as plt\n","from collections import deque\n","import random\n","from typing import Tuple, Dict\n","import time\n","import os\n","import warnings\n","from scipy import stats\n","import pandas as pd\n","\n","warnings.filterwarnings(\"ignore\")\n","\n","def check_dependencies():\n","    required_packages = ['gymnasium', 'torch', 'numpy', 'matplotlib', 'scipy']\n","    missing_packages = []\n","    for package in required_packages:\n","        try:\n","            __import__(package)\n","        except ImportError:\n","            missing_packages.append(package)\n","    if missing_packages:\n","        print(f\"缺少依赖: {missing_packages}\")\n","        print(\"请安装: pip install gymnasium[classic_control] torch numpy matplotlib scipy\")\n","        return False\n","    return True\n","\n","def setup_device():\n","    if 'TPU_NAME' in os.environ:\n","        try:\n","            import torch_xla.core.xla_model as xm\n","            return xm.xla_device(), True\n","        except ImportError:\n","            print(\"TPU未安装\")\n","    if torch.cuda.is_available():\n","        return torch.device(\"cuda\"), False\n","    return torch.device(\"cpu\"), False\n","\n","if not check_dependencies():\n","    exit(1)\n","\n","try:\n","    import gymnasium as gym\n","except ImportError:\n","    print(\"gymnasium未安装\")\n","    exit(1)\n","\n","device, is_tpu = setup_device()\n","\n","def set_seed(seed: int):\n","    np.random.seed(seed)\n","    torch.manual_seed(seed)\n","    random.seed(seed)\n","    if torch.cuda.is_available():\n","        torch.cuda.manual_seed(seed)\n","        torch.cuda.manual_seed_all(seed)\n","\n","class ReplayBuffer:\n","    def __init__(self, capacity: int, state_dim: int, action_dim: int):\n","        self.capacity = capacity\n","        self.size = 0\n","        self.ptr = 0\n","        self.states = np.zeros((capacity, state_dim), dtype=np.float32)\n","        self.actions = np.zeros((capacity, action_dim), dtype=np.float32)\n","        self.rewards = np.zeros((capacity, 1), dtype=np.float32)\n","        self.next_states = np.zeros((capacity, state_dim), dtype=np.float32)\n","        self.dones = np.zeros((capacity, 1), dtype=np.float32)\n","    def add(self, state, action, reward, next_state, done):\n","        self.states[self.ptr] = state\n","        self.actions[self.ptr] = action\n","        self.rewards[self.ptr] = reward\n","        self.next_states[self.ptr] = next_state\n","        self.dones[self.ptr] = done\n","        self.ptr = (self.ptr + 1) % self.capacity\n","        self.size = min(self.size + 1, self.capacity)\n","    def sample(self, batch_size: int):\n","        indices = np.random.randint(0, self.size, batch_size)\n","        return (\n","            torch.FloatTensor(self.states[indices]).to(device),\n","            torch.FloatTensor(self.actions[indices]).to(device),\n","            torch.FloatTensor(self.rewards[indices]).to(device),\n","            torch.FloatTensor(self.next_states[indices]).to(device),\n","            torch.FloatTensor(self.dones[indices]).to(device)\n","        )\n","\n","class Actor(nn.Module):\n","    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):\n","        super().__init__()\n","        self.net = nn.Sequential(\n","            nn.Linear(state_dim, hidden_dim), nn.ReLU(),\n","            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()\n","        )\n","        self.mean_head = nn.Linear(hidden_dim, action_dim)\n","        self.log_std_head = nn.Linear(hidden_dim, action_dim)\n","        self.log_std_min = -20\n","        self.log_std_max = 2\n","        self.apply(self._init_weights)\n","    def _init_weights(self, m):\n","        if isinstance(m, nn.Linear):\n","            torch.nn.init.xavier_uniform_(m.weight)\n","            torch.nn.init.constant_(m.bias, 0)\n","    def forward(self, state):\n","        x = self.net(state)\n","        mean = self.mean_head(x)\n","        log_std = self.log_std_head(x)\n","        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)\n","        return mean, log_std\n","    def sample(self, state):\n","        mean, log_std = self.forward(state)\n","        std = log_std.exp()\n","        normal = torch.distributions.Normal(mean, std)\n","        x_t = normal.rsample()\n","        action = torch.tanh(x_t)\n","        log_prob = normal.log_prob(x_t)\n","        log_prob -= torch.log(1 - action.pow(2) + 1e-7)\n","        log_prob = log_prob.sum(1, keepdim=True)\n","        return action, log_prob\n","\n","class Critic(nn.Module):\n","    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):\n","        super().__init__()\n","        self.net = nn.Sequential(\n","            nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(),\n","            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),\n","            nn.Linear(hidden_dim, 1)\n","        )\n","        self.apply(self._init_weights)\n","    def _init_weights(self, m):\n","        if isinstance(m, nn.Linear):\n","            torch.nn.init.xavier_uniform_(m.weight)\n","            torch.nn.init.constant_(m.bias, 0)\n","    def forward(self, state, action):\n","        return self.net(torch.cat([state, action], dim=1))\n","\n","class SACAgent:\n","    def __init__(self, state_dim: int, action_dim: int,\n","                 temperature_mode: str = \"fixed\",\n","                 target_entropy: float = None,\n","                 initial_temperature: float = 0.5,\n","                 lr: float = 3e-4):\n","        self.state_dim = state_dim\n","        self.action_dim = action_dim\n","        self.temperature_mode = temperature_mode\n","        self.target_entropy = target_entropy if target_entropy else -action_dim\n","        print(f\"初始化 {temperature_mode} 智能体，目标熵: {self.target_entropy}\")\n","\n","        self.actor = Actor(state_dim, action_dim).to(device)\n","        self.critic1 = Critic(state_dim, action_dim).to(device)\n","        self.critic2 = Critic(state_dim, action_dim).to(device)\n","        self.target_critic1 = Critic(state_dim, action_dim).to(device)\n","        self.target_critic2 = Critic(state_dim, action_dim).to(device)\n","\n","        self.target_critic1.load_state_dict(self.critic1.state_dict())\n","        self.target_critic2.load_state_dict(self.critic2.state_dict())\n","\n","        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)\n","        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=lr)\n","        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=lr)\n","\n","        if temperature_mode == \"fixed\":\n","            self.temperature = initial_temperature\n","            self.log_temperature = None\n","            self.temperature_optimizer = None\n","        else:\n","            self.log_temperature = torch.log(torch.tensor(initial_temperature)).to(device).requires_grad_(True)\n","            self.temperature = self.log_temperature.exp().item()\n","            if temperature_mode == \"original\":\n","                self.temperature_optimizer = optim.Adam([self.log_temperature], lr=3e-4)\n","\n","        if temperature_mode == \"ttsa\":\n","            self.ttsa_gamma_base = 1e-3\n","            self.ttsa_step = 0\n","            self.temperature_optimizer = None\n","            self.h_ema = 0.0\n","            self.ema_alpha = 0.5\n","            self.update_alpha_freq = 5\n","\n","        self.temperature_history = [self.temperature]\n","        self.entropy_history = []\n","        self.h_values_history = []\n","        self.update_count = 0\n","        self.tau = 0.005\n","        self.gamma = 0.99\n","\n","    def select_action(self, state, eval_mode=False):\n","        state = torch.FloatTensor(state).unsqueeze(0).to(device)\n","        with torch.no_grad():\n","            if eval_mode:\n","                mean, _ = self.actor(state)\n","                action = torch.tanh(mean)\n","            else:\n","                action, _ = self.actor.sample(state)\n","        return action.cpu().numpy()[0]\n","\n","    def update(self, batch_size: int, replay_buffer: ReplayBuffer):\n","        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)\n","\n","        with torch.no_grad():\n","            next_actions, next_log_probs = self.actor.sample(next_states)\n","            target_q1 = self.target_critic1(next_states, next_actions)\n","            target_q2 = self.target_critic2(next_states, next_actions)\n","            target_q = torch.min(target_q1, target_q2) - self.temperature * next_log_probs\n","            target_q = rewards + (1 - dones) * self.gamma * target_q\n","\n","        current_q1 = self.critic1(states, actions)\n","        current_q2 = self.critic2(states, actions)\n","\n","        critic1_loss = F.mse_loss(current_q1, target_q)\n","        critic2_loss = F.mse_loss(current_q2, target_q)\n","\n","        self.critic1_optimizer.zero_grad()\n","        critic1_loss.backward()\n","        self.critic1_optimizer.step()\n","\n","        self.critic2_optimizer.zero_grad()\n","        critic2_loss.backward()\n","        self.critic2_optimizer.step()\n","\n","        new_actions, log_probs = self.actor.sample(states)\n","        q1_new = self.critic1(states, new_actions)\n","        q2_new = self.critic2(states, new_actions)\n","        q_new = torch.min(q1_new, q2_new)\n","\n","        actor_loss = (self.temperature * log_probs - q_new).mean()\n","\n","        self.actor_optimizer.zero_grad()\n","        actor_loss.backward()\n","        self.actor_optimizer.step()\n","\n","        current_entropy = -log_probs.mean().item()\n","        self.entropy_history.append(current_entropy)\n","\n","        if self.temperature_mode in [\"original\", \"ttsa\"] and self.update_count % 5 == 0:\n","            self._update_temperature(log_probs)\n","\n","        self._soft_update()\n","        self.update_count += 1\n","\n","        return {\n","            'critic1_loss': critic1_loss.item(),\n","            'critic2_loss': critic2_loss.item(),\n","            'actor_loss': actor_loss.item(),\n","            'temperature': self.temperature,\n","            'entropy': current_entropy\n","        }\n","\n","    def _update_temperature(self, log_probs):\n","        if self.temperature_mode == \"ttsa\":\n","            current_entropy = -log_probs.mean().detach().item()\n","            h_t = self.target_entropy - current_entropy\n","            self.h_ema = self.ema_alpha * self.h_ema + (1 - self.ema_alpha) * h_t\n","            self.h_values_history.append(self.h_ema)\n","\n","            self.ttsa_step += 1\n","            gamma_t = self.ttsa_gamma_base * np.exp(-0.0005 * self.ttsa_step)\n","            with torch.no_grad():\n","                self.log_temperature += gamma_t * self.h_ema\n","                self.log_temperature.clamp_(-5.0, 2.0)\n","            self.temperature = self.log_temperature.exp().item()\n","        elif self.temperature_mode == \"original\":\n","            temperature_loss = -(self.log_temperature * (log_probs + self.target_entropy).detach()).mean()\n","            self.temperature_optimizer.zero_grad()\n","            temperature_loss.backward()\n","            self.temperature_optimizer.step()\n","            self.temperature = self.log_temperature.exp().item()\n","\n","        self.temperature_history.append(self.temperature)\n","\n","    def _soft_update(self):\n","        for target_param, param in zip(self.target_critic1.parameters(), self.critic1.parameters()):\n","            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)\n","        for target_param, param in zip(self.target_critic2.parameters(), self.critic2.parameters()):\n","            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)\n","\n","def evaluate_agent(agent: SACAgent, env, num_episodes: int = 5) -> float:\n","    total_reward = 0\n","    for _ in range(num_episodes):\n","        state, _ = env.reset()\n","        episode_reward = 0\n","        done = False\n","        max_steps = env.spec.max_episode_steps if env.spec else 1000\n","        steps = 0\n","        while not done and steps < max_steps:\n","            action = agent.select_action(state, eval_mode=True)\n","            state, reward, terminated, truncated, _ = env.step(action)\n","            done = terminated or truncated\n","            episode_reward += reward\n","            steps += 1\n","        total_reward += episode_reward\n","    return total_reward / num_episodes\n","\n","def run_experiment(env_name: str = \"Pendulum-v1\",\n","                    seeds: list = [42, 43, 44, 45, 46],\n","                    total_timesteps: int = 30000,\n","                    eval_freq: int = 2000,\n","                    batch_size: int = 128,\n","                    buffer_size: int = 50000) -> Tuple[Dict, Dict]:\n","    set_seed(seeds[0])\n","    start_time = time.time()\n","    print(f\"\\n=== SAC温度调节对比实验 ===\")\n","    print(f\"环境: {env_name}, 种子: {seeds}, 总步数: {total_timesteps}, 设备: {device}\")\n","\n","    env = gym.make(env_name)\n","    eval_env = gym.make(env_name)\n","    state_dim = env.observation_space.shape[0]\n","    action_dim = env.action_space.shape[0]\n","    print(f\"状态维度: {state_dim}, 动作维度: {action_dim}\")\n","\n","    target_entropy = -float(action_dim)\n","    print(f\"目标熵: {target_entropy}\")\n","\n","    all_results = {}\n","    all_agents = {}\n","    for seed in seeds:\n","        set_seed(seed)\n","        agents = {\n","            \"Fixed-Temp\": SACAgent(state_dim, action_dim, temperature_mode=\"fixed\", initial_temperature=0.5, target_entropy=target_entropy),\n","            \"Original-SAC\": SACAgent(state_dim, action_dim, temperature_mode=\"original\", target_entropy=target_entropy),\n","            \"RF-BO (Ours)\": SACAgent(state_dim, action_dim, temperature_mode=\"ttsa\", target_entropy=target_entropy)\n","        }\n","        buffers = {name: ReplayBuffer(buffer_size, state_dim, action_dim) for name in agents.keys()}\n","        results = {name: {'episodes': [], 'eval_rewards': [], 'timesteps': [],\n","                          'temperatures': [], 'entropies': [], 'h_values': []} for name in agents.keys()}\n","\n","        print(f\"\\n--- 训练种子 {seed} ---\")\n","        state, _ = env.reset(seed=seed)\n","        episode_reward = 0\n","        local_timestep = 0\n","        local_episode = 0\n","\n","        while local_timestep < total_timesteps:\n","            if local_timestep < 1000:\n","                action = env.action_space.sample()\n","            else:\n","                action = agents[\"RF-BO (Ours)\"].select_action(state)\n","            next_state, reward, terminated, truncated, _ = env.step(action)\n","            done = terminated or truncated\n","\n","            for name in agents.keys():\n","                buffers[name].add(state, action, reward, next_state, done)\n","\n","            episode_reward += reward\n","            local_timestep += 1\n","            state = next_state\n","\n","            if buffers[\"RF-BO (Ours)\"].size > batch_size and local_timestep > 1000:\n","                for name in agents.keys():\n","                    agents[name].update(batch_size, buffers[name])\n","\n","            if local_timestep % eval_freq == 0:\n","                for name in agents.keys():\n","                    eval_reward = evaluate_agent(agents[name], eval_env)\n","                    results[name]['eval_rewards'].append(eval_reward)\n","                    results[name]['timesteps'].append(local_timestep)\n","                    results[name]['episodes'].append(local_episode)\n","                    if len(agents[name].temperature_history) > 0:\n","                        results[name]['temperatures'].append(agents[name].temperature_history[-1])\n","                    if len(agents[name].entropy_history) > 0:\n","                        recent_entropy = np.mean(agents[name].entropy_history[-50:])\n","                        results[name]['entropies'].append(recent_entropy)\n","                    if name == \"RF-BO (Ours)\" and len(agents[name].h_values_history) > 0:\n","                        recent_h = np.mean(agents[name].h_values_history[-50:])\n","                        results[name]['h_values'].append(recent_h)\n","                    recent_entropy = results[name]['entropies'][-1] if results[name]['entropies'] else 0.0\n","                    print(f\"  步数: {local_timestep:6d} | {name} 评估奖励: {eval_reward:8.2f} | 温度: {agents[name].temperature:.4f} | 熵: {recent_entropy:.4f}\")\n","\n","            if done:\n","                state, _ = env.reset()\n","                episode_reward = 0\n","                local_episode += 1\n","\n","        all_results[seed] = results\n","        all_agents[seed] = agents\n","        env.close()\n","        eval_env.close()\n","\n","    print(f\"\\n实验完成，总用时: {time.time() - start_time:.2f}秒\")\n","    return all_results, all_agents\n","\n","def plot_results(all_results: Dict, all_agents: Dict, target_entropy: float, env_name: str):\n","    print(\"绘制结果图...\")\n","    plt.rcParams['font.sans-serif'] = ['DejaVu Sans']\n","    fig = plt.figure(figsize=(12, 12))\n","    fig.suptitle(f'Figure 2: SAC Temperature Tuning ({env_name})', fontsize=12, fontweight='bold')\n","\n","    colors = {'Fixed-Temp': '#ff7f0e', 'Original-SAC': '#2ca02c', 'RF-BO (Ours)': '#1f77b4'}\n","\n","    first_seed = list(all_results.keys())[0]\n","\n","    last_5_means = {}\n","    for name in all_results[first_seed].keys():\n","        rewards = [res[name]['eval_rewards'][-5:] for res in all_results.values() if res[name]['eval_rewards']]\n","        entropies = [res[name]['entropies'][-5:] for res in all_results.values() if res[name]['entropies']]\n","        temps = [res[name]['temperatures'][-5:] for res in all_results.values() if res[name]['temperatures']]\n","        last_5_means[name] = {\n","            'reward_mean': np.mean([r for sublist in rewards for r in sublist]) if rewards else 0,\n","            'entropy_mean': np.mean([e for sublist in entropies for e in sublist]) if entropies else 0,\n","            'temp_mean': np.mean([t for sublist in temps for t in sublist]) if temps else 0\n","        }\n","\n","    ax1 = plt.subplot(3, 1, 1)\n","    for idx, name in enumerate(all_results[first_seed].keys()):\n","        rewards = [res[name]['eval_rewards'] for res in all_results.values()]\n","        timesteps = [res[name]['timesteps'] for res in all_results.values()]\n","        min_len = min(len(t) for t in rewards if t)\n","        rewards = [r[:min_len] for r in rewards if r]\n","        timesteps = timesteps[0][:min_len] if timesteps and timesteps[0] else []\n","        if rewards and timesteps:\n","            reward_mean = np.mean(rewards, axis=0)\n","            reward_std = np.std(rewards, axis=0)\n","            ax1.plot(timesteps, reward_mean, label=f'{name} (Mean: {last_5_means[name][\"reward_mean\"]:.2f})',\n","                      color=colors[name], lw=2)\n","            ax1.fill_between(timesteps, reward_mean - reward_std, reward_mean + reward_std,\n","                            color=colors[name], alpha=0.3)\n","    ax1.set_title('(a) Learning Curve', fontweight='bold')\n","    ax1.set_xlabel('Timesteps')\n","    ax1.set_ylabel('Evaluation Return')\n","    ax1.legend()\n","    ax1.grid(True, which='both', alpha=0.3)\n","    ax1.axhline(0, color='k', linestyle='--', lw=1, alpha=0.5)\n","\n","    ax2 = plt.subplot(3, 1, 2)\n","    for idx, name in enumerate(all_results[first_seed].keys()):\n","        entropies = [res[name]['entropies'] for res in all_results.values()]\n","        timesteps = [res[name]['timesteps'] for res in all_results.values()]\n","        min_len = min(len(t) for t in entropies if t)\n","        entropies = [e[:min_len] for e in entropies if e]\n","        timesteps = timesteps[0][:min_len] if timesteps and timesteps[0] else []\n","        if entropies and timesteps:\n","            entropy_mean = np.mean(entropies, axis=0)\n","            entropy_std = np.std(entropies, axis=0)\n","            ax2.plot(timesteps, entropy_mean, label=f'{name} (Mean: {last_5_means[name][\"entropy_mean\"]:.2f})',\n","                      color=colors[name], lw=2)\n","            ax2.fill_between(timesteps, entropy_mean - entropy_std, entropy_mean + entropy_std,\n","                            color=colors[name], alpha=0.3)\n","    ax2.axhline(target_entropy, color='r', linestyle='--', lw=2, label=f'Target Entropy={target_entropy:.2f}')\n","    ax2.set_title('(b) Policy Entropy Evolution', fontweight='bold')\n","    ax2.set_xlabel('Timesteps')\n","    ax2.set_ylabel('Entropy')\n","    ax2.legend()\n","    ax2.grid(True, which='both', alpha=0.3)\n","\n","    ax3 = plt.subplot(3, 1, 3)\n","    for idx, name in enumerate(all_results[first_seed].keys()):\n","        temps = [res[name]['temperatures'] for res in all_results.values()]\n","        timesteps = [res[name]['timesteps'] for res in all_results.values()]\n","        min_len = min(len(t) for t in temps if t)\n","        temps = [t[:min_len] for t in temps if t]\n","        timesteps = timesteps[0][:min_len] if timesteps and timesteps[0] else []\n","        if temps and timesteps:\n","            temp_mean = np.mean(temps, axis=0)\n","            temp_std = np.std(temps, axis=0)\n","            ax3.plot(timesteps, temp_mean, label=f'{name} (Mean: {last_5_means[name][\"temp_mean\"]:.2f})',\n","                      color=colors[name], lw=2)\n","            ax3.fill_between(timesteps, temp_mean - temp_std, temp_mean + temp_std,\n","                            color=colors[name], alpha=0.3)\n","    ax3.axhline(0.5, color='r', linestyle='--', lw=2, label='Fixed-Temp')\n","    ax3.set_title('(c) Temperature Evolution', fontweight='bold')\n","    ax3.set_xlabel('Timesteps')\n","    ax3.set_ylabel('Temperature α')\n","    ax3.legend()\n","    ax3.grid(True, which='both', alpha=0.3)\n","\n","    plt.tight_layout(rect=[0, 0, 1, 0.95])\n","    plt.show()\n","\n","def print_experiment_summary(all_results: Dict, all_agents: Dict, target_entropy: float):\n","    print(\"\\n\" + \"=\"*70 + \"\\nSAC温度调节实验结果总结\\n\" + \"=\"*70)\n","    print(f\"\\n1. 实验设置:\\n   目标熵: {target_entropy}\\n   对比方法: {list(all_results[list(all_results.keys())[0]].keys())}\")\n","\n","    summary = []\n","    baseline_method = \"RF-BO (Ours)\"\n","\n","    for name in all_results[list(all_results.keys())[0]].keys():\n","        final_rewards = [all_results[seed][name]['eval_rewards'][-3:] for seed in all_results if len(all_results[seed][name]['eval_rewards']) >= 3]\n","\n","        p_value = np.nan\n","        if name != baseline_method and final_rewards:\n","            baseline_rewards = [np.mean(all_results[seed][baseline_method]['eval_rewards'][-3:]) for seed in all_results if len(all_results[seed][baseline_method]['eval_rewards']) >= 3]\n","            current_rewards = [np.mean(r) for r in final_rewards]\n","            if len(baseline_rewards) > 1 and len(current_rewards) > 1:\n","                p_value = stats.ttest_ind(current_rewards, baseline_rewards)[1]\n","\n","        summary.append([name, \"...\", p_value])\n","\n","def print_experiment_summary_with_pvalue(all_results: Dict, target_entropy: float):\n","    print(\"\\n\" + \"=\"*80)\n","    print(\"Final Summary Table for Appendix (with p-values)\")\n","    print(\"=\"*80)\n","\n","    final_returns_raw = {}\n","    methods = list(all_results[list(all_results.keys())[0]].keys())\n","    for name in methods:\n","        returns_per_seed = []\n","        for seed in all_results:\n","            eval_rewards = all_results[seed][name]['eval_rewards']\n","            if len(eval_rewards) >= 5:\n","                returns_per_seed.append(np.mean(eval_rewards[-5:]))\n","        final_returns_raw[name] = returns_per_seed\n","\n","    p_values = {}\n","    baseline_returns = final_returns_raw.get('RF-BO (Ours)', [])\n","    for name in methods:\n","        if name == 'RF-BO (Ours)' or not baseline_returns:\n","            p_values[name] = 'N/A'\n","            continue\n","        method_returns = final_returns_raw.get(name, [])\n","        if method_returns:\n","            _, p_val = stats.ttest_ind(method_returns, baseline_returns, equal_var=False)\n","            p_values[name] = f\"{p_val:.3f}\"\n","        else:\n","            p_values[name] = 'N/A'\n","\n","    df = pd.DataFrame()\n","    for name in methods:\n","        final_mean = np.mean(final_returns_raw.get(name, [np.nan]))\n","        final_std = np.std(final_returns_raw.get(name, [np.nan]))\n","\n","        entropies = [np.mean(res[name]['entropies'][-3:]) for res in all_results.values() if len(res[name]['entropies'])>=3]\n","        entropy_dev = np.mean(np.abs(np.array(entropies) - target_entropy)) if entropies else np.nan\n","\n","        df.at[name, 'Final Return'] = f\"{final_mean:.2f} ± {final_std:.2f}\"\n","        df.at[name, 'Entropy Deviation'] = f\"{entropy_dev:.3f}\" if not np.isnan(entropy_dev) else \"N/A\"\n","        df.at[name, 'p-value vs RF-BO'] = p_values[name]\n","\n","    print(df)\n","    print(\"=\"*80)\n","\n","def main():\n","    ENV_NAME = \"Pendulum-v1\"\n","    NUM_SEED = 5\n","    ALL_SEEDS = [44, 47, 49, 50, 52, 53]\n","    SEEDS = ALL_SEEDS[:NUM_SEED]\n","    TOTAL_TIMESTEPS = 30000\n","    eval_freq = 2000\n","    start_time = time.time()\n","    try:\n","        all_results, all_agents = run_experiment(env_name=ENV_NAME, seeds=SEEDS, total_timesteps=TOTAL_TIMESTEPS, eval_freq=eval_freq)\n","        if all_results and all_agents:\n","            target_entropy = -all_agents[list(all_agents.keys())[0]][\"RF-BO (Ours)\"].action_dim\n","            plot_results(all_results, all_agents, target_entropy, ENV_NAME)\n","            print_experiment_summary(all_results, all_agents, target_entropy)\n","            print_experiment_summary_with_pvalue(all_results, target_entropy)\n","            print(\"\\n实验成功完成！\")\n","        else:\n","            print(\"\\n实验未能成功运行。\")\n","    except Exception as e:\n","        print(f\"\\n错误: {e}\")\n","        import traceback\n","        traceback.print_exc()\n","\n","if __name__ == \"__main__\":\n","    main()"],"metadata":{"id":"Ug2zr9J5L-zq"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# 6.3.1\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import numpy as np\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","from torch.autograd import grad\n","from IPython.display import clear_output\n","from scipy.stats import wasserstein_distance\n","\n","torch.manual_seed(42)\n","np.random.seed(42)\n","\n","BATCH_SIZE = 64\n","LR_G = 0.001\n","LR_D = 0.001\n","ITERATIONS = 4000\n","D_STEPS = 5\n","DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","print(f\"Running on {DEVICE}\")\n","\n","def get_real_data(batch_size):\n","    modes = torch.randint(0, 2, (batch_size, 1)).float().to(DEVICE)\n","    data = (modes * 4 - 2) + torch.randn(batch_size, 1).to(DEVICE) * 0.1\n","    return data\n","\n","class Generator(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.net = nn.Sequential(\n","            nn.Linear(8, 32), nn.ReLU(),\n","            nn.Linear(32, 32), nn.ReLU(),\n","            nn.Linear(32, 1)\n","        )\n","    def forward(self, z): return self.net(z)\n","\n","class Discriminator(nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.net = nn.Sequential(\n","            nn.Linear(1, 32), nn.Tanh(),\n","            nn.Linear(32, 32), nn.Tanh(),\n","            nn.Linear(32, 1)\n","        )\n","    def forward(self, x): return self.net(x)\n","\n","def compute_gradient_penalty(D, real_samples, fake_samples):\n","    alpha = torch.rand(real_samples.size(0), 1).to(DEVICE)\n","    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)\n","    d_interpolates = D(interpolates)\n","    fake = torch.ones(real_samples.shape[0], 1).to(DEVICE)\n","\n","    gradients = grad(\n","        outputs=d_interpolates, inputs=interpolates,\n","        grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True\n","    )[0]\n","\n","    gradients = gradients.view(gradients.size(0), -1)\n","    gradient_norm = gradients.norm(2, dim=1)\n","    root_h = gradient_norm.mean() - 1.0\n","    gradient_penalty = ((gradient_norm - 1) ** 2).mean()\n","    return gradient_penalty, root_h\n","\n","def train_gan(method='fixed'):\n","    G = Generator().to(DEVICE)\n","    D = Discriminator().to(DEVICE)\n","\n","    optimizer_G = optim.Adam(G.parameters(), lr=LR_G, betas=(0.5, 0.999))\n","    optimizer_D = optim.Adam(D.parameters(), lr=LR_D, betas=(0.5, 0.999))\n","\n","    lambda_param = torch.tensor([1.0], device=DEVICE, requires_grad=True)\n","\n","    if method == 'lse_grad':\n","        optimizer_lambda = optim.Adam([lambda_param], lr=0.02)\n","    elif method == 'dual_adam':\n","        optimizer_lambda = optim.Adam([lambda_param], lr=0.01, betas=(0.5, 0.999))\n","    else:\n","        optimizer_lambda = None\n","\n","    rfbo_h_smooth = 0.0\n","\n","    history = {'lambda': [], 'root_h': [], 'd_loss': []}\n","\n","    for it in range(ITERATIONS):\n","        for step_i in range(D_STEPS):\n","            real_data = get_real_data(BATCH_SIZE)\n","            z = torch.randn(BATCH_SIZE, 8).to(DEVICE)\n","            fake_data = G(z).detach()\n","\n","            d_real = D(real_data)\n","            d_fake = D(fake_data)\n","            w_dist = torch.mean(d_fake) - torch.mean(d_real)\n","            gp_loss, root_h = compute_gradient_penalty(D, real_data, fake_data)\n","\n","            is_last_step = (step_i == D_STEPS - 1)\n","            need_graph = (method == 'lse_grad') and is_last_step\n","\n","            curr_lambda = lambda_param if (method == 'lse_grad') else lambda_param.item()\n","            d_loss = w_dist + curr_lambda * gp_loss\n","\n","            optimizer_D.zero_grad()\n","            d_loss.backward(create_graph=need_graph)\n","            optimizer_D.step()\n","\n","        z_outer = torch.randn(BATCH_SIZE, 8).to(DEVICE)\n","        fake_outer = G(z_outer).detach()\n","        gp_loss_outer, root_h_outer = compute_gradient_penalty(D, real_data, fake_outer)\n","\n","        if method == 'fixed':\n","            lambda_param.data.fill_(10.0)\n","\n","        elif method == 'lse_grad':\n","            loss_upper = root_h_outer ** 2\n","            optimizer_lambda.zero_grad()\n","            loss_upper.backward()\n","            optimizer_lambda.step()\n","            with torch.no_grad(): lambda_param.clamp_(min=0.01)\n","            optimizer_D.zero_grad()\n","\n","        elif method == 'dual_adam':\n","            optimizer_lambda.zero_grad()\n","            lambda_param.grad = -root_h_outer.view(1).detach()\n","            optimizer_lambda.step()\n","            with torch.no_grad(): lambda_param.clamp_(min=0.01)\n","\n","        elif method == 'rfbo':\n","            rfbo_h_smooth = 0.9 * rfbo_h_smooth + 0.1 * root_h_outer.item()\n","            gamma = 0.008\n","\n","            with torch.no_grad():\n","                update = gamma * rfbo_h_smooth\n","                lambda_param.add_(update)\n","                lambda_param.clamp_(min=0.01)\n","\n","        z = torch.randn(BATCH_SIZE, 8).to(DEVICE)\n","        fake_out = G(z)\n","        g_loss = -torch.mean(D(fake_out))\n","        optimizer_G.zero_grad()\n","        g_loss.backward()\n","        optimizer_G.step()\n","\n","        history['lambda'].append(lambda_param.item())\n","        history['root_h'].append(root_h_outer.item())\n","        history['d_loss'].append(d_loss.item())\n","\n","    return history, G\n","\n","print(\"Training RF-BO (Ours)...\")\n","hist_rfbo, G_rfbo = train_gan(method='rfbo')\n","\n","print(\"Training LSE (True Grad)...\")\n","hist_lse_grad, G_lse_grad = train_gan(method='lse_grad')\n","\n","print(\"Training Dual Adam (Fair Baseline)...\")\n","hist_adam, G_adam = train_gan(method='dual_adam')\n","\n","clear_output()\n","\n","plt.figure(figsize=(18, 5))\n","\n","plt.subplot(1, 3, 1)\n","plt.plot(hist_rfbo['lambda'], label='RF-BO (Ours)', color='green', linewidth=2.5)\n","plt.plot(hist_lse_grad['lambda'], label='LSE (True Grad)', color='red', alpha=0.6, linestyle=':')\n","plt.plot(hist_adam['lambda'], label=\"Dual Adam\", color='blue', alpha=0.5, linestyle='-.')\n","plt.title(r\"Evolution of $\\lambda$\")\n","plt.xlabel(\"Iterations\")\n","plt.ylabel(r\"$\\lambda$ Value\")\n","plt.legend()\n","plt.grid(True, alpha=0.3)\n","\n","plt.subplot(1, 3, 2)\n","def smooth(y, box_pts):\n","    box = np.ones(box_pts)/box_pts\n","    return np.convolve(y, box, mode='same')\n","\n","plt.plot(smooth(hist_rfbo['root_h'], 50), label='RF-BO', color='green', linewidth=2)\n","plt.plot(smooth(hist_lse_grad['root_h'], 50), label='LSE (True)', color='red', alpha=0.5, linestyle=':')\n","plt.plot(smooth(hist_adam['root_h'], 50), label=\"Dual Adam\", color='blue', alpha=0.5, linestyle='-.')\n","plt.axhline(0, color='black', linestyle='--')\n","plt.title(\"Constraint Violation (Smoothed)\")\n","plt.xlabel(\"Iterations\")\n","plt.legend()\n","plt.grid(True, alpha=0.3)\n","\n","plt.subplot(1, 3, 3)\n","real_data = get_real_data(1000).cpu().numpy()\n","z = torch.randn(1000, 8).to(DEVICE)\n","fake_rfbo = G_rfbo(z).detach().cpu().numpy()\n","fake_adam = G_adam(z).detach().cpu().numpy()\n","fake_lse = G_lse_grad(z).detach().cpu().numpy()\n","\n","sns.kdeplot(real_data.flatten(), fill=True, label='Real Data', color='black', alpha=0.1)\n","sns.kdeplot(fake_rfbo.flatten(), label='RF-BO', color='green', linewidth=2)\n","sns.kdeplot(fake_adam.flatten(), label=\"Dual Adam\", color='blue', linestyle='-.')\n","sns.kdeplot(fake_lse.flatten(), label=\"LSE (True)\", color='red', alpha=0.6, linestyle=':')\n","\n","plt.title(\"Generated Data Distribution\")\n","plt.legend()\n","plt.grid(True, alpha=0.3)\n","\n","plt.tight_layout()\n","plt.show()\n","\n","print(f\"--- Constraint Violation (Mean Abs |h|) ---\")\n","print(f\"RF-BO (Ours):  {np.mean(np.abs(hist_rfbo['root_h'][-500:])):.4f}\")\n","print(f\"Dual Adam:     {np.mean(np.abs(hist_adam['root_h'][-500:])):.4f}\")\n","print(f\"LSE (True):    {np.mean(np.abs(hist_lse_grad['root_h'][-500:])):.4f}\")\n","\n","real_samples = get_real_data(1000).cpu().numpy().flatten()\n","fake_rfbo_s = G_rfbo(torch.randn(1000, 8).to(DEVICE)).detach().cpu().numpy().flatten()\n","fake_adam_s = G_adam(torch.randn(1000, 8).to(DEVICE)).detach().cpu().numpy().flatten()\n","fake_lse_s = G_lse_grad(torch.randn(1000, 8).to(DEVICE)).detach().cpu().numpy().flatten()\n","\n","print(f\"\\n--- Generation Quality (Wasserstein Dist) ---\")\n","print(f\"RF-BO (Ours):  {wasserstein_distance(real_samples, fake_rfbo_s):.4f}\")\n","print(f\"Dual Adam:     {wasserstein_distance(real_samples, fake_adam_s):.4f}\")\n","print(f\"LSE (True):    {wasserstein_distance(real_samples, fake_lse_s):.4f}\")\n","\n","print(f\"\\n--- Update Stability (Step Variance) ---\")\n","var_rfbo_step = np.var(np.diff(hist_rfbo['lambda'][-2000:]))\n","var_adam_step = np.var(np.diff(hist_adam['lambda'][-2000:]))\n","\n","print(f\"RF-BO Step Var:     {var_rfbo_step:.2e}\")\n","print(f\"Dual Adam Step Var: {var_adam_step:.2e}\")\n","\n","if var_rfbo_step < var_adam_step:\n","    print(f\">> RF-BO updates are {var_adam_step/var_rfbo_step:.1f}x smoother than Adam.\")"],"metadata":{"id":"oG-Z0UdBMOJ4"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Optimized Contrastive Learning Experiment: RF-BO (Root-Finding Bilevel Optimization)\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import torch.nn.functional as F\n","import torchvision\n","import torchvision.transforms as transforms\n","from torch.utils.data import DataLoader\n","from tqdm import tqdm\n","import numpy as np\n","import random\n","import time\n","import math\n","import os\n","import matplotlib.pyplot as plt\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","def setup_device():\n","    torch.backends.cuda.matmul.allow_tf32 = True\n","    torch.backends.cudnn.allow_tf32 = True\n","    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","    print(f\"Using device: {device} (TF32 Enabled)\")\n","    return device\n","\n","def set_seed(seed: int):\n","    random.seed(seed)\n","    np.random.seed(seed)\n","    torch.manual_seed(seed)\n","    if torch.cuda.is_available():\n","        torch.cuda.manual_seed_all(seed)\n","\n","DEVICE = setup_device()\n","BATCH_SIZE = 256\n","NUM_EPOCHS = 50\n","LR = 5e-4\n","WEIGHT_DECAY = 1e-4\n","ENCODER_OUTPUT_DIM = 512\n","PROJECTION_DIM = 128\n","NUM_WORKERS = 8\n","\n","INITIAL_TEMPERATURE = 0.5\n","SIMILARITY_TARGET = 0.6\n","RF_BO_GAMMA_BASE = 2.5e-3\n","WARMUP_EPOCHS = 5\n","\n","class TwoCropTransform:\n","    def __init__(self, transform):\n","        self.transform = transform\n","    def __call__(self, x):\n","        return [self.transform(x), self.transform(x)]\n","\n","def get_dataloaders():\n","    train_transform = transforms.Compose([\n","        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),\n","        transforms.RandomHorizontalFlip(),\n","        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),\n","        transforms.RandomGrayscale(p=0.2),\n","        transforms.ToTensor(),\n","        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])\n","    ])\n","    test_transform = transforms.Compose([\n","        transforms.ToTensor(),\n","        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])\n","    ])\n","\n","    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=TwoCropTransform(train_transform))\n","    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,\n","                                num_workers=NUM_WORKERS, pin_memory=True, drop_last=True, persistent_workers=True)\n","\n","    memory_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=test_transform)\n","    memory_loader = DataLoader(memory_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)\n","\n","    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)\n","    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)\n","\n","    align_transform = transforms.Compose([\n","        transforms.RandomResizedCrop(size=32, scale=(0.8, 1.0)),\n","        transforms.RandomHorizontalFlip(p=0.5),\n","        transforms.ToTensor(),\n","        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])\n","    ])\n","    align_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=TwoCropTransform(align_transform))\n","    align_loader = DataLoader(align_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)\n","\n","    return train_loader, memory_loader, test_loader, align_loader\n","\n","class SimCLRModel(nn.Module):\n","    def __init__(self, encoder_output_dim, projection_dim):\n","        super().__init__()\n","        self.encoder = torchvision.models.resnet18(weights=None, num_classes=encoder_output_dim)\n","        self.encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n","        self.encoder.maxpool = nn.Identity()\n","        self.projection_head = nn.Sequential(\n","            nn.Linear(encoder_output_dim, encoder_output_dim, bias=False),\n","            nn.ReLU(),\n","            nn.Linear(encoder_output_dim, projection_dim, bias=False)\n","        )\n","\n","    def forward(self, x, return_features=False):\n","        features = self.encoder(x)\n","        projection = self.projection_head(features)\n","        if return_features:\n","            return F.normalize(features, dim=1)\n","        return F.normalize(projection, dim=1)\n","\n","class NTXentLoss(nn.Module):\n","    def __init__(self, batch_size, temperature, device):\n","        super().__init__()\n","        self.batch_size = batch_size\n","        self.temperature = max(temperature, 0.01)\n","        self.device = device\n","        self.mask = self.mask_correlated_samples(batch_size)\n","        self.criterion = nn.CrossEntropyLoss(reduction=\"sum\")\n","        self.similarity_f = nn.CosineSimilarity(dim=2)\n","\n","    def mask_correlated_samples(self, batch_size):\n","        N = 2 * batch_size\n","        mask = torch.ones((N, N), dtype=bool)\n","        mask.fill_diagonal_(False)\n","        for i in range(batch_size):\n","            mask[i, batch_size + i] = False\n","            mask[batch_size + i, i] = False\n","        return mask.to(self.device)\n","\n","    def forward(self, z_i, z_j):\n","        N = 2 * self.batch_size\n","        z = torch.cat((z_i, z_j), dim=0)\n","        raw_sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0))\n","        sim = raw_sim / self.temperature\n","        sim_i_j = torch.diag(sim, self.batch_size)\n","        sim_j_i = torch.diag(sim, -self.batch_size)\n","        positives = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)\n","        negatives = sim[self.mask].reshape(N, -1)\n","        labels = torch.zeros(N).to(positives.device).long()\n","        logits = torch.cat((positives, negatives), dim=1)\n","        if torch.isnan(logits).any(): logits = torch.nan_to_num(logits, 0.0)\n","        loss = self.criterion(logits, labels) / N\n","        with torch.no_grad():\n","            mean_raw_pos_sim = torch.diag(raw_sim, self.batch_size).mean().item()\n","        return loss, mean_raw_pos_sim\n","\n","class ContrastiveAgent:\n","    def __init__(self, temp_mode=\"fixed\"):\n","        self.model = SimCLRModel(ENCODER_OUTPUT_DIM, PROJECTION_DIM).to(DEVICE)\n","        try: self.model = torch.compile(self.model)\n","        except: pass\n","        self.optimizer = optim.Adam(self.model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n","        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=NUM_EPOCHS)\n","        self.scaler = torch.amp.GradScaler('cuda')\n","        self.temp_mode = temp_mode\n","        self.rfbo_step = 0\n","\n","        if temp_mode == \"fixed\":\n","            self.temperature = INITIAL_TEMPERATURE\n","            self.log_temperature = None\n","            self.temp_optimizer = None\n","        else:\n","            self.log_temperature = torch.log(torch.tensor(INITIAL_TEMPERATURE)).to(DEVICE).requires_grad_(True)\n","            self.temperature = self.log_temperature.exp().item()\n","            if temp_mode == \"projected\":\n","                self.temp_optimizer = optim.SGD([self.log_temperature], lr=1e-3, momentum=0.9)\n","            else:\n","                self.temp_optimizer = None\n","\n","        self.loss_history, self.temp_history = [], [self.temperature]\n","        self.accuracy_history = []\n","\n","    def train_one_epoch(self, train_loader, current_epoch):\n","        self.model.train()\n","        start_time = time.time()\n","        total_loss = 0\n","        update_temp = (self.temp_mode != \"fixed\") and (current_epoch >= WARMUP_EPOCHS)\n","\n","        if self.temp_mode == \"cosine\":\n","            progress = current_epoch / NUM_EPOCHS\n","            self.temperature = 0.1 + 0.5 * (1 + math.cos(math.pi * progress)) * 0.4\n","            with torch.no_grad(): self.log_temperature.fill_(math.log(self.temperature))\n","\n","        for (images, _) in tqdm(train_loader, desc=f\"Ep {current_epoch} ({self.temp_mode})\", leave=False):\n","            images = torch.cat(images, dim=0).to(DEVICE)\n","            with torch.amp.autocast('cuda'):\n","                z = self.model(images)\n","                z_i, z_j = torch.split(z, BATCH_SIZE, dim=0)\n","                if self.temp_mode not in [\"fixed\", \"cosine\"]:\n","                     self.temperature = self.log_temperature.exp().clamp(0.01, 1.0).item()\n","                loss_func = NTXentLoss(BATCH_SIZE, self.temperature, DEVICE)\n","                loss, raw_pos_sim = loss_func(z_i, z_j)\n","\n","            if torch.isnan(loss): continue\n","            self.optimizer.zero_grad(set_to_none=True)\n","            self.scaler.scale(loss).backward()\n","            self.scaler.step(self.optimizer)\n","            self.scaler.update()\n","            total_loss += loss.item()\n","\n","            if update_temp and self.temp_mode not in [\"cosine\"]:\n","                self._update_temperature(raw_pos_sim)\n","\n","        self.loss_history.append(total_loss / len(train_loader))\n","        self.temp_history.append(self.temperature)\n","        self.scheduler.step()\n","        return time.time() - start_time\n","\n","    def _update_temperature(self, raw_pos_sim):\n","        if self.temp_mode == \"projected\":\n","            h_t = raw_pos_sim - SIMILARITY_TARGET\n","            temp_loss = -self.log_temperature * torch.tensor(h_t, device=DEVICE).detach()\n","            if self.temp_optimizer:\n","                self.temp_optimizer.zero_grad()\n","                temp_loss.backward()\n","                self.temp_optimizer.step()\n","                with torch.no_grad(): self.log_temperature.clamp_(-4.6, 0.0)\n","        elif self.temp_mode == \"rfbo\":\n","            h_t = raw_pos_sim - SIMILARITY_TARGET\n","            gamma_t = RF_BO_GAMMA_BASE / np.log(self.rfbo_step + np.e)\n","            with torch.no_grad():\n","                self.log_temperature += gamma_t * h_t\n","                self.log_temperature.clamp_(-4.6, 0.0)\n","            self.rfbo_step += 1\n","\n","def extract_features(encoder, loader):\n","    feats, labels = [], []\n","    encoder.eval()\n","    with torch.no_grad():\n","        for x, y in loader:\n","            with torch.amp.autocast('cuda'):\n","                h = encoder(x.to(DEVICE))\n","            feats.append(h.float().cpu())\n","            labels.append(y)\n","    return torch.cat(feats), torch.cat(labels)\n","\n","def train_linear_classifier(agent, memory_loader, test_loader):\n","    encoder = agent.model.encoder\n","    train_feats, train_labels = extract_features(encoder, memory_loader)\n","    test_feats, test_labels = extract_features(encoder, test_loader)\n","    classifier = nn.Linear(train_feats.size(1), 10).to(DEVICE)\n","    optimizer = optim.Adam(classifier.parameters(), lr=0.1)\n","    criterion = nn.CrossEntropyLoss()\n","    dataset = torch.utils.data.TensorDataset(train_feats, train_labels)\n","    loader = DataLoader(dataset, batch_size=512, shuffle=True)\n","    for _ in range(10):\n","        for f, l in loader:\n","            f, l = f.to(DEVICE), l.to(DEVICE)\n","            optimizer.zero_grad()\n","            criterion(classifier(f), l).backward()\n","            optimizer.step()\n","    classifier.eval()\n","    with torch.no_grad():\n","        acc = (classifier(test_feats.to(DEVICE)).argmax(1).cpu() == test_labels).float().mean().item() * 100\n","    return acc\n","\n","def calculate_metrics(agent, align_loader, test_loader):\n","    agent.model.eval()\n","    align_loss = 0\n","    with torch.no_grad():\n","        for (i1, i2), _ in align_loader:\n","            with torch.amp.autocast('cuda'):\n","                h1, h2 = agent.model(i1.to(DEVICE), True), agent.model(i2.to(DEVICE), True)\n","            align_loss += (h1 - h2).pow(2).sum(dim=1).mean().item()\n","    align = align_loss / len(align_loader)\n","\n","    feats = []\n","    with torch.no_grad():\n","        for x, _ in test_loader:\n","            with torch.amp.autocast('cuda'):\n","                feats.append(agent.model(x.to(DEVICE), True))\n","            if len(feats)*BATCH_SIZE >= 2048: break\n","    feats = torch.cat(feats)[:2048]\n","    sq_dist = torch.cdist(feats, feats, p=2).pow(2)\n","    uni = sq_dist[~torch.eye(2048, dtype=bool, device=DEVICE)].mul(-2).exp().mean().log().item()\n","    return align, uni\n","\n","def run_experiment(seed):\n","    set_seed(seed)\n","    train_loader, memory_loader, test_loader, align_loader = get_dataloaders()\n","    agents = {\n","        \"Cosine-Decay\": ContrastiveAgent(\"cosine\"),\n","        \"Projected-SGD\": ContrastiveAgent(\"projected\")\n","    }\n","    results = {}\n","    for name, agent in agents.items():\n","        epoch_times = []\n","        print(f\"\\n>>> Training {name} (Seed {seed})\")\n","        for epoch in range(NUM_EPOCHS):\n","            t = agent.train_one_epoch(train_loader, epoch+1)\n","            epoch_times.append(t)\n","\n","            if (epoch+1) in [1, 10, 20, 30, 40, 50]:\n","                acc = train_linear_classifier(agent, memory_loader, test_loader)\n","                align, uni = calculate_metrics(agent, align_loader, test_loader)\n","                agent.accuracy_history.append(acc)\n","                print(f\"E{epoch+1:02d} | Acc: {acc:.2f}% | Temp: {agent.temperature:.3f} | Uni: {uni:.3f} | Align: {align:.4f} | Time: {t:.2f}s\")\n","            else:\n","                agent.accuracy_history.append(0)\n","\n","        results[name] = {\n","            \"agent\": agent,\n","            \"accuracy\": max(agent.accuracy_history),\n","            \"avg_time\": np.mean(epoch_times)\n","        }\n","    return results\n","\n","def main():\n","    print(\"Starting Optimized SimCLR Experiment (RF-BO Rebrand)...\")\n","    start_time = time.time()\n","    seeds = [42, 43, 44]\n","    all_results = {s: run_experiment(s) for s in seeds}\n","\n","    print(\"\\n\" + \"=\"*30 + \"\\nFINAL SUMMARY\\n\" + \"=\"*30)\n","    methods = list(all_results[42].keys())\n","    for name in methods:\n","        accs = [all_results[s][name][\"accuracy\"] for s in seeds]\n","        times = [all_results[s][name][\"avg_time\"] for s in seeds]\n","        print(f\"{name:15} | Acc: {np.mean(accs):.2f}±{np.std(accs):.2f} | Avg Time/Epoch: {np.mean(times):.2f}s\")\n","\n","    print(f\"\\nTotal time: {(time.time()-start_time)/60:.1f} minutes\")\n","\n","if __name__ == '__main__':\n","    main()"],"metadata":{"id":"1QB7lj2KMOy-"},"execution_count":null,"outputs":[]}]}