{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7e263994",
   "metadata": {},
   "source": [
    "# Adaptive path Correction with Exponents (ACE)\n",
    "Demo code to reproduce the results and figures in the paper. A single A6000 GPU was used to run the demo notebook. Run the code blocks in order to prevent errors."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c36fbff5",
   "metadata": {},
   "source": [
    "## Step 1. Import libraries and train models (if not already trained)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c07eb92",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "from datetime import datetime\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import itertools\n",
    "import json\n",
    "from IPython.display import clear_output\n",
    "\n",
    "# Import ACE library components\n",
    "from ace_lib.metrics.export import compute_sample_based_metrics\n",
    "from ace_lib.interpolant import MLPInstFlexible\n",
    "from ace_lib.sample_data import ground_truth_hcg, plot_diagnostics\n",
    "from ace_lib.ace import simulate_ace\n",
    "from ace_lib.interpolant import Interpolant, FlowMatcher, plot_path_trajectories\n",
    "from ace_lib.utils import load_interpolants_from_json, set_seed, get_experiment_dir\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")\n",
    "set_seed(42)\n",
    "\n",
    "interpolant_schedules = load_interpolants_from_json(\"ace_lib/interpolant_schedules.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54e5ec1c",
   "metadata": {},
   "source": [
    "Skip the pretraining code if the 2D models are already trained under `/PretrainedToyModels/`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bbe7ec9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train toy models once (if not already trained) / Takes ~30 mins\n",
    "from ace_lib.train_toy_models import main\n",
    "main()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e671a38",
   "metadata": {},
   "source": [
    "## Figure 1. Marginal Path Collapse and Our Solution ACE \n",
    "$q^{(1)}q^{(2)}/q^{(3)}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5b9df84",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = get_experiment_dir(f\"ace_demo_runs_{datetime.now().strftime('%Y%m%d')}\", \"figure_1\")\n",
    "print(f\"Experiment directory: {exp_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47fabe6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "schedule_combination = [\"cos_t\", \"ddpm_linear\", \"default_linear\"]\n",
    "\n",
    "# Load the pretrained models\n",
    "model_path = \"PretrainedToyModels\"\n",
    "u_model1 = MLPInstFlexible(z_dim=1, cond_dim=1, width=256, depth=4, output_dim=1).to(device); u_model1.load_state_dict(torch.load(f\"{model_path}/u_model1_X_given_A_alpha={schedule_combination[0]}.pth\")); u_model1.eval()\n",
    "s_model1 = MLPInstFlexible(z_dim=1, cond_dim=1, width=256, depth=4, output_dim=1).to(device); s_model1.load_state_dict(torch.load(f\"{model_path}/s_model1_X_given_A_alpha={schedule_combination[0]}.pth\")); s_model1.eval()\n",
    "u_model2 = MLPInstFlexible(z_dim=2, cond_dim=1, width=256, depth=4, output_dim=2).to(device); u_model2.load_state_dict(torch.load(f\"{model_path}/u_model2_XY_given_B_alpha={schedule_combination[1]}.pth\")); u_model2.eval()\n",
    "s_model2 = MLPInstFlexible(z_dim=2, cond_dim=1, width=256, depth=4, output_dim=2).to(device); s_model2.load_state_dict(torch.load(f\"{model_path}/s_model2_XY_given_B_alpha={schedule_combination[1]}.pth\")); s_model2.eval()\n",
    "u_model3 = MLPInstFlexible(z_dim=1, cond_dim=0, width=256, depth=4, output_dim=1).to(device); u_model3.load_state_dict(torch.load(f\"{model_path}/u_model3_X_alpha={schedule_combination[2]}.pth\")); u_model3.eval()\n",
    "s_model3 = MLPInstFlexible(z_dim=1, cond_dim=0, width=256, depth=4, output_dim=1).to(device); s_model3.load_state_dict(torch.load(f\"{model_path}/s_model3_X_alpha={schedule_combination[2]}.pth\")); s_model3.eval()\n",
    "\n",
    "def v1_fn(x, t, A): return u_model1(x, t, A)\n",
    "def s1_fn(x, t, A): return s_model1(x, t, A)\n",
    "def v2_fn(z, t, B): return u_model2(z, t, B)\n",
    "def s2_fn(z, t, B): return s_model2(z, t, B)\n",
    "def v3_fn(x, t): return u_model3(x, t)\n",
    "def s3_fn(x, t): return s_model3(x, t)\n",
    "def sigma_fn(t): return 0.5 * torch.ones_like(t)\n",
    "\n",
    "print(\"Models loaded.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cc835aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the velocity, score, projection, embedding lists\n",
    "A, B = 1, 1\n",
    "v_fn_list=[\n",
    "        lambda x, t: v1_fn(x[:, :1], t, torch.full((x.size(0), 1), A, device=x.device)), # v1(X|A)\n",
    "        lambda x, t: v2_fn(x, t, torch.full((x.size(0), 1), B, device=x.device)), # v2(X|B)\n",
    "        lambda x, t: v3_fn(x[:, :1], t)                                                         # v3(Z)\n",
    "    ]\n",
    "s_fn_list=[\n",
    "        lambda x, t: s1_fn(x[:, :1], t, torch.full((x.size(0), 1), A, device=x.device)), # s1(X|A)\n",
    "        lambda x, t: s2_fn(x, t, torch.full((x.size(0), 1), B, device=x.device)), # s2(X|B)\n",
    "        lambda x, t: s3_fn(x[:, :1], t)                                                         # s3(Z)\n",
    "    ]\n",
    "proj_list=[\n",
    "        lambda z: z[:, :1],    # project to X \n",
    "        lambda z: z,           # identity for Z\n",
    "        lambda z: z[:, :1]     # project to X\n",
    "    ]\n",
    "emb_list=[\n",
    "        lambda x: torch.cat([x, torch.zeros(x.size(0), 1, device=x.device)], dim=1),  # embed X→Z\n",
    "        lambda z: z,                                                                  # identity\n",
    "        lambda x: torch.cat([x, torch.zeros(x.size(0), 1, device=x.device)], dim=1)   # embed X→Z\n",
    "    ]\n",
    "print(\"Velocity, score, projection, and embedding functions defined.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c29d1440",
   "metadata": {},
   "outputs": [],
   "source": [
    "Bump=0.0\n",
    "Ramp=0.0\n",
    "weight = 1.0\n",
    "\n",
    "# Define the exponent list\n",
    "gamma_list = [\n",
    "    lambda t : torch.tensor(1) + Bump * t * (1 - t) + Ramp * t,\n",
    "    lambda t : torch.tensor(1 * weight),\n",
    "    lambda t : torch.tensor(-1 * weight)\n",
    "]\n",
    "d_gamma_list = [\n",
    "    lambda t: torch.zeros_like(t) + Bump * (1 - 2 * t) + Ramp,\n",
    "    lambda t: torch.zeros_like(t),\n",
    "    lambda t: torch.zeros_like(t)\n",
    "]\n",
    "print(f\"Gamma and dGamma functions defined with Bump={Bump}_Ramp={Ramp}_weight={weight}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de23a157",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Criterion plot\n",
    "Criterion = lambda t: sum([ gamma_list[i](t) / (interpolant_schedules[schedule_combination[i]].alpha_t(t))**2 for i in range(len(schedule_combination)) ])\n",
    "t = torch.linspace(0.0, 1.0, 100)\n",
    "plt.plot(t.numpy(), Criterion(t).numpy())\n",
    "plt.xlabel('t')\n",
    "plt.ylabel('Criterion(t)')\n",
    "plt.ylim(-20, 100)\n",
    "plt.title('Criterion vs t')\n",
    "plt.grid(True)\n",
    "plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_Bump={Bump}_Ramp={Ramp}_weight={weight}.png\"))\n",
    "plt.show()\n",
    "\n",
    "# print when Criterion = 0\n",
    "for i in range(len(t)-1):\n",
    "    if Criterion(t[i]) > 0 and Criterion(t[i+1]) < 0:\n",
    "        print(\"Criterion = 0 at t =\", t[i].item())\n",
    "        break\n",
    "for i in range(len(t)-1):\n",
    "    if Criterion(t[i]) < 0 and Criterion(t[i+1]) > 0:\n",
    "        print(\"Criterion = 0 at t =\", t[i].item())\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b09b4295",
   "metadata": {},
   "outputs": [],
   "source": [
    "A, B = 1, 1\n",
    "x0 = torch.randn(10000, 2).to(\"cuda\")  # (X, Y) sample\n",
    "\n",
    "samples, logw_final, logw_history, sample_history, resample_history = simulate_ace(\n",
    "    x0=x0, \n",
    "    v_fn_list=v_fn_list, \n",
    "    s_fn_list=s_fn_list, \n",
    "    proj_list=proj_list, \n",
    "    emb_list=emb_list, \n",
    "    sigma_fn=sigma_fn,\n",
    "    v_star= lambda z, t: v2_fn(z, t, torch.full((z.size(0), 1), B, device=z.device)), \n",
    "    t0=0.0, t1=1.0, n_steps=100, device=\"cuda\", ess_threshold=0.4, print_resample_history=True,\n",
    "    gamma_list=gamma_list,\n",
    "    d_gamma_list=d_gamma_list,\n",
    "    resample=False\n",
    "    )\n",
    "print(\"ACE simulation completed.\")\n",
    "samples = samples.cpu().numpy()\n",
    "plot_diagnostics(samples, logw_final, logw_history, save_name=f\"{exp_dir}/diagnostic_plot_Bump={Bump}_Ramp={Ramp}_weight={weight}\")\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "246eb7f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_path_trajectories(sample_history, n_frame=4, experiment_id=exp_dir, name=f\"trajectory_plot_Bump={Bump}_Ramp={Ramp}_weight={weight}_NR\", deg=-50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a16ec11f",
   "metadata": {},
   "outputs": [],
   "source": [
    "A, B = 1, 1\n",
    "x0 = torch.randn(10000, 2).to(\"cuda\")  # (X, Y) sample\n",
    "\n",
    "samples, logw_final, logw_history, sample_history, resample_history = simulate_ace(\n",
    "    x0=x0, \n",
    "    v_fn_list=v_fn_list, \n",
    "    s_fn_list=s_fn_list, \n",
    "    proj_list=proj_list, \n",
    "    emb_list=emb_list, \n",
    "    sigma_fn=sigma_fn,\n",
    "    v_star= lambda z, t: v2_fn(z, t, torch.full((z.size(0), 1), B, device=z.device)), \n",
    "    t0=0.0, t1=1.0, n_steps=100, device=\"cuda\", ess_threshold=0.4, print_resample_history=True,\n",
    "    gamma_list=gamma_list,\n",
    "    d_gamma_list=d_gamma_list,\n",
    "    resample=True\n",
    "    )\n",
    "print(\"ACE simulation completed.\")\n",
    "samples = samples.cpu().numpy()\n",
    "plot_diagnostics(samples, logw_final, logw_history, save_name=f\"{exp_dir}/diagnostic_plot_Bump={Bump}_Ramp={Ramp}_weight={weight}_FKC\")\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca663455",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_path_trajectories(sample_history, n_frame=4, experiment_id=exp_dir, name=f\"trajectory_plot_Bump={Bump}_Ramp={Ramp}_weight={weight}_FKC\", deg=-50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc4b9088",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(schedule_combination)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea23da9",
   "metadata": {},
   "outputs": [],
   "source": [
    "Bump=10.0\n",
    "Ramp=2.0\n",
    "weight = 1.0\n",
    "\n",
    "# Define the exponent list\n",
    "gamma_list = [\n",
    "    lambda t : torch.tensor(1) + Bump * t * (1 - t) + Ramp * t,\n",
    "    lambda t : torch.tensor(1 * weight),\n",
    "    lambda t : torch.tensor(-1 * weight)\n",
    "]\n",
    "d_gamma_list = [\n",
    "    lambda t: torch.zeros_like(t) + Bump * (1 - 2 * t) + Ramp,\n",
    "    lambda t: torch.zeros_like(t),\n",
    "    lambda t: torch.zeros_like(t)\n",
    "]\n",
    "print(f\"Gamma and dGamma functions defined with Bump={Bump}_Ramp={Ramp}_weight={weight}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c43b61ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Criterion plot\n",
    "Criterion = lambda t: sum([ gamma_list[i](t) / (interpolant_schedules[schedule_combination[i]].alpha_t(t))**2 for i in range(len(schedule_combination)) ])\n",
    "t = torch.linspace(0.0, 1.0, 2000)\n",
    "plt.plot(t.numpy(), Criterion(t).numpy())\n",
    "plt.xlabel('t')\n",
    "plt.ylabel('Criterion(t)')\n",
    "plt.ylim(-20, 100)\n",
    "plt.title('Criterion vs t')\n",
    "plt.grid(True)\n",
    "plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_Bump={Bump}_Ramp={Ramp}_weight={weight}.png\"))\n",
    "plt.show()\n",
    "\n",
    "# print when Criterion = 0\n",
    "for i in range(len(t)-1):\n",
    "    if Criterion(t[i]) > 0 and Criterion(t[i+1]) < 0:\n",
    "        print(\"Criterion = 0 at t =\", t[i].item())\n",
    "        break\n",
    "for i in range(len(t)-1):\n",
    "    if Criterion(t[i]) < 0 and Criterion(t[i+1]) > 0:\n",
    "        print(\"Criterion = 0 at t =\", t[i].item())\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "803bdb5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "Bump_Values = [0.0, 10.0, 20.0, 30.0, 50.0, 100.0]\n",
    "Ramp_Values = [0.0, 0.5, 1.0, 1.5, 2.0]\n",
    "\n",
    "for Bump in Bump_Values:\n",
    "    for Ramp in Ramp_Values:\n",
    "        print(f\"\\nAnalyzing for Bump={Bump}, Ramp={Ramp}\")\n",
    "        weight = 1.0\n",
    "\n",
    "        # Define the exponent list\n",
    "        gamma_list = [\n",
    "            lambda t : torch.tensor(1) + Bump * t * (1 - t) + Ramp * t,\n",
    "            lambda t : torch.tensor(1 * weight),\n",
    "            lambda t : torch.tensor(-1 * weight)\n",
    "        ]\n",
    "        d_gamma_list = [\n",
    "            lambda t: torch.zeros_like(t) + Bump * (1 - 2 * t) + Ramp,\n",
    "            lambda t: torch.zeros_like(t),\n",
    "            lambda t: torch.zeros_like(t)\n",
    "        ]\n",
    "        Criterion = lambda t: sum([ gamma_list[i](t) / (interpolant_schedules[schedule_combination[i]].alpha_t(t))**2 for i in range(len(schedule_combination)) ])\n",
    "        t = torch.linspace(0.0, 1.0, 2000)\n",
    "        for i in range(len(t)-1):\n",
    "            if Criterion(t[i]) > 0 and Criterion(t[i+1]) < 0:\n",
    "                print(\"Criterion = 0 at t =\", t[i].item())\n",
    "                break\n",
    "        for i in range(len(t)-1):\n",
    "            if Criterion(t[i]) < 0 and Criterion(t[i+1]) > 0:\n",
    "                print(\"Criterion = 0 at t =\", t[i].item())\n",
    "                break\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6b68c20",
   "metadata": {},
   "outputs": [],
   "source": [
    "A, B = 1, 1\n",
    "x0 = torch.randn(10000, 2).to(\"cuda\")  # (X, Y) sample\n",
    "\n",
    "samples, logw_final, logw_history, sample_history, resample_history = simulate_ace(\n",
    "    x0=x0, \n",
    "    v_fn_list=v_fn_list, \n",
    "    s_fn_list=s_fn_list, \n",
    "    proj_list=proj_list, \n",
    "    emb_list=emb_list, \n",
    "    sigma_fn=sigma_fn,\n",
    "    v_star= lambda z, t: v2_fn(z, t, torch.full((z.size(0), 1), B, device=z.device)), \n",
    "    t0=0.0, t1=1.0, n_steps=100, device=\"cuda\", ess_threshold=0.4, print_resample_history=True,\n",
    "    gamma_list=gamma_list,\n",
    "    d_gamma_list=d_gamma_list,\n",
    "    resample=True\n",
    "    )\n",
    "print(\"ACE simulation completed.\")\n",
    "samples = samples.cpu().numpy()\n",
    "plot_diagnostics(samples, logw_final, logw_history, save_name=f\"{exp_dir}/diagnostic_plot_Bump={Bump}_Ramp={Ramp}_weight={weight}_ACE\")\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48d61f7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_path_trajectories(sample_history, n_frame=4, experiment_id=exp_dir, name=f\"trajectory_plot_Bump={Bump}_Ramp={Ramp}_weight={weight}_ACE\", deg=-50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9e044ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "Bump=0.0\n",
    "Ramp=0.0\n",
    "weight = 1.0\n",
    "\n",
    "# Define the exponent list\n",
    "gamma_list = [\n",
    "    lambda t : torch.tensor(1) + Bump * t * (1 - t) + Ramp * t,\n",
    "    lambda t : torch.tensor(1 * weight),\n",
    "    lambda t : torch.tensor(-1 * weight)\n",
    "]\n",
    "d_gamma_list = [\n",
    "    lambda t: torch.zeros_like(t) + Bump * (1 - 2 * t) + Ramp,\n",
    "    lambda t: torch.zeros_like(t),\n",
    "    lambda t: torch.zeros_like(t)\n",
    "]\n",
    "print(f\"Gamma and dGamma functions defined with Bump={Bump}_Ramp={Ramp}_weight={weight}.\")\n",
    "\n",
    "Criterion = lambda t: sum([ gamma_list[i](t) / (interpolant_schedules[schedule_combination[i]].alpha_t(t))**2 for i in range(len(schedule_combination)) ])\n",
    "\n",
    "t = torch.linspace(0.0, 1.0, 100)\n",
    "plt.plot(t.numpy(), Criterion(t).numpy(), label=\"Constant Exponents\")\n",
    "plt.xlabel('t')\n",
    "plt.ylabel('C(t)')\n",
    "plt.grid(True)\n",
    "\n",
    "\n",
    "Bump=10.0\n",
    "Ramp=5.0\n",
    "weight = 1.0\n",
    "\n",
    "# Define the exponent list\n",
    "gamma_list = [\n",
    "    lambda t : torch.tensor(1) + Bump * t * (1 - t) + Ramp * t,\n",
    "    lambda t : torch.tensor(1 * weight),\n",
    "    lambda t : torch.tensor(-1 * weight)\n",
    "]\n",
    "d_gamma_list = [\n",
    "    lambda t: torch.zeros_like(t) + Bump * (1 - 2 * t) + Ramp,\n",
    "    lambda t: torch.zeros_like(t),\n",
    "    lambda t: torch.zeros_like(t)\n",
    "]\n",
    "print(f\"Gamma and dGamma functions defined with Bump={Bump}_Ramp={Ramp}_weight={weight}.\")\n",
    "\n",
    "Criterion = lambda t: sum([ gamma_list[i](t) / (interpolant_schedules[schedule_combination[i]].alpha_t(t))**2 for i in range(len(schedule_combination)) ])\n",
    "\n",
    "t = torch.linspace(0.0, 1.0, 2000)\n",
    "plt.plot(t.numpy(), Criterion(t).numpy(), label='Adaptive Exponents (with Bump)')\n",
    "plt.legend()\n",
    "plt.ylim((-50,100))\n",
    "plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_{schedule_combination}_Bump{Bump}.png\"))\n",
    "plt.show()\n",
    "\n",
    "# print when Criterion = 0\n",
    "for i in range(len(t)-1):\n",
    "    if Criterion(t[i]) > 0 and Criterion(t[i+1]) < 0:\n",
    "        print(\"Criterion(t): pos -> neg at t =\", t[i].item())\n",
    "        break\n",
    "for i in range(len(t)-1):\n",
    "    if Criterion(t[i]) < 0 and Criterion(t[i+1]) > 0:\n",
    "        print(\"Criterion(t) neg -> pos at t =\", t[i].item())\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "177d90bb",
   "metadata": {},
   "source": [
    "## Figure 2. Non-integrable Region in the ratio-of-Gaussians example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c60a14dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = get_experiment_dir(f\"ace_demo_runs_{datetime.now().strftime('%Y%m%d')}\", \"figure_2\")\n",
    "print(f\"Experiment directory: {exp_dir}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b373058",
   "metadata": {},
   "source": [
    "We can simplify the GMM example for convenience:\n",
    "\n",
    "Let $p_t$ denote the path from $p_0=\\mathcal{N}(0,\\sigma_0^2I)$ to $p_1=\\mathcal{N}(0,\\sigma_1^2I)$ via $X_t = (1-t)X_0 + tX_1$. \n",
    "\n",
    "Then, the intermediate densities will be $p_t = \\mathcal{N}(0, \\left((1-t)^2\\sigma_0^2 + t^2 \\sigma_1^2\\right)I)$\n",
    "\n",
    "The score function is simply\n",
    "$$\n",
    "\\nabla \\log p_t (X) = -\\frac{1}{(1-t)^2\\sigma_0^2 + t^2 \\sigma_1^2} X\n",
    "$$\n",
    "\n",
    "Deriving the SDE from the Path\n",
    "\n",
    "For a general SDE of the form $dX_t = f(t,X_t)dt + g(t)dW_t$, the evolution of its variance $\\sigma_t^2$ for a zero-mean Guassian process is governed by the Fokker-Planck equation:\n",
    "$$\n",
    "\\frac{d\\sigma_t^2}{dt} = 2\\mu(t)\\sigma_t^2 + g(t)^2\n",
    "$$\n",
    "where we assumed linear drift of the form $f(t,X_t) = \\mu(t)X_t$. For this specific path, we have\n",
    "$$\n",
    "\\mu(t) = \\frac{-2(1-t)\\sigma_0^2 + 2t\\sigma_1^2 - g(t)^2}{2((1-t)^2\\sigma_0^2 + t^2\\sigma_1^2)}\n",
    "$$\n",
    "and the full drift is then $f(t,X_t) = \\mu(t)X_t$.\n",
    "\n",
    "As a result, we have\n",
    "$$\n",
    "dX_t = \\frac{-2(1-t)\\sigma_0^2 + 2t\\sigma_1^2 - g(t)^2}{2((1-t)^2\\sigma_0^2 + t^2\\sigma_1^2)} X_tdt + g(t)dW_t\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81db1c59",
   "metadata": {},
   "outputs": [],
   "source": [
    "BUMP_VALUE = 0.0\n",
    "\n",
    "def sigma_p1(t): return ((1-t)**2 + 0.5*t**2)/(1 + BUMP_VALUE * t * (1-t))\n",
    "def sigma_p2(t): return (1-t)**2 + 7*t**2\n",
    "def sigma_q1(t): return 1.5*(1-t)**2 + t**2\n",
    "def sigma_q2(t): return 1.5*(1-t)**2 + t**2\n",
    "\n",
    "def sigma_eff(var1, var2):\n",
    "    return 1.0 / (1/var1 + 1/var2)\n",
    "\n",
    "def sigma_P(t):\n",
    "    return sigma_eff(sigma_p1(t), sigma_p2(t))\n",
    "\n",
    "def sigma_Q(t):\n",
    "    return sigma_eff(sigma_q1(t), sigma_q2(t))\n",
    "\n",
    "def eff_var(ts, var1, var2, filter_negative=False):\n",
    "    if filter_negative:\n",
    "        indices = np.where(var1 < var2)[0]\n",
    "    else:\n",
    "        indices = np.arange(len(ts))\n",
    "    eff_ts = ts[indices]\n",
    "    eff_vars = 1.0 / (1/var1[indices] - 1/var2[indices])\n",
    "    return eff_ts, eff_vars\n",
    "\n",
    "# Find break points where integrability flips\n",
    "ts = np.linspace(0,1,500)\n",
    "sigP = np.array([sigma_P(t) for t in ts])\n",
    "sigQ = np.array([sigma_Q(t) for t in ts])\n",
    "diff = sigP - sigQ\n",
    "sign_changes = ts[np.where(np.diff(np.sign(diff)))[0]]\n",
    "eff_ts, eff_vars = eff_var(ts, sigP, sigQ)\n",
    "\n",
    "# Generate 10000 samples * 500 timesteps from Gaussian with effective variance \n",
    "samples = np.random.randn(len(eff_ts), 10000, 2) * np.sqrt(eff_vars[:, np.newaxis, np.newaxis])\n",
    "print(\"Samples shape:\", samples.shape)\n",
    "samples = torch.tensor(samples, device=\"cuda\", dtype=torch.float32)\n",
    "plot_path_trajectories(samples, n_frame=6, resample_history=None, divergence_points=[0.456, 0.63],  experiment_id=exp_dir, name=f\"gaussian_ratio_path_Bump={BUMP_VALUE}\", deg=-45, num_trajectory_points=0, hard_lim=15)\n",
    "\n",
    "print(\"Potential divergence points at:\", sign_changes)\n",
    "t_eff = 0.456\n",
    "i_eff = (np.abs(eff_ts - t_eff)).argmin()\n",
    "print(f\"Effective variance at {t_eff}: {eff_vars[i_eff]}\")\n",
    "\n",
    "# Plot variances\n",
    "plt.rcParams['figure.dpi'] = 300\n",
    "plt.figure(figsize=(8,4))\n",
    "plt.plot(ts, sigP, label=r\"$\\sigma_P^2(t)$\")\n",
    "plt.plot(ts, sigQ, label=r\"$\\sigma_Q^2(t)$\")\n",
    "plt.plot(eff_ts, eff_vars, label=r\"$\\sigma_{eff}^2(t)$\", ls=':')\n",
    "plt.axhline(0, color='black', lw=0.5)\n",
    "for sc in sign_changes:\n",
    "    plt.axvline(sc, color='red', ls='--', label=\"break point\")\n",
    "plt.legend()\n",
    "plt.xlabel(\"t\")\n",
    "plt.ylabel(\"effective variance\")\n",
    "plt.ylim(-1, 40)\n",
    "# plt.title(\"Where integrability breaks (2D Gaussian ratio)\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bcb790b",
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 5000\n",
    "ts = np.linspace(0,1,500)\n",
    "sigP = np.array([sigma_P(t) for t in ts])\n",
    "sigQ = np.array([sigma_Q(t) for t in ts])\n",
    "ts, eff_vars = eff_var(ts, sigP, sigQ, filter_negative=False)\n",
    "\n",
    "samples = [torch.randn(bs,2) * torch.sqrt(torch.tensor(eff_vars[i])) for i in range(len(ts))]\n",
    "\n",
    "plot_path_trajectories(samples, divergence_points=[0.456, 0.63], hard_lim=10, experiment_id=\".\", name=\"blowup_path\", n_frame=5, deg=-80, num_trajectory_points=0, interval_d=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da9973ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma_0 = 1.5\n",
    "sigma_1 = 1\n",
    "\n",
    "n_samples = 5000\n",
    "dim = 2\n",
    "n_timesteps = 1000\n",
    "T = 1.0\n",
    "dt = T / n_timesteps\n",
    "\n",
    "# For any linear path between two isotropic gaussians, the score and velocity functions are:\n",
    "def v_fn(x, t, sigma_0, sigma_1):\n",
    "    return (-2 * (1-t) * sigma_0**2 + 2 * t * sigma_1 **2) * x / ( 2 * ((1-t)**2 * sigma_0**2 + t**2 * sigma_1**2) )\n",
    "def s_fn(x, t, sigma_0, sigma_1):\n",
    "    return - x / ((1-t)**2 * sigma_0**2 + t**2 * sigma_1**2)\n",
    "def g_fn(t):\n",
    "    return 0.5\n",
    "\n",
    "X_t = torch.randn(n_samples, dim, device=device) * sigma_0\n",
    "sample_history = [X_t.clone()]\n",
    "\n",
    "for i in tqdm(range(n_timesteps), desc=\"Simulating SDE\"):\n",
    "    t = i * dt\n",
    "    g_t = g_fn(t)\n",
    "    \n",
    "    v_value = v_fn(X_t, t, sigma_0, sigma_1) + 0.5 * g_t**2 * s_fn(X_t, t, sigma_0, sigma_1)\n",
    "    drift = v_value * dt\n",
    "\n",
    "    random_noise = torch.randn_like(X_t)\n",
    "    diffusion = g_t * np.sqrt(dt) * random_noise\n",
    "    \n",
    "    X_t = X_t + drift + diffusion\n",
    "    \n",
    "    sample_history.append(X_t.clone())\n",
    "\n",
    "print(f\"Simulation finished. `sample_history` contains {len(sample_history)} timesteps.\")\n",
    "# plot_path_trajectories(sample_history, resample_history=None, hard_lim=10, experiment_id=exp_dir, name=\"gaussian_path_4\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f57710b6",
   "metadata": {},
   "source": [
    "## Figure E.12: Stabilizing a ratio-of-Gaussians path via the bump parameter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "443c46d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = get_experiment_dir(f\"ace_demo_runs_{datetime.now().strftime('%Y%m%d')}\", \"figure_E12\")\n",
    "print(f\"Experiment directory: {exp_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b6d2ba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "BUMP_VALUE = 0.1 #0.5\n",
    "\n",
    "def sigma_p1(t): return ((1-t)**2 + 0.5*t**2)/(1 + BUMP_VALUE * t * (1-t))\n",
    "def sigma_p2(t): return (1-t)**2 + 7*t**2\n",
    "def sigma_q1(t): return 1.5*(1-t)**2 + t**2\n",
    "def sigma_q2(t): return 1.5*(1-t)**2 + t**2\n",
    "\n",
    "def sigma_eff(var1, var2):\n",
    "    return 1.0 / (1/var1 + 1/var2)\n",
    "\n",
    "def sigma_P(t):\n",
    "    return sigma_eff(sigma_p1(t), sigma_p2(t))\n",
    "\n",
    "def sigma_Q(t):\n",
    "    return sigma_eff(sigma_q1(t), sigma_q2(t))\n",
    "\n",
    "def eff_var(ts, var1, var2, filter_negative=False):\n",
    "    if filter_negative:\n",
    "        indices = np.where(var1 < var2)[0]\n",
    "    else:\n",
    "        indices = np.arange(len(ts))\n",
    "    eff_ts = ts[indices]\n",
    "    eff_vars = 1.0 / (1/var1[indices] - 1/var2[indices])\n",
    "    return eff_ts, eff_vars\n",
    "\n",
    "# Find break points where integrability flips\n",
    "ts = np.linspace(0,1,500)\n",
    "sigP = np.array([sigma_P(t) for t in ts])\n",
    "sigQ = np.array([sigma_Q(t) for t in ts])\n",
    "diff = sigP - sigQ\n",
    "sign_changes = ts[np.where(np.diff(np.sign(diff)))[0]]\n",
    "eff_ts, eff_vars = eff_var(ts, sigP, sigQ)\n",
    "\n",
    "# Generate 10000 samples * 500 timesteps from Gaussian with effective variance \n",
    "samples = np.random.randn(len(eff_ts), 10000, 2) * np.sqrt(eff_vars[:, np.newaxis, np.newaxis])\n",
    "print(\"Samples shape:\", samples.shape)\n",
    "samples = torch.tensor(samples, device=\"cuda\", dtype=torch.float32)\n",
    "plot_path_trajectories(samples, overwrite_fractions=[0.0, 0.25, 0.5, 0.75, 1.0], experiment_id=exp_dir, name=f\"gaussian_ratio_path_Bump={BUMP_VALUE}\", deg=-45, num_trajectory_points=0, hard_lim=15)\n",
    "\n",
    "print(\"Potential divergence points at:\", sign_changes)\n",
    "t_eff = 0.456\n",
    "i_eff = (np.abs(eff_ts - t_eff)).argmin()\n",
    "print(f\"Effective variance at {t_eff}: {eff_vars[i_eff]}\")\n",
    "\n",
    "# Plot variances\n",
    "plt.rcParams['figure.dpi'] = 300\n",
    "plt.figure(figsize=(8,2))\n",
    "plt.plot(ts, sigP, label=r\"$\\sigma_P^2(t)$\")\n",
    "plt.plot(ts, sigQ, label=r\"$\\sigma_Q^2(t)$\")\n",
    "plt.plot(eff_ts, eff_vars, label=r\"$\\sigma_{eff}^2(t)$\", ls=':')\n",
    "plt.axhline(0, color='black', lw=0.5)\n",
    "for sc in sign_changes:\n",
    "    plt.axvline(sc, color='red', ls='--', label=\"break point\")\n",
    "plt.legend()\n",
    "plt.xlabel(\"t\")\n",
    "plt.ylabel(\"effective variance\")\n",
    "plt.ylim(-1, 40)\n",
    "# plt.title(\"Where integrability breaks (2D Gaussian ratio)\")\n",
    "plt.show()\n",
    "\n",
    "# print the minimum 1/effective variance\n",
    "# min_C = 1.0 / np.max(eff_vars[eff_vars > 0])\n",
    "# print(f\"Minimum Criterion min_t C(t): {min_C}\")\n",
    "print(f\"C(0.5) = {1.0 / (eff_vars[250])}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7e30fda",
   "metadata": {},
   "outputs": [],
   "source": [
    "BUMP_VALUE = 0.4\n",
    "\n",
    "def sigma_p1(t): return ((1-t)**2 + 0.5*t**2)/(1 + BUMP_VALUE * t * (1-t))\n",
    "def sigma_p2(t): return (1-t)**2 + 7*t**2\n",
    "def sigma_q1(t): return 1.5*(1-t)**2 + t**2\n",
    "def sigma_q2(t): return 1.5*(1-t)**2 + t**2\n",
    "\n",
    "def sigma_eff(var1, var2):\n",
    "    return 1.0 / (1/var1 + 1/var2)\n",
    "\n",
    "def sigma_P(t):\n",
    "    return sigma_eff(sigma_p1(t), sigma_p2(t))\n",
    "\n",
    "def sigma_Q(t):\n",
    "    return sigma_eff(sigma_q1(t), sigma_q2(t))\n",
    "\n",
    "def eff_var(ts, var1, var2, filter_negative=False):\n",
    "    if filter_negative:\n",
    "        indices = np.where(var1 < var2)[0]\n",
    "    else:\n",
    "        indices = np.arange(len(ts))\n",
    "    eff_ts = ts[indices]\n",
    "    eff_vars = 1.0 / (1/var1[indices] - 1/var2[indices])\n",
    "    return eff_ts, eff_vars\n",
    "\n",
    "# Find break points where integrability flips\n",
    "ts = np.linspace(0,1,500)\n",
    "sigP = np.array([sigma_P(t) for t in ts])\n",
    "sigQ = np.array([sigma_Q(t) for t in ts])\n",
    "diff = sigP - sigQ\n",
    "sign_changes = ts[np.where(np.diff(np.sign(diff)))[0]]\n",
    "eff_ts, eff_vars = eff_var(ts, sigP, sigQ)\n",
    "\n",
    "# Generate 10000 samples * 500 timesteps from Gaussian with effective variance \n",
    "samples = np.random.randn(len(eff_ts), 10000, 2) * np.sqrt(eff_vars[:, np.newaxis, np.newaxis])\n",
    "print(\"Samples shape:\", samples.shape)\n",
    "samples = torch.tensor(samples, device=\"cuda\", dtype=torch.float32)\n",
    "plot_path_trajectories(samples, n_frame=6, resample_history=None,  experiment_id=exp_dir, name=f\"gaussian_ratio_path_Bump={BUMP_VALUE}\", deg=-45, num_trajectory_points=0, hard_lim=15)\n",
    "\n",
    "print(\"Potential divergence points at:\", sign_changes)\n",
    "t_eff = 0.456\n",
    "i_eff = (np.abs(eff_ts - t_eff)).argmin()\n",
    "print(f\"Effective variance at {t_eff}: {eff_vars[i_eff]}\")\n",
    "\n",
    "# Plot variances\n",
    "plt.rcParams['figure.dpi'] = 300\n",
    "plt.figure(figsize=(8,4))\n",
    "plt.plot(ts, sigP, label=r\"$\\sigma_P^2(t)$\")\n",
    "plt.plot(ts, sigQ, label=r\"$\\sigma_Q^2(t)$\")\n",
    "plt.plot(eff_ts, eff_vars, label=r\"$\\sigma_{eff}^2(t)$\", ls=':')\n",
    "plt.axhline(0, color='black', lw=0.5)\n",
    "for sc in sign_changes:\n",
    "    plt.axvline(sc, color='red', ls='--', label=\"break point\")\n",
    "plt.legend()\n",
    "plt.xlabel(\"t\")\n",
    "plt.ylabel(\"effective variance\")\n",
    "plt.ylim(-1, 40)\n",
    "# plt.title(\"Where integrability breaks (2D Gaussian ratio)\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65d51033",
   "metadata": {},
   "source": [
    "## Figure 3. Common noise schedules and Marginal Path Collapse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "accf994d",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = get_experiment_dir(f\"ace_demo_runs_{datetime.now().strftime('%Y%m%d')}\", \"figure_3\")\n",
    "print(f\"Experiment directory: {exp_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3ad135f",
   "metadata": {},
   "outputs": [],
   "source": [
    "names = list(interpolant_schedules.keys())\n",
    "\n",
    "t = torch.linspace(0.0, 1.0, 100)\n",
    "\n",
    "name_eq = {\n",
    "    \"custom_poly\": r\"$\\alpha_t = -4t^3+7t^2-4t+1$\",\n",
    "    \"ddpm_linear\": r\"$\\alpha_t = \\text{DDPM}$\",\n",
    "    \"1-t**2\": r\"$\\alpha_t = 1-t^2$\",\n",
    "    \"sigmoid\": r\"$\\alpha_t = \\text{Sigmoid}$\",\n",
    "    \"default_linear\": r\"$\\alpha_t = 1-t$\",\n",
    "    \"cos_t\": r\"$\\alpha_t=\\cos(\\frac{\\pi}{2}t)$\"\n",
    "}\n",
    "\n",
    "\n",
    "plt.figure(figsize=(7, 7)) \n",
    "for name in names:\n",
    "    plt.plot(t.numpy(), interpolant_schedules[name].alpha_t(t).numpy(), label=name_eq[name])\n",
    "    # print(name, interpolant_schedules[name].alpha_t(0), interpolant_schedules[name].beta_t(0), interpolant_schedules[name].d_alpha_t(0), interpolant_schedules[name].d_beta_t(0))\n",
    "plt.xlabel(r'$t$')\n",
    "plt.ylabel(r'$\\alpha_t$')\n",
    "# plt.title(r'$\\alpha_t$ vs $t$ Graph for Common Noise Schedules')\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.savefig(os.path.join(exp_dir, \"alpha_t_plot_common_schedules.png\"))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39b4f5a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#tau_t = sqrt(1- interpolant_schedules[\"sigmoid\"].alpha_t(t))\n",
    "tau = lambda t: torch.sqrt(1 - interpolant_schedules[\"sigmoid\"].alpha_t(t))\n",
    "\n",
    "plt.figure(figsize=(7, 7))\n",
    "plt.plot(t.numpy(), tau(t).numpy(), label=r\"$\\tau(t)$\")\n",
    "plt.plot(t.numpy(), interpolant_schedules[\"1-t**2\"].alpha_t(tau(t)).numpy(), label=r\"$1-\\tau^2(t)$\")\n",
    "plt.plot(t.numpy(), interpolant_schedules[\"sigmoid\"].alpha_t(t).numpy(), label=r\"sigmoid$(t)$\", ls='--')\n",
    "plt.xlabel(r'$t$')\n",
    "plt.ylabel(r'$\\tau_t$')\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.savefig(os.path.join(exp_dir, \"alpha_t_plot_common_schedules.png\"))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64cff9d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_condition(alpha_funcs, g=1.0, n_grid=2000, Bump = 0.0, Ramp = 0.0, NegRamp = 0.0):\n",
    "    \"\"\"Check sign conditions for [a1, a2] or [a1, a2, a3].\"\"\"\n",
    "    ts = torch.linspace(0.0, 1.0, n_grid)\n",
    "\n",
    "    alphas = [f(ts) for f in alpha_funcs]\n",
    "    Bumps = torch.tensor([Bump * t * (1-t) for t in ts])\n",
    "    Ramps = torch.tensor([Ramp * t for t in ts])\n",
    "    NegRamps = torch.tensor([NegRamp * t for t in ts])\n",
    "    if len(alpha_funcs) == 2:\n",
    "        C = 2 / (alphas[0]**2 + 1e-12) - g / (alphas[1]**2 + 1e-12) + Bumps / (alphas[0]**2 + 1e-12) + Ramps / (alphas[0]**2 + 1e-12)\n",
    "    elif len(alpha_funcs) == 3:\n",
    "        if interpolant_schedules[\"cos_t\"].alpha_t in alpha_funcs:\n",
    "            # find the index of cos_t\n",
    "            cos_index = alpha_funcs.index(interpolant_schedules[\"cos_t\"].alpha_t)\n",
    "            # put cos_t at the beginning\n",
    "            if cos_index != 0:\n",
    "                alphas[0], alphas[cos_index] = alphas[cos_index], alphas[0]\n",
    "                print(f\"Swapped cos_t {cos_index} to the front for calculation.\")\n",
    "        C = 1 / (alphas[0]**2 + 1e-12) + 1 / (alphas[1]**2 + 1e-12) - (g - NegRamps) / (alphas[2]**2 + 1e-12) + Bumps / (alphas[0]**2 + 1e-12) + Ramps / (alphas[0]**2 + 1e-12)\n",
    "    else:\n",
    "        raise ValueError(\"Only supports 2 or 3 schedules\")\n",
    "\n",
    "    return (C[0] > 0) and (C.min() < 0), ts, C\n",
    "\n",
    "\n",
    "name_eq_plot = {\n",
    "    \"custom_poly\": r\"$-4t^3+7t^2-4t+1$\",\n",
    "    \"ddpm_linear\": r\"$\\text{DDPM}$\",\n",
    "    \"1-t**2\": r\"$1-t^2$\",\n",
    "    \"sigmoid\": r\"$\\text{Sigmoid}$\",\n",
    "    \"default_linear\": r\"$1-t$\",\n",
    "    \"cos_t\": r\"$\\cos(\\frac{\\pi}{2}t)$\"\n",
    "}\n",
    "\n",
    "\n",
    "def find_valid_combinations(interpolants, g=1.0, Bump=0.0, Ramp=0.0, NegRamp=0.0):\n",
    "    valid_pairs, valid_triples = [], []\n",
    "    plt.rcParams['figure.dpi'] = 300\n",
    "    plt.rcParams['font.size'] = 14\n",
    "\n",
    "    for a1, a2, a3 in reversed(list(itertools.permutations(interpolants, 3))):\n",
    "        # plot 5 combinations (to avoid clutter)\n",
    "        if len(valid_triples) > 5:\n",
    "            break\n",
    "\n",
    "        if \"custom_poly\" in [a1, a2, a3]:\n",
    "            continue\n",
    "\n",
    "        valid, t, C = check_condition([interpolant_schedules[a1].alpha_t, interpolant_schedules[a2].alpha_t, interpolant_schedules[a3].alpha_t], g=g, Bump=Bump, Ramp=Ramp, NegRamp=NegRamp)\n",
    "        if valid:\n",
    "            valid_triples.append([a1, a2, a3])\n",
    "            plt.ylim((-20,100))\n",
    "            plt.grid(True)\n",
    "            plt.plot(t.numpy(), C.numpy(), label=f'{name_eq_plot[a1]}, {name_eq_plot[a2]}, {name_eq_plot[a3]}')\n",
    "            plt.legend()\n",
    "    plt.xlabel('t')\n",
    "    plt.ylabel(r\"$C(t)$\")\n",
    "    plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_a1a2a3_no_bump.png\"))\n",
    "    plt.show()\n",
    "    plt.close()\n",
    "\n",
    "    if len(valid_triples) == 0:\n",
    "        for a1, a2 in itertools.permutations(interpolants, 2):\n",
    "            valid, t, C = check_condition([interpolants[a1].alpha_t, interpolants[a2].alpha_t], g=g, Bump=Bump, Ramp=Ramp)\n",
    "            if valid:\n",
    "                valid_pairs.append([a1, a1, a2])\n",
    "                plt.ylim((-20,100))\n",
    "                plt.grid(True)\n",
    "                plt.plot(t.numpy(), C.numpy(), label=f'{name_eq_plot[a1]}, {name_eq_plot[a1]}, {name_eq_plot[a2]}')\n",
    "                plt.legend()\n",
    "        plt.xlabel('t')\n",
    "        plt.ylabel(r\"$C(t)$\")\n",
    "        plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_a1a1a2_no_bump.png\"))\n",
    "        plt.show()\n",
    "        plt.close()\n",
    "\n",
    "    return valid_pairs, valid_triples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb4bee1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize randomly selected triples that violate the criterion (just 6 for clean plots; full search in appendix)\n",
    "triples = [\n",
    "    [\"cos_t\", \"ddpm_linear\", \"default_linear\"],\n",
    "    [\"1-t**2\", \"cos_t\", \"ddpm_linear\"],\n",
    "    [\"sigmoid\", \"cos_t\", \"ddpm_linear\"],\n",
    "    [\"cos_t\", \"cos_t\", \"sigmoid\"],\n",
    "    [\"sigmoid\", \"cos_t\", \"default_linear\"],\n",
    "    [\"cos_t\", \"cos_t\", \"default_linear\"]\n",
    "]\n",
    "\n",
    "for a1, a2, a3 in triples:\n",
    "    valid, t, C = check_condition([interpolant_schedules[a1].alpha_t, interpolant_schedules[a2].alpha_t, interpolant_schedules[a3].alpha_t], g=1.0, \n",
    "                                  Bump=0.0, Ramp=0.0, NegRamp=0.0)\n",
    "    plt.ylim((-20,100))\n",
    "    plt.grid(True)\n",
    "    plt.plot(t.numpy(), C.numpy(), label=f'{name_eq_plot[a1]}, {name_eq_plot[a2]}, {name_eq_plot[a3]}')\n",
    "    plt.legend()\n",
    "plt.xlabel('t')\n",
    "plt.ylabel(r\"$C(t)$\")\n",
    "plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_a1a2a3_with_bump.png\"))\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "for a1, a2, a3 in triples:\n",
    "    valid, t, C = check_condition([interpolant_schedules[a1].alpha_t, interpolant_schedules[a2].alpha_t, interpolant_schedules[a3].alpha_t], g=1.0, \n",
    "                                  Bump=10.0, Ramp=2.0, NegRamp=0.0)\n",
    "    plt.ylim((-20,100))\n",
    "    plt.grid(True)\n",
    "    plt.plot(t.numpy(), C.numpy(), label=f'{name_eq_plot[a1]}, {name_eq_plot[a2]}, {name_eq_plot[a3]}')\n",
    "    plt.legend()\n",
    "plt.xlabel('t')\n",
    "plt.ylabel(r\"$C(t)$\")\n",
    "plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_a1a2a3_with_bump.png\"))\n",
    "plt.show()\n",
    "plt.close()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e927b7fc",
   "metadata": {},
   "source": [
    "## Figure 4. Visualization of the sampling trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9b0db9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = get_experiment_dir(f\"ace_demo_runs_{datetime.now().strftime('%Y%m%d')}\", \"figure_4\")\n",
    "print(f\"Experiment directory: {exp_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72c23ac2",
   "metadata": {},
   "outputs": [],
   "source": [
    "Bump=0.0\n",
    "Ramp=2.0\n",
    "weight = 1.0\n",
    "\n",
    "# Define the exponent list\n",
    "gamma_list = [\n",
    "    lambda t : torch.tensor(1) + Bump * t * (1 - t) + Ramp * t,\n",
    "    lambda t : torch.tensor(1 * weight),\n",
    "    lambda t : torch.tensor(-1 * weight)\n",
    "]\n",
    "d_gamma_list = [\n",
    "    lambda t: torch.zeros_like(t) + Bump * (1 - 2 * t) + Ramp,\n",
    "    lambda t: torch.zeros_like(t),\n",
    "    lambda t: torch.zeros_like(t)\n",
    "]\n",
    "A, B = 1, 1\n",
    "x0 = torch.randn(10000, 2).to(\"cuda\")  # (X, Y) sample\n",
    "\n",
    "samples, logw_final, logw_history, sample_history, resample_history = simulate_ace(\n",
    "    x0=x0, \n",
    "    v_fn_list=v_fn_list, \n",
    "    s_fn_list=s_fn_list, \n",
    "    proj_list=proj_list, \n",
    "    emb_list=emb_list, \n",
    "    sigma_fn=sigma_fn,\n",
    "    v_star= lambda z, t: v2_fn(z, t, torch.full((z.size(0), 1), B, device=z.device)), \n",
    "    t0=0.0, t1=1.0, n_steps=100, device=\"cuda\", ess_threshold=0.7, print_resample_history=True,\n",
    "    gamma_list=gamma_list,\n",
    "    d_gamma_list=d_gamma_list,\n",
    "    resample=True\n",
    "    )\n",
    "print(\"ACE simulation completed.\")\n",
    "samples = samples.cpu().numpy()\n",
    "plot_path_trajectories(method_figure=True, sample_history=sample_history, resample_history=resample_history, n_frame=0, experiment_id=exp_dir, name=f\"effect_of_resampling_Bump={Bump}_Ramp={Ramp}_weight={weight}_ACE\", deg=-50)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b23b995",
   "metadata": {},
   "outputs": [],
   "source": [
    "Bump=0.0\n",
    "Ramp=2.0\n",
    "weight = 1.0\n",
    "\n",
    "# Define the exponent list\n",
    "gamma_list = [\n",
    "    lambda t : torch.tensor(1) + Bump * t * (1 - t) + Ramp * t,\n",
    "    lambda t : torch.tensor(1 * weight),\n",
    "    lambda t : torch.tensor(-1 * weight)\n",
    "]\n",
    "d_gamma_list = [\n",
    "    lambda t: torch.zeros_like(t) + Bump * (1 - 2 * t) + Ramp,\n",
    "    lambda t: torch.zeros_like(t),\n",
    "    lambda t: torch.zeros_like(t)\n",
    "]\n",
    "A, B = 1, 1\n",
    "x0 = torch.randn(10000, 2).to(\"cuda\")  # (X, Y) sample\n",
    "\n",
    "samples, logw_final, logw_history, sample_history, resample_history = simulate_ace(\n",
    "    x0=x0, \n",
    "    v_fn_list=v_fn_list, \n",
    "    s_fn_list=s_fn_list, \n",
    "    proj_list=proj_list, \n",
    "    emb_list=emb_list, \n",
    "    sigma_fn=sigma_fn,\n",
    "    v_star= lambda z, t: v2_fn(z, t, torch.full((z.size(0), 1), B, device=z.device)), \n",
    "    t0=0.0, t1=1.0, n_steps=100, device=\"cuda\", ess_threshold=0.7, print_resample_history=True,\n",
    "    gamma_list=gamma_list,\n",
    "    d_gamma_list=d_gamma_list,\n",
    "    resample=False\n",
    "    )\n",
    "print(\"ACE simulation completed.\")\n",
    "samples = samples.cpu().numpy()\n",
    "plot_path_trajectories(method_figure=False, sample_history=sample_history, resample_history=resample_history, n_frame=6, experiment_id=exp_dir, name=f\"effect_of_no_resampling_Bump={Bump}_Ramp={Ramp}_weight={weight}_ACE\", deg=-50)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc195abe",
   "metadata": {},
   "source": [
    "## Table 2. Distributional similarity metrics (lower is better)\n",
    "\n",
    "Run `ace_eval_script.py`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7e0094f",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = get_experiment_dir(f\"ace_demo_runs_{datetime.now().strftime('%Y%m%d')}\", \"Table_2\")\n",
    "print(f\"Experiment directory: {exp_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c467c1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_summary_stats(input_file=\"results.csv\", output_file=\"summary_stats.csv\"):\n",
    "    if not os.path.exists(input_file):\n",
    "        print(f\"Error: File '{input_file}' not found.\")\n",
    "        return\n",
    "\n",
    "    df = pd.read_csv(input_file)\n",
    "    group_cols = ['Method', 'Bump', 'Ramp', 'weight']\n",
    "    metric_cols = ['W1', 'W2', 'MMD', 'TV']\n",
    "\n",
    "    summary = df.groupby(group_cols)[metric_cols].agg(['mean', 'std', 'min', 'max'])\n",
    "\n",
    "    # Change ('W1', 'mean') to 'W1_mean'\n",
    "    summary.columns = ['_'.join(col).strip() for col in summary.columns.values]\n",
    "    \n",
    "    summary = summary.reset_index()\n",
    "    summary.to_csv(output_file, index=False)\n",
    "    \n",
    "    print(f\"Summary statistics saved to '{output_file}'\")\n",
    "    \n",
    "    # Optional: Print a preview to console\n",
    "    print(\"\\n--- Preview of Results ---\")\n",
    "    print(summary.head().to_string())\n",
    "\n",
    "compute_summary_stats()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a70d1f5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_file = \"summary_stats.csv\"\n",
    "if not os.path.exists(input_file):\n",
    "    print(f\"Error: {input_file} not found. Please run summarize_results.py first.\")\n",
    "\n",
    "# 1. Load Data\n",
    "df = pd.read_csv(input_file)\n",
    "\n",
    "# 2. Define the exact rows we want in the table\n",
    "# Added \"and Ramp == 0.0\" to all ACE rows for safety as discussed\n",
    "row_definitions = [\n",
    "    (\"NR*\",            \"Method == 'NR'\"),\n",
    "    (\"FKC*\",           \"Method == 'FKC'\"),\n",
    "    (\"ACE ($B=0$)\",     \"Method == 'ACE' and Bump == 0.0 and Ramp == 1.5\"),\n",
    "    (\"ACE ($B=10$)\",   \"Method == 'ACE' and Bump == 10.0 and Ramp == 1.5\"),\n",
    "    (\"ACE ($B=20$)\",   \"Method == 'ACE' and Bump == 20.0 and Ramp == 1.5\"),\n",
    "    (\"ACE ($B=30$)\",   \"Method == 'ACE' and Bump == 30.0 and Ramp == 1.5\"),\n",
    "    (\"ACE ($B=40$)\",   \"Method == 'ACE' and Bump == 40.0 and Ramp == 1.5\"),\n",
    "    (\"ACE ($B=50$)\",   \"Method == 'ACE' and Bump == 50.0 and Ramp == 1.5\"),\n",
    "    # (\"ACE ($B=100$)\", \"Method == 'ACE' and Bump == 100.0 and Ramp == 2.0\"),\n",
    "]\n",
    "\n",
    "# 3. Define Metrics and their precision\n",
    "metrics = [\n",
    "    (\"W1\", 2),\n",
    "    (\"W2\", 2),\n",
    "    (\"MMD\", 3)\n",
    "]\n",
    "\n",
    "# 4. Extract data and find \"Best\" (Minimum Mean) for bolding\n",
    "table_rows = []\n",
    "\n",
    "# Initialize minimums to infinity\n",
    "best_means = {m: float('inf') for m, _ in metrics}\n",
    "\n",
    "for label, query in row_definitions:\n",
    "    try:\n",
    "        subset = df.query(query)\n",
    "        \n",
    "        if subset.empty:\n",
    "            table_rows.append({'label': label, 'data': None})\n",
    "            continue\n",
    "\n",
    "        row_data = subset.iloc[0]\n",
    "        \n",
    "        processed_row = {'label': label, 'data': {}}\n",
    "        \n",
    "        for metric, _ in metrics:\n",
    "            mean_val = row_data[f\"{metric}_mean\"]\n",
    "            \n",
    "            # Check for global minimum (Best)\n",
    "            if mean_val < best_means[metric]:\n",
    "                best_means[metric] = mean_val\n",
    "            \n",
    "            processed_row['data'][metric] = {\n",
    "                'max':  row_data[f\"{metric}_max\"],\n",
    "                'mean': mean_val,\n",
    "                'std':  row_data[f\"{metric}_std\"]\n",
    "            }\n",
    "        \n",
    "        table_rows.append(processed_row)\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"Error processing row '{label}': {e}\")\n",
    "\n",
    "# 5. Generate LaTeX\n",
    "print(r\"\\begin{tabular}{lccccccc c}\")\n",
    "print(r\"\\toprule\")\n",
    "print(r\"\\multirow{2}{*}{Method}\")\n",
    "print(r\"% & \\multirow{2}{*}{Path Validity}\")\n",
    "print(r\"& \\multicolumn{2}{c}{$W_1$ ($\\downarrow$)} \")\n",
    "print(r\"& \\multicolumn{2}{c}{$W_2$ ($\\downarrow$)} \")\n",
    "print(r\"& \\multicolumn{2}{c}{MMD (RBF) ($\\downarrow$)} \\\\\")\n",
    "print(r\"% & \\multirow{2}{*}{Exponent Schedule} \\\\\")\n",
    "print(r\"\\cmidrule(lr){2-3} \\cmidrule(lr){4-5} \\cmidrule(lr){6-7}\")\n",
    "\n",
    "# --- CHANGED ORDER HERE ---\n",
    "print(r\"& Mean $\\pm$ Std & Max \")\n",
    "print(r\"& Mean $\\pm$ Std & Max \")\n",
    "print(r\"& Mean $\\pm$ Std & Max \\\\\") \n",
    "# --------------------------\n",
    "\n",
    "print(r\"\\midrule\")\n",
    "\n",
    "for i, row in enumerate(table_rows):\n",
    "    label = row['label']\n",
    "    \n",
    "    # Add midrule between baselines (FKC) and ACE methods if needed\n",
    "    # Note: Depending on your table logic, you might want the line after ACE* (B=0) or B=10\n",
    "    if label == \"ACE* ($B=10$)\":\n",
    "        print(r\"\\midrule\")\n",
    "\n",
    "    if row['data'] is None:\n",
    "        print(f\"{label} & - & - & - & - & - & - \\\\\\\\\")\n",
    "        continue\n",
    "\n",
    "    line_parts = [label]\n",
    "\n",
    "    for metric, precision in metrics:\n",
    "        vals = row['data'][metric]\n",
    "\n",
    "        max_str = f\"{vals['max']:.{precision}f}\"\n",
    "        mean_str = f\"{vals['mean']:.{precision}f}\"\n",
    "        std_str = f\"{vals['std']:.{precision}f}\"\n",
    "\n",
    "        if abs(vals['mean'] - best_means[metric]) < 1e-9:\n",
    "            mean_str = f\"\\\\textbf{{{mean_str}}}\"\n",
    "\n",
    "        # --- CHANGED ORDER HERE ---\n",
    "        # First: Mean +/- Std\n",
    "        line_parts.append(f\"{mean_str} $\\\\pm$ {std_str}\")\n",
    "        # Second: Max\n",
    "        line_parts.append(f\"{max_str}\")\n",
    "        # --------------------------\n",
    "\n",
    "    line_str = \" & \".join(line_parts) + r\" \\\\\"\n",
    "    print(line_str)\n",
    "\n",
    "print(r\"\\bottomrule\")\n",
    "print(r\"\\end{tabular}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c0c5155",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_sensitivity_grid():\n",
    "    input_file = \"summary_stats.csv\"\n",
    "    if not os.path.exists(input_file):\n",
    "        print(f\"Error: {input_file} not found.\")\n",
    "        return\n",
    "\n",
    "    df = pd.read_csv(input_file)\n",
    "\n",
    "    # 1. Filter for ACE only\n",
    "    df = df[df['Method'] == 'ACE']\n",
    "\n",
    "    target_metrics = ['W1', 'W2', 'MMD'] \n",
    "\n",
    "    for metric in target_metrics:\n",
    "        print(f\"\\n% --- Metric: {metric} ---\")\n",
    "        \n",
    "        # 2. Pivot\n",
    "        mean_pivot = df.pivot(index='Bump', columns='Ramp', values=f'{metric}_mean')\n",
    "        std_pivot = df.pivot(index='Bump', columns='Ramp', values=f'{metric}_std')\n",
    "\n",
    "        # 3. FILTERING LOGIC (The Fix)\n",
    "        # We want to keep ALL Rows (Bump), so we drop Columns (Ramp) \n",
    "        # that don't have data for every single Bump.\n",
    "        \n",
    "        # axis=1 : Look at columns. \n",
    "        # how='any': If ANY row in that column is NaN, drop the column.\n",
    "        mean_pivot.dropna(axis=1, how='any', inplace=True)\n",
    "        \n",
    "        # Align std_pivot\n",
    "        std_pivot = std_pivot.loc[mean_pivot.index, mean_pivot.columns]\n",
    "        \n",
    "        if mean_pivot.empty:\n",
    "            print(f\"% WARNING: No common Ramps found across all Bump values for {metric}.\")\n",
    "            continue\n",
    "\n",
    "        # Sort indices\n",
    "        mean_pivot.sort_index(axis=0, inplace=True)\n",
    "        mean_pivot.sort_index(axis=1, inplace=True)\n",
    "        std_pivot.sort_index(axis=0, inplace=True)\n",
    "        std_pivot.sort_index(axis=1, inplace=True)\n",
    "\n",
    "        min_val = mean_pivot.min().min()\n",
    "\n",
    "        # 4. Generate LaTeX\n",
    "        num_cols = len(mean_pivot.columns)\n",
    "        col_format = \"l\" + \"c\" * num_cols\n",
    "        \n",
    "        print(r\"\\begin{table}[h]\")\n",
    "        print(r\"\\centering\")\n",
    "        print(f\"\\\\caption{{ACE Sensitivity Analysis: ${metric}$ ($\\\\downarrow$)}}\")\n",
    "        \n",
    "        # Resize to fit linewidth\n",
    "        print(r\"\\resizebox{\\linewidth}{!}{\") \n",
    "        print(f\"\\\\begin{{tabular}}{{{col_format}}}\")\n",
    "        print(r\"\\toprule\")\n",
    "        \n",
    "        # Header\n",
    "        headers = [f\"{c}\" for c in mean_pivot.columns]\n",
    "        header_str = \" & \".join(headers)\n",
    "        print(f\"Bump $\\\\backslash$ Ramp & {header_str} \\\\\\\\\")\n",
    "        print(r\"\\midrule\")\n",
    "\n",
    "        # Data Rows\n",
    "        for bump_val in mean_pivot.index:\n",
    "            row_str = [f\"B={bump_val}\"]\n",
    "            \n",
    "            for ramp_val in mean_pivot.columns:\n",
    "                m = mean_pivot.loc[bump_val, ramp_val]\n",
    "                s = std_pivot.loc[bump_val, ramp_val]\n",
    "                \n",
    "                prec = 3 if metric == 'MMD' else 2\n",
    "                m_str = f\"{m:.{prec}f}\"\n",
    "                s_str = f\"{s:.{prec}f}\"\n",
    "                \n",
    "                if abs(m - min_val) < 1e-9:\n",
    "                    m_str = f\"\\\\textbf{{{m_str}}}\"\n",
    "                \n",
    "                cell = f\"{m_str} \\\\tiny{{$\\\\pm${s_str}}}\"\n",
    "                row_str.append(cell)\n",
    "\n",
    "            print(\" & \".join(row_str) + r\" \\\\\")\n",
    "\n",
    "        print(r\"\\bottomrule\")\n",
    "        print(r\"\\end{tabular}\")\n",
    "        print(r\"}\") \n",
    "        print(r\"\\end{table}\")\n",
    "        print(\"\\n\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    generate_sensitivity_grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4af32d43",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize linear bump and quadratic bump schedules\n",
    "plt.rcParams['figure.dpi'] = 300\n",
    "plt.rcParams['font.size'] = 14\n",
    "t = torch.linspace(0.0, 1.0, 100)\n",
    "t_end = 0.8\n",
    "plt.figure(figsize=(8, 4))\n",
    "plt.plot(t.numpy(), (t * (1 - t)).numpy(), label=r'$Q(t)=t(1-t)$ (Quadratic Bump)')\n",
    "plt.plot(t.numpy(), ((t * (t < t_end) + (t_end / (t_end - 1) * (t - t_end) + t_end) * (t >= t_end))).numpy(), label=r'$L_\\tau(t)=\\min(t, \\tau(1-t))$ (Linear Bump)')\n",
    "plt.plot(t.numpy(), (1.3 * t * (1 - t) + 0.3 * (t * (t < t_end) + (t_end / (t_end - 1) * (t - t_end) + t_end) * (t >= t_end))).numpy(), label=r'$b(t)=B_1Q(t)+B_2L_\\tau(t)$ (Combined Bump)')\n",
    "plt.xlabel(r'$t$')\n",
    "plt.ylabel('Bump Functions')\n",
    "plt.legend(loc='upper left')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7e138f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sort_parameter_combinations():\n",
    "    # 1. Define the search space\n",
    "    # B1 = Bump (Parabolic term)\n",
    "    Bump_Values = [0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 100.0]\n",
    "    # B2 = Ramp (Linear term)\n",
    "    Ramp_Values = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0, 8.0, 16.0]\n",
    "\n",
    "    combinations = []\n",
    "\n",
    "    # 2. Iterate through all combinations (Cartesian Product)\n",
    "    for b1 in Bump_Values:\n",
    "        for b2 in Ramp_Values:\n",
    "            # Calculate the cost based on your theorem's integral result\n",
    "            cost = (1/3) * (b1**2) + (b2**2)\n",
    "            \n",
    "            combinations.append({\n",
    "                \"Bump (B1)\": b1,\n",
    "                \"Ramp (B2)\": b2,\n",
    "                \"Cost\": cost\n",
    "            })\n",
    "\n",
    "    # 3. Sort by Cost (Ascending)\n",
    "    df = pd.DataFrame(combinations)\n",
    "    df = df.sort_values(by=\"Cost\", ascending=True).reset_index(drop=True)\n",
    "\n",
    "    # 4. Print formatted table\n",
    "    print(f\"{'Rank':<6} | {'Bump (B1)':<10} | {'Ramp (B2)':<10} | {'Cost':<15}\")\n",
    "    print(\"-\" * 50)\n",
    "    \n",
    "    for i, row in df.iterrows():\n",
    "        print(f\"{i+1:<6} | {row['Bump (B1)']:<10.1f} | {row['Ramp (B2)']:<10.1f} | {row['Cost']:<15.4f}\")\n",
    "\n",
    "sort_parameter_combinations()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d03c897",
   "metadata": {},
   "outputs": [],
   "source": [
    "schedule_combination = [\"cos_t\", \"ddpm_linear\", \"default_linear\"]\n",
    "# Criterion = lambda t: sum([ gamma_list[i](t) / (interpolant_schedules[schedule_combination[i]].alpha_t(t))**2 for i in range(len(schedule_combination)) ])\n",
    "\n",
    "def check_combinations():\n",
    "    Bump_Values = [0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 100.0]\n",
    "    Ramp_Values = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0, 8.0, 16.0]\n",
    "    weight = 1.0\n",
    "    \n",
    "    # Define time discretization for the check (avoiding t=1 singularity)\n",
    "    t_eval = torch.linspace(0, 0.999, 1000)\n",
    "\n",
    "    # --- 2. DEFINE SCHEDULES (ALPHAS) ---\n",
    "    alpha_funcs = [interpolant_schedules[schedule_combination[i]].alpha_t for i in range(len(schedule_combination))]\n",
    "\n",
    "    # --- 3. HELPER FUNCTIONS ---\n",
    "    def get_validity(B, R):\n",
    "        # 1. Define Gammas (Vectorized)\n",
    "        # Term 0: Target + Correction\n",
    "        g0 = 1 + B * t_eval * (1 - t_eval) + R * t_eval\n",
    "        # Term 1: Conditional\n",
    "        g1 = torch.full_like(t_eval, 1 * weight)\n",
    "        # Term 2: Prior (Negative)\n",
    "        g2 = torch.full_like(t_eval, -1 * weight)\n",
    "        \n",
    "        gammas = [g0, g1, g2]\n",
    "        \n",
    "        # 2. Compute Criterion\n",
    "        # C(t) = sum( gamma_i(t) / alpha_i(t)^2 )\n",
    "        criterion = torch.zeros_like(t_eval)\n",
    "        for i in range(3):\n",
    "            alpha = alpha_funcs[i](t_eval)\n",
    "            criterion += gammas[i] / (alpha**2 + 1e-8) # epsilon for stability\n",
    "            \n",
    "        # 3. Check Condition\n",
    "        min_val = criterion.min().item()\n",
    "        return \"O\" if min_val > 0 else \"X\"\n",
    "\n",
    "    # --- 4. GENERATE & SORT DATA ---\n",
    "    results = []\n",
    "    \n",
    "    for b in Bump_Values:\n",
    "        for r in Ramp_Values:\n",
    "            # Theoretical Smoothness Cost\n",
    "            cost = (1/3) * (b**2) + (r**2)\n",
    "            \n",
    "            # Validity Check\n",
    "            validity = get_validity(b, r)\n",
    "            \n",
    "            results.append({\n",
    "                \"Bump\": b,\n",
    "                \"Ramp\": r,\n",
    "                \"Cost\": cost,\n",
    "                \"Valid\": validity\n",
    "            })\n",
    "\n",
    "    # Sort by Cost\n",
    "    df = pd.DataFrame(results)\n",
    "    df = df.sort_values(by=\"Cost\", ascending=True).reset_index(drop=True)\n",
    "\n",
    "    # --- 5. PRINT TABLE ---\n",
    "    print(f\"{'Rank':<5} | {'Bump':<6} | {'Ramp':<6} | {'Valid':<5} | {'Cost':<10}\")\n",
    "    print(\"-\" * 45)\n",
    "    for i, row in df.iterrows():\n",
    "        print(f\"{i+1:<5} | {row['Bump']:<6.1f} | {row['Ramp']:<6.1f} | {row['Valid']:<5} | {row['Cost']:<10.4f}\")\n",
    "\n",
    "check_combinations()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ca44469",
   "metadata": {},
   "source": [
    "## Table E.5. Frequency of Marginal Path Collapse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f1798ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = get_experiment_dir(f\"ace_demo_runs_{datetime.now().strftime('%Y%m%d')}\", \"Table_E5\")\n",
    "print(f\"Experiment directory: {exp_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ffebac9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_interpolants_from_json_alpha_only(path):\n",
    "    with open(path, \"r\") as f:\n",
    "        interpolants_raw = json.load(f)\n",
    "\n",
    "    interpolants = {}\n",
    "    for name, funcs in interpolants_raw.items():\n",
    "        alpha_t = eval(funcs[\"alpha_t\"], {\"torch\": torch})\n",
    "        interpolants[name] = alpha_t\n",
    "    return interpolants\n",
    "\n",
    "# Load\n",
    "interpolants = load_interpolants_from_json_alpha_only(\"ace_lib/interpolant_schedules.json\")\n",
    "names = list(interpolants.keys())\n",
    "\n",
    "def check_condition(alpha_funcs, n_grid=200, Bump = 0.0, Anneal_weight=1.0):\n",
    "    \"\"\"Check sign conditions for [a1, a2, a3]. a1 a3 / a2 and anneal weight applies to (a1 / a2)^w a3\"\"\"\n",
    "    ts = torch.linspace(0.0, 0.99, n_grid)\n",
    "\n",
    "    if n_grid > 1:\n",
    "        dt = ts[1] - ts[0]\n",
    "    else:\n",
    "        dt = torch.tensor(0.0) \n",
    "\n",
    "    alphas = [f(ts) for f in alpha_funcs]\n",
    "    Bumps = torch.tensor([Bump * t * (1-t) for t in ts])\n",
    "    if len(alpha_funcs) == 3:\n",
    "        C = (Anneal_weight + Bumps) / (alphas[0]**2 + 1e-12) - Anneal_weight / (alphas[1]**2 + 1e-12) + 1 / (alphas[2]**2 + 1e-12)\n",
    "    else:\n",
    "        raise ValueError(\"Only supports 3 schedules\")\n",
    "\n",
    "    total_negative_length = 0.0\n",
    "    if n_grid > 1:\n",
    "        negative_intervals = C[:-1] < 0\n",
    "        total_negative_length_tensor = torch.sum(negative_intervals.float()) * dt\n",
    "        total_negative_length = total_negative_length_tensor.item()\n",
    "\n",
    "    return (C.min() < 0), ts, C, total_negative_length\n",
    "\n",
    "name_eq_plot = {\n",
    "    \"ddpm_linear\": r\"$\\text{DDPM}$\",\n",
    "    \"1-t**2\": r\"$1-t^2$\",\n",
    "    \"sigmoid\": r\"$\\text{Sigmoid}$\",\n",
    "    \"default_linear\": r\"$1-t$\",\n",
    "    \"cos_t\": r\"$\\cos(\\frac{\\pi}{2}t)$\"\n",
    "}\n",
    "\n",
    "\n",
    "def find_valid_combinations(interpolants, Anneal_weight=1.0, Bump=0.0, unique=True):\n",
    "    collapse_combinations = []\n",
    "\n",
    "    for a1, a2, a3 in itertools.product(interpolants, repeat=3):\n",
    "        valid, t, C, total_negative_length = check_condition([interpolants[a1], interpolants[a2], interpolants[a3]], Anneal_weight=Anneal_weight, Bump=Bump)\n",
    "        if valid:\n",
    "            collapse_combinations.append([a1, a2, a3, total_negative_length])\n",
    "    collapse_combinations.sort(key=lambda x: x[3], reverse=True)\n",
    "    if unique:\n",
    "        unique_combinations = []\n",
    "        seen = set()\n",
    "        for combo in collapse_combinations:\n",
    "            identifier = combo[3]\n",
    "            if identifier not in seen:\n",
    "                unique_combinations.append(combo)\n",
    "                seen.add(identifier)\n",
    "        collapse_combinations = unique_combinations # Remove duplicates if total_negative_length is the same\n",
    "    return collapse_combinations\n",
    "\n",
    "weights = [1.0, 1.1, 1.5, 2.0, 7.5, 15]\n",
    "for ANNEAL_WEIGHT in weights:\n",
    "    collapse_combinations = find_valid_combinations(interpolants, Anneal_weight=ANNEAL_WEIGHT, Bump=0.0, unique=False)\n",
    "    print(f\"There are {len(collapse_combinations)} combinations that have path collapse under anneal weight {ANNEAL_WEIGHT}.\")\n",
    "\n",
    "    with open(os.path.join(exp_dir, f\"non_unique_collapse_combinations_anneal_weight={ANNEAL_WEIGHT}.json\"), \"w\") as f:\n",
    "        json.dump(collapse_combinations, f, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "451078be",
   "metadata": {},
   "source": [
    "## Table E.6: Metrics under Homogeneous Composition (collapse duration 0%)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22f53b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = get_experiment_dir(f\"ace_demo_runs_{datetime.now().strftime('%Y%m%d')}\", \"Table_E6\")\n",
    "print(f\"Experiment directory: {exp_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8739be86",
   "metadata": {},
   "outputs": [],
   "source": [
    "names = ['ddpm_linear', 'ddpm_linear', 'ddpm_linear',0]\n",
    "\n",
    "bs = 10000; n_steps=1000; seeds = [0,1,2,3,4]\n",
    "ESS_THRESHOLD = 0.7\n",
    "ANNEAL_WEIGHT = 1.0\n",
    "BUMP_VALUE = 0.0\n",
    "RUN_METRIC_EVAL = True\n",
    "results = []\n",
    "\n",
    "u_model1 = MLPInstFlexible(z_dim=1, cond_dim=1, width=256, depth=4, output_dim=1).to(device); u_model1.load_state_dict(torch.load(f\"PretrainedToyModels/u_model1_X_given_A_alpha={names[0]}.pth\")); u_model1.eval()\n",
    "s_model1 = MLPInstFlexible(z_dim=1, cond_dim=1, width=256, depth=4, output_dim=1).to(device); s_model1.load_state_dict(torch.load(f\"PretrainedToyModels/s_model1_X_given_A_alpha={names[0]}.pth\")); s_model1.eval()\n",
    "u_model2 = MLPInstFlexible(z_dim=2, cond_dim=1, width=256, depth=4, output_dim=2).to(device); u_model2.load_state_dict(torch.load(f\"PretrainedToyModels/u_model2_XY_given_B_alpha={names[2]}.pth\")); u_model2.eval()\n",
    "s_model2 = MLPInstFlexible(z_dim=2, cond_dim=1, width=256, depth=4, output_dim=2).to(device); s_model2.load_state_dict(torch.load(f\"PretrainedToyModels/s_model2_XY_given_B_alpha={names[2]}.pth\")); s_model2.eval()\n",
    "u_model3 = MLPInstFlexible(z_dim=1, cond_dim=0, width=256, depth=4, output_dim=1).to(device); u_model3.load_state_dict(torch.load(f\"PretrainedToyModels/u_model3_X_alpha={names[1]}.pth\")); u_model3.eval()\n",
    "s_model3 = MLPInstFlexible(z_dim=1, cond_dim=0, width=256, depth=4, output_dim=1).to(device); s_model3.load_state_dict(torch.load(f\"PretrainedToyModels/s_model3_X_alpha={names[1]}.pth\")); s_model3.eval()\n",
    "\n",
    "def v1_fn(x, t, A): return u_model1(x, t, A)\n",
    "def s1_fn(x, t, A): return s_model1(x, t, A)\n",
    "def v2_fn(x, t): return u_model3(x, t)\n",
    "def s2_fn(x, t): return s_model3(x, t)\n",
    "def v3_fn(z, t, B): return u_model2(z, t, B)\n",
    "def s3_fn(z, t, B): return s_model2(z, t, B)\n",
    "def sigma_fn(t): return 0.5 * torch.ones_like(t)\n",
    "\n",
    "v_fn_list=[\n",
    "        lambda x, t: v1_fn(x[:, :1], t, torch.full((x.size(0), 1), A, device=x.device)), # v1(X|A)\n",
    "        lambda x, t: v2_fn(x[:, :1], t),                                                 # v2(X)\n",
    "        lambda x, t: v3_fn(x, t, torch.full((x.size(0), 1), B, device=x.device))         # v3(Z|B)\n",
    "    ]\n",
    "s_fn_list=[\n",
    "        lambda x, t: s1_fn(x[:, :1], t, torch.full((x.size(0), 1), A, device=x.device)), # s1(X|A)\n",
    "        lambda x, t: s2_fn(x[:, :1], t),                                                 # s2(X)\n",
    "        lambda x, t: s3_fn(x, t, torch.full((x.size(0), 1), B, device=x.device))         # s3(Z|B)\n",
    "    ]\n",
    "proj_list=[\n",
    "        lambda z: z[:, :1],    # project to X \n",
    "        lambda z: z[:, :1],    # project to X\n",
    "        lambda z: z            # identity for Z\n",
    "    ]\n",
    "emb_list=[\n",
    "        lambda x: torch.cat([x, torch.zeros(x.size(0), 1, device=x.device)], dim=1),  # embed X→Z\n",
    "        lambda x: torch.cat([x, torch.zeros(x.size(0), 1, device=x.device)], dim=1),  # embed X→Z\n",
    "        lambda z: z  # identity\n",
    "    ]\n",
    "print(f\"{names} models loaded.\")\n",
    "\n",
    "\n",
    "for seed in seeds:\n",
    "    torch.manual_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    print(f\"Seed set to {seed}\")\n",
    "\n",
    "    for Method_name in [\"NR\", \"ACE\"]: # FKC:\n",
    "        print(f\"Method: {Method_name}\")\n",
    "        if Method_name == \"FKC\" or Method_name == \"NR\":\n",
    "                print(\"Simulating FKC (Constant Gammas)\")\n",
    "                gamma_list = [\n",
    "                    lambda t : torch.tensor(1) * ANNEAL_WEIGHT,\n",
    "                    lambda t : torch.tensor(-1) * ANNEAL_WEIGHT,\n",
    "                    lambda t : torch.tensor(1)\n",
    "                ]\n",
    "                d_gamma_list = [\n",
    "                    lambda t: torch.zeros_like(t),\n",
    "                    lambda t: torch.zeros_like(t),\n",
    "                    lambda t: torch.zeros_like(t)\n",
    "                ]\n",
    "        elif Method_name == \"ACE\":\n",
    "            print(\"Simulating ACE (Adaptive Gammas)\")\n",
    "            gamma_list = [\n",
    "                lambda t : torch.tensor(1) * ANNEAL_WEIGHT + (BUMP_VALUE * t * (1 - t)),\n",
    "                lambda t : torch.tensor(-1) * ANNEAL_WEIGHT,\n",
    "                lambda t : torch.tensor(1)\n",
    "            ]\n",
    "            d_gamma_list = [\n",
    "                lambda t: torch.zeros_like(t) + (BUMP_VALUE * (1 - 2*t)),\n",
    "                lambda t: torch.zeros_like(t),\n",
    "                lambda t: torch.zeros_like(t)\n",
    "            ]\n",
    "        if seed == 0:\n",
    "            Criterion = lambda t: sum([ gamma_list[i](t) / (interpolants[names[i]](t))**2 for i in range(len(names)-1) ])\n",
    "            t = torch.linspace(0.0, 0.99, 100)\n",
    "            plt.plot(t.numpy(), Criterion(t).numpy())\n",
    "            plt.xlabel('t')\n",
    "            plt.ylabel('Criterion C(t)')\n",
    "            plt.title('Criterion C(t) vs t')\n",
    "            plt.grid(True)\n",
    "            plt.ylim(-20,100)\n",
    "            plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_{names}_ANNEAL={ANNEAL_WEIGHT}_BUMP={BUMP_VALUE}_Method={Method_name}.png\"))\n",
    "            plt.show()\n",
    "            plt.close()\n",
    "\n",
    "        # print when Criterion = 0\n",
    "        for i in range(len(t)-1):\n",
    "            if Criterion(t[i]) > 0 and Criterion(t[i+1]) < 0:\n",
    "                print(\"Criterion = 0 at t =\", t[i].item())\n",
    "                break\n",
    "        for i in range(len(t)-1):\n",
    "            if Criterion(t[i]) < 0 and Criterion(t[i+1]) > 0:\n",
    "                print(\"Criterion = 0 at t =\", t[i].item())\n",
    "                break\n",
    "        for A, B in [(1,1)]: #[(1,1), (1,0), (0,1), (0,0)]:\n",
    "            print(f\"Conditioning on A={A}, B={B}\")\n",
    "\n",
    "            x0 = torch.randn(bs, 2).to(\"cuda\")\n",
    "            samples, logw_final, logw_history, sample_history, resample_history = simulate_ace(\n",
    "                x0=x0, v_fn_list=v_fn_list, s_fn_list=s_fn_list, proj_list=proj_list, emb_list=emb_list, sigma_fn=sigma_fn,\n",
    "                v_star= lambda z, t: v3_fn(z, t, torch.full((z.size(0), 1), B, device=z.device)), \n",
    "                t0=0.0, t1=1.0, n_steps=n_steps, device=\"cuda\", ess_threshold=ESS_THRESHOLD, print_resample_history=True,\n",
    "                gamma_list=gamma_list,\n",
    "                d_gamma_list=d_gamma_list,\n",
    "                resample= (Method_name != \"NR\")\n",
    "            )\n",
    "            samples = samples.cpu().numpy()\n",
    "            if seed == 0:\n",
    "                plot_diagnostics(samples, logw_final, logw_history, save_name=f\"{exp_dir}/alpha={names}_{Method_name}_seed{seed}_AB={A}{B}_Bump={BUMP_VALUE}\")\n",
    "                plot_path_trajectories(sample_history, n_frame=6, resample_history=None, experiment_id=exp_dir, name=f\"alpha={names}_{Method_name}_seed{seed}_AB={A}{B}_Bump={BUMP_VALUE}\", deg=-50)\n",
    "                plt.close(); clear_output()\n",
    "\n",
    "            if RUN_METRIC_EVAL:\n",
    "                print(f\"Evaluating {Method_name} for AB={A}{B}\")\n",
    "                samples_gt = ground_truth_hcg(bs, cond_A=A, cond_B=B).numpy()\n",
    "                w1, w2, mmd_rbf, total_var = compute_sample_based_metrics(\n",
    "                    torch.tensor(samples_gt), torch.tensor(samples)\n",
    "                )\n",
    "                results.append([seed, Method_name, w1, w2, mmd_rbf, total_var, A, B, ESS_THRESHOLD])\n",
    "\n",
    "                df = pd.DataFrame(results, columns=[\"seed\", \"method\", \"W1\", \"W2\", \"MMD_RBF\", \"Total Var.\", \"A\", \"B\", \"ESS_Threshold\"])\n",
    "                df.to_csv(f\"{exp_dir}/experiment_results_numseeds{len(seeds)}_bs{bs}_n_steps{n_steps}_ESS{ESS_THRESHOLD}_{names}_ANNEAL={ANNEAL_WEIGHT}_BUMP={BUMP_VALUE}.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb8cde82",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for Bump in []:\n",
    "    plt.figure(figsize=(15,12))\n",
    "    plt.title(f'Criterion C(t) with Bump={Bump} for Invalid Common Cases', fontsize=24)\n",
    "    for names in collapse_combinations:\n",
    "        a1, a2, a3, invalid_interval = names\n",
    "        valid, t, C, invalid_interval = check_condition([interpolants[a1], interpolants[a2], interpolants[a3]], Anneal_weight=1.0, Bump=Bump)\n",
    "        plt.ylim((-20,100))\n",
    "        plt.grid(True)\n",
    "        plt.plot(t.numpy(), C.numpy(), label=f'{name_eq_plot[a1]}, {name_eq_plot[a2]}, {name_eq_plot[a3]}')\n",
    "        plt.legend(loc='upper left', fontsize=12, ncol=3)\n",
    "    plt.xlabel('t', fontsize=24)\n",
    "    plt.ylabel(r\"$C(t)$\", fontsize=24)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_a1a2a3_bump={Bump}.png\"))\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f55eb665",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import re\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import ast\n",
    "\n",
    "# ==========================================\n",
    "# CONFIGURATION\n",
    "# ==========================================\n",
    "PARENT_DIR = exp_dir\n",
    "OUTPUT_FILENAME = \"Rebuttal_Grid_with_Length.png\"\n",
    "plt.rcParams['savefig.dpi'] = 300 # For saved figure resolution\n",
    "\n",
    "# Metrics to plot (Columns)\n",
    "METRICS = ['W1', 'W2', 'MMD_RBF']\n",
    "METRIC_LABELS = {\n",
    "    'W1': r'$W_1$ Distance ($\\downarrow$)',\n",
    "    'W2': r'$W_2$ Distance ($\\downarrow$)',\n",
    "    'MMD_RBF': r'MMD-RBF ($\\downarrow$)'\n",
    "}\n",
    "\n",
    "# LaTeX Mappings for Schedules\n",
    "NAME_EQ_PLOT = {\n",
    "    \"ddpm_linear\": r\"$\\text{DDPM}$\",\n",
    "    \"1-t**2\": r\"$1-t^2$\",\n",
    "    \"sigmoid\": r\"$\\text{Sigmoid}$\",\n",
    "    \"default_linear\": r\"$1-t$\",\n",
    "    \"cos_t\": r\"$\\cos(\\frac{\\pi}{2}t)$\"\n",
    "}\n",
    "\n",
    "# Methods and Colors (Highlight ACE)\n",
    "PALETTE = {\n",
    "    'NR': '#B0B0B0',   # Light Gray\n",
    "    'FKC': '#696969',  # Dark Gray\n",
    "    'ACE': '#D62728'   # Bold Red\n",
    "}\n",
    "\n",
    "# Order of plotting on X-axis\n",
    "METHOD_ORDER = ['NR', 'FKC', 'ACE']\n",
    "\n",
    "# ==========================================\n",
    "# 1. DATA LOADING\n",
    "# ==========================================\n",
    "def load_all_data(parent_dir):\n",
    "    all_data = []\n",
    "    \n",
    "    # 1. ROBUST FOLDER FINDING\n",
    "    try:\n",
    "        subdirs = [\n",
    "            os.path.join(parent_dir, d) for d in os.listdir(parent_dir) \n",
    "            if os.path.isdir(os.path.join(parent_dir, d)) and d.endswith(\"_ESS0.9_Noncollapse\")\n",
    "        ]\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: Parent directory '{parent_dir}' not found.\")\n",
    "        return pd.DataFrame() \n",
    "    \n",
    "    # Sort them based on the float number in the first bracket\n",
    "    def sort_key(path):\n",
    "        match = re.search(r\"\\[([\\d\\.]+)\\]\", os.path.basename(path))\n",
    "        return float(match.group(1)) if match else 0\n",
    "    \n",
    "    subdirs.sort(key=sort_key)\n",
    "\n",
    "    for i, folder_path in enumerate(subdirs):\n",
    "        folder_name = os.path.basename(folder_path)\n",
    "        \n",
    "        # Extract and parse the schedule list for the Case Label\n",
    "        try:\n",
    "            # The folder structure is usually: [sort_val]['s1', 's2', 's3', len]_Visualizations\n",
    "            # We extract the list part between ']' and '_Visualizations'\n",
    "            raw_list_part = folder_name.split(']')[1].split('_ESS0.9_Noncollapse')[0]\n",
    "            \n",
    "            if not raw_list_part.endswith(']'):\n",
    "                raw_list_part += ']'\n",
    "            \n",
    "            # Parse list\n",
    "            schedule_list = ast.literal_eval(raw_list_part)\n",
    "            schedules = schedule_list[:3]\n",
    "            schedule_collapse_length = schedule_list[3] if len(schedule_list) > 3 else \"N/A\"\n",
    "            # convert string to float with 1 decimal place\n",
    "            try:\n",
    "                schedule_collapse_length = f\"{float(schedule_collapse_length)*100:.1f}\"\n",
    "            except:\n",
    "                schedule_collapse_length = \"N/A\"\n",
    "            \n",
    "            # 1. Map to LaTeX (strip existing $ so we can wrap it uniformly)\n",
    "            latex_schedules = [NAME_EQ_PLOT.get(s, s).replace('$', '') for s in schedules]\n",
    "            \n",
    "            # 2. Apply Reordering Logic (User specified: [2, 0, 1])\n",
    "            # Original: [s1, s2, s3] -> New: [s3, s1, s2]\n",
    "            reordered = [latex_schedules[2]] + latex_schedules[:2]\n",
    "            \n",
    "            # 3. Format into 3 lines: alpha^(i)_t = val\n",
    "            formatted_lines = []\n",
    "            for idx, val in enumerate(reordered):\n",
    "                # Construct LaTeX string: $\\alpha^{(i)}_t = val$\n",
    "                # Using raw strings to handle backslashes safely\n",
    "                line = r\"$\\alpha^{(\" + str(idx+1) + r\")}_t = \" + val + r\"$\"\n",
    "                formatted_lines.append(line)\n",
    "            formatted_lines.append(f\"\\nCollapse Duration: {schedule_collapse_length}%\")\n",
    "            case_label = \"\\n\".join(formatted_lines)\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f\"Warning: Could not parse label for {folder_name}, using default. Error: {e}\")\n",
    "            case_label = f\"Case {i+1}\"\n",
    "\n",
    "        # 2. ROBUST CSV FINDING\n",
    "        try:\n",
    "            files_in_folder = os.listdir(folder_path)\n",
    "            csv_files = [\n",
    "                os.path.join(folder_path, f) for f in files_in_folder\n",
    "                if f.startswith(\"experiment_results\") and f.endswith(\".csv\")\n",
    "            ]\n",
    "        except OSError:\n",
    "            print(f\"Could not access folder: {folder_name}\")\n",
    "            continue\n",
    "\n",
    "        if not csv_files:\n",
    "            print(f\"Skipping {folder_name}: No results CSV found.\")\n",
    "            continue\n",
    "            \n",
    "        # Read Data\n",
    "        df = pd.read_csv(csv_files[0])\n",
    "        df['Condition'] = df.apply(lambda row: f\"({int(row['A'])},{int(row['B'])})\", axis=1)\n",
    "        df['Case'] = case_label\n",
    "        df['Case_Index'] = i \n",
    "        all_data.append(df)\n",
    "\n",
    "    if not all_data:\n",
    "        raise ValueError(\"No data found! Check your paths.\")\n",
    "        \n",
    "    return pd.concat(all_data, ignore_index=True)\n",
    "\n",
    "# ==========================================\n",
    "# 2. PLOTTING\n",
    "# ==========================================\n",
    "def create_facet_plot(df):\n",
    "    # --- FONT SETTINGS ---\n",
    "    # Set global font family to serif\n",
    "    plt.rcParams['font.family'] = 'Times New Roman'\n",
    "    sns.set_context(\"paper\", font_scale=2)\n",
    "\n",
    "    cases = df.sort_values('Case_Index')['Case'].unique()\n",
    "    n_cases = len(cases)\n",
    "    n_metrics = len(METRICS)\n",
    "    \n",
    "    # Increase height per case to accommodate 3 lines of text\n",
    "    # Width = 5 * n_metrics, Height = 3.5 * n_cases (was 3.0)\n",
    "    fig, axes = plt.subplots(n_cases, n_metrics, figsize=(5 * n_metrics, 3.5 * n_cases), sharex=True)\n",
    "    \n",
    "    sns.set_style(\"whitegrid\")\n",
    "\n",
    "    print(\"Generating Plots...\")\n",
    "\n",
    "    for row_idx, case in enumerate(cases):\n",
    "        for col_idx, metric in enumerate(METRICS):\n",
    "            \n",
    "            if n_cases > 1 and n_metrics > 1:\n",
    "                ax = axes[row_idx, col_idx]\n",
    "            elif n_cases > 1:\n",
    "                ax = axes[row_idx]\n",
    "            elif n_metrics > 1:\n",
    "                ax = axes[col_idx]\n",
    "            else:\n",
    "                ax = axes\n",
    "\n",
    "            subset = df[df['Case'] == case]\n",
    "            \n",
    "            sns.barplot(\n",
    "                data=subset,\n",
    "                x='Condition',\n",
    "                y=metric,\n",
    "                hue='method',\n",
    "                hue_order=METHOD_ORDER,\n",
    "                palette=PALETTE,\n",
    "                ax=ax,\n",
    "                edgecolor='black',\n",
    "                linewidth=0.5,\n",
    "                errorbar='sd',\n",
    "                capsize=0.1,\n",
    "                err_kws={'linewidth': 1}\n",
    "            )\n",
    "            \n",
    "            # Headers\n",
    "            if row_idx == 0:\n",
    "                ax.set_title(METRIC_LABELS[metric], fontsize=20, fontweight='bold', pad=15)\n",
    "            else:\n",
    "                ax.set_title(\"\")\n",
    "\n",
    "            # Row Labels (Case Names)\n",
    "            if col_idx == 0:\n",
    "                # No textwrap! The string already has newlines.\n",
    "                # Using va='center' to align the 3 lines block with the plot center\n",
    "                # labelpad moves it left.\n",
    "                ax.set_ylabel(case, fontsize=18, rotation=0, labelpad=90, va='center')\n",
    "            else:\n",
    "                ax.set_ylabel(\"\")\n",
    "\n",
    "            if row_idx == n_cases - 1:\n",
    "                ax.set_xlabel(r\"Condition $(1_A,1_B)$\", fontsize=18)\n",
    "            else:\n",
    "                ax.set_xlabel(\"\")\n",
    "\n",
    "            ax.grid(True, axis='y', linestyle='--', alpha=0.6)\n",
    "            sns.despine(ax=ax, left=True)\n",
    "            \n",
    "            if ax.get_legend():\n",
    "                ax.get_legend().remove()\n",
    "\n",
    "    # Global Legend\n",
    "    handles, labels = (axes[0, 0] if n_cases > 1 and n_metrics > 1 else axes[0]).get_legend_handles_labels()\n",
    "    fig.legend(handles, labels, loc='lower center', ncol=3, \n",
    "               bbox_to_anchor=(0.55, -0.01), fontsize=18, frameon=False)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    # Adjust Margins\n",
    "    # Increased left margin (0.2 -> 0.25) for the wider 3-line labels\n",
    "    plt.subplots_adjust(top=0.92, bottom=0.07, left=0.25)\n",
    "    \n",
    "    # Add \"Schedules\" Header\n",
    "    # Placed in the top-left margin area\n",
    "    # x=0.125 is roughly centered in the 0.25 left margin\n",
    "    fig.text(0.17, 0.93, \"Schedules\", fontsize=20, fontweight='bold', ha='center', va='bottom', fontfamily='Times New Roman')\n",
    "    \n",
    "    output_path = os.path.join(PARENT_DIR, OUTPUT_FILENAME)\n",
    "    plt.savefig(output_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\"Success! Plot saved to: {output_path}\")\n",
    "    plt.show()\n",
    "\n",
    "# ==========================================\n",
    "# 3. PRINT STATS TABLE\n",
    "# ==========================================\n",
    "def print_markdown_stats(df):\n",
    "    print(\"\\n\" + \"=\"*80)\n",
    "    print(\" COMPACT SUMMARY STATISTICS (Mean ± Std)\")\n",
    "    print(\"=\"*80 + \"\\n\")\n",
    "    \n",
    "    # Create a clean copy for display\n",
    "    df_disp = df.copy()\n",
    "    \n",
    "    # Clean up the LaTeX labels for the text table (Remove $, \\alpha, newlines)\n",
    "    # We want a single line string for the Case column\n",
    "    def clean_label(s):\n",
    "        # Replace newlines with spaces\n",
    "        s = s.replace('\\n', ' | ')\n",
    "        # Remove LaTeX specific chars for cleaner text output\n",
    "        s = s.replace('$', '').replace('\\\\', '')\n",
    "        return s\n",
    "\n",
    "    df_disp['Case'] = df_disp['Case'].apply(clean_label)\n",
    "\n",
    "    # Group by Case, Condition, Method\n",
    "    grouped = df_disp.groupby(['Case', 'Condition', 'method'])[METRICS].agg(['mean', 'std'])\n",
    "    \n",
    "    # Create output format \"Mean ± Std\"\n",
    "    summary_df = pd.DataFrame(index=grouped.index)\n",
    "    for metric in METRICS:\n",
    "        summary_df[metric] = grouped[metric].apply(\n",
    "            lambda x: f\"{x['mean']:.3f} ± {x['std']:.3f}\", axis=1\n",
    "        )\n",
    "    \n",
    "    summary_df = summary_df.reset_index()\n",
    "    \n",
    "    # Try to use markdown if available, else string\n",
    "    try:\n",
    "        print(summary_df.to_markdown(index=False))\n",
    "    except ImportError:\n",
    "        print(summary_df.to_string(index=False))\n",
    "    \n",
    "    # Also save to CSV for easy copy-paste\n",
    "    stats_path = os.path.join(PARENT_DIR, \"Summary_Stats_Table.csv\")\n",
    "    summary_df.to_csv(stats_path, index=False)\n",
    "    print(f\"\\n[Info] Stats saved to: {stats_path}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    if os.path.exists(PARENT_DIR):\n",
    "        combined_df = load_all_data(PARENT_DIR)\n",
    "        if not combined_df.empty:\n",
    "            create_facet_plot(combined_df)\n",
    "            print_markdown_stats(combined_df)\n",
    "    else:\n",
    "        print(f\"Error: Directory '{PARENT_DIR}' not found.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ac132d9",
   "metadata": {},
   "source": [
    "## Sensitivity to the bump parameter $B$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fa48dcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = get_experiment_dir(f\"ace_demo_runs_{datetime.now().strftime('%Y%m%d')}\", \"figure_6\")\n",
    "print(f\"Experiment directory: {exp_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6277dd52",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Multiple Bumps\n",
    "ANNEAL_WEIGHT = 1.0\n",
    "BUMP_VALUES = [10.0, 20.0, 30.0, 40.0, 50.0, 100.0]\n",
    "RUN_METRIC_EVAL = True\n",
    "names = [\"cos_t\", \"sigmoid\", \"1-t**2\"]\n",
    "print(names)\n",
    "\n",
    "for BUMP_VALUE in BUMP_VALUES:\n",
    "    gamma_list = [\n",
    "        lambda t : torch.tensor(1) * ANNEAL_WEIGHT+ (BUMP_VALUE * t * (1 - t)),\n",
    "        lambda t : torch.tensor(-1) * ANNEAL_WEIGHT ,\n",
    "        lambda t : torch.tensor(1)\n",
    "    ]\n",
    "    d_gamma_list = [\n",
    "        lambda t: torch.zeros_like(t)+ (BUMP_VALUE * (1 - 2*t)),\n",
    "        lambda t: torch.zeros_like(t) ,\n",
    "        lambda t: torch.zeros_like(t)\n",
    "    ]\n",
    "    Criterion = lambda t: sum([ gamma_list[i](t) / (interpolant_schedules[names[i]].alpha_t(t))**2 for i in range(len(names)) ])\n",
    "\n",
    "    print(f\"Criterion(0.99) = {Criterion(torch.tensor(0.99)).item()}\")\n",
    "\n",
    "    t = torch.linspace(0.0, 0.99, 100)\n",
    "    plt.plot(t.numpy(), Criterion(t).numpy(), label=f'Bump={BUMP_VALUE}')\n",
    "\n",
    "heuristic = lambda t : 1 / (interpolant_schedules[names[0]].alpha_t(t))**2\n",
    "t = torch.linspace(0.0, 0.99, 100)\n",
    "plt.plot(t.numpy(), heuristic(t).numpy(), label='Bump Guide', linestyle='--')\n",
    "plt.xlabel('t')\n",
    "plt.ylabel('C(t)')\n",
    "plt.title('Criterion C(t) for Multiple Bumps')\n",
    "plt.grid(True)\n",
    "plt.ylim(-10,40)\n",
    "plt.legend()\n",
    "plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_{names}_ANNEAL={ANNEAL_WEIGHT}_BUMP={BUMP_VALUES}.png\"))\n",
    "plt.rcParams['figure.dpi'] = 300\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3209980b",
   "metadata": {},
   "source": [
    "## Figure C.1: Performance comparison across varying ESS thresholds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aa65e40",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = get_experiment_dir(f\"ace_demo_runs_{datetime.now().strftime('%Y%m%d')}\", \"figure_C1\")\n",
    "print(f\"Experiment directory: {exp_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c038877",
   "metadata": {},
   "outputs": [],
   "source": [
    "names = ['cos_t', 'ddpm_linear', 'default_linear',0]\n",
    "\n",
    "# for names in collapse_combinations:\n",
    "for ESS_THRESHOLD in [0.1,0.3,0.5,0.7,0.9]:\n",
    "    print(f\"ESS Threshold: {ESS_THRESHOLD}\")\n",
    "    bs = 10000; n_steps=1000\n",
    "    seeds = [0,1,2,3,4]\n",
    "    ANNEAL_WEIGHT = 1.0\n",
    "    BUMP_VALUE = 0.0\n",
    "    RUN_METRIC_EVAL = True\n",
    "    results = []\n",
    "\n",
    "    subexperiment_id = f\"{names}\"\n",
    "    experiment_id = f\"{exp_dir}/[{names[3]:.4f}]{subexperiment_id}_ESS{ESS_THRESHOLD}_Noncollapse\"\n",
    "    if not os.path.exists(experiment_id):\n",
    "        os.makedirs(experiment_id)\n",
    "\n",
    "    u_model1 = MLPInstFlexible(z_dim=1, cond_dim=1, width=256, depth=4, output_dim=1).to(device); u_model1.load_state_dict(torch.load(f\"PretrainedToyModels/u_model1_X_given_A_alpha={names[0]}.pth\")); u_model1.eval()\n",
    "    s_model1 = MLPInstFlexible(z_dim=1, cond_dim=1, width=256, depth=4, output_dim=1).to(device); s_model1.load_state_dict(torch.load(f\"PretrainedToyModels/s_model1_X_given_A_alpha={names[0]}.pth\")); s_model1.eval()\n",
    "    u_model2 = MLPInstFlexible(z_dim=2, cond_dim=1, width=256, depth=4, output_dim=2).to(device); u_model2.load_state_dict(torch.load(f\"PretrainedToyModels/u_model2_XY_given_B_alpha={names[2]}.pth\")); u_model2.eval()\n",
    "    s_model2 = MLPInstFlexible(z_dim=2, cond_dim=1, width=256, depth=4, output_dim=2).to(device); s_model2.load_state_dict(torch.load(f\"PretrainedToyModels/s_model2_XY_given_B_alpha={names[2]}.pth\")); s_model2.eval()\n",
    "    u_model3 = MLPInstFlexible(z_dim=1, cond_dim=0, width=256, depth=4, output_dim=1).to(device); u_model3.load_state_dict(torch.load(f\"PretrainedToyModels/u_model3_X_alpha={names[1]}.pth\")); u_model3.eval()\n",
    "    s_model3 = MLPInstFlexible(z_dim=1, cond_dim=0, width=256, depth=4, output_dim=1).to(device); s_model3.load_state_dict(torch.load(f\"PretrainedToyModels/s_model3_X_alpha={names[1]}.pth\")); s_model3.eval()\n",
    "\n",
    "    def v1_fn(x, t, A): return u_model1(x, t, A)\n",
    "    def s1_fn(x, t, A): return s_model1(x, t, A)\n",
    "    def v2_fn(x, t): return u_model3(x, t)\n",
    "    def s2_fn(x, t): return s_model3(x, t)\n",
    "    def v3_fn(z, t, B): return u_model2(z, t, B)\n",
    "    def s3_fn(z, t, B): return s_model2(z, t, B)\n",
    "    def sigma_fn(t): return 0.5 * torch.ones_like(t)\n",
    "\n",
    "    v_fn_list=[\n",
    "            lambda x, t: v1_fn(x[:, :1], t, torch.full((x.size(0), 1), A, device=x.device)), # v1(X|A)\n",
    "            lambda x, t: v2_fn(x[:, :1], t),                                                 # v2(X)\n",
    "            lambda x, t: v3_fn(x, t, torch.full((x.size(0), 1), B, device=x.device))         # v3(Z|B)\n",
    "        ]\n",
    "    s_fn_list=[\n",
    "            lambda x, t: s1_fn(x[:, :1], t, torch.full((x.size(0), 1), A, device=x.device)), # s1(X|A)\n",
    "            lambda x, t: s2_fn(x[:, :1], t),                                                 # s2(X)\n",
    "            lambda x, t: s3_fn(x, t, torch.full((x.size(0), 1), B, device=x.device))         # s3(Z|B)\n",
    "        ]\n",
    "    proj_list=[\n",
    "            lambda z: z[:, :1],    # project to X \n",
    "            lambda z: z[:, :1],    # project to X\n",
    "            lambda z: z            # identity for Z\n",
    "        ]\n",
    "    emb_list=[\n",
    "            lambda x: torch.cat([x, torch.zeros(x.size(0), 1, device=x.device)], dim=1),  # embed X→Z\n",
    "            lambda x: torch.cat([x, torch.zeros(x.size(0), 1, device=x.device)], dim=1),  # embed X→Z\n",
    "            lambda z: z  # identity\n",
    "        ]\n",
    "    print(f\"{names} models loaded.\")\n",
    "\n",
    "\n",
    "    for seed in seeds:\n",
    "        torch.manual_seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        print(f\"Seed set to {seed}\")\n",
    "\n",
    "        for Method_name in [\"NR\", \"ACE\"]: # FKC:\n",
    "            print(f\"Method: {Method_name}\")\n",
    "            if Method_name == \"FKC\" or Method_name == \"NR\":\n",
    "                    print(\"Simulating FKC (Constant Gammas)\")\n",
    "                    gamma_list = [\n",
    "                        lambda t : torch.tensor(1) * ANNEAL_WEIGHT,\n",
    "                        lambda t : torch.tensor(-1) * ANNEAL_WEIGHT,\n",
    "                        lambda t : torch.tensor(1)\n",
    "                    ]\n",
    "                    d_gamma_list = [\n",
    "                        lambda t: torch.zeros_like(t),\n",
    "                        lambda t: torch.zeros_like(t),\n",
    "                        lambda t: torch.zeros_like(t)\n",
    "                    ]\n",
    "            elif Method_name == \"ACE\":\n",
    "                print(\"Simulating ACE (Adaptive Gammas)\")\n",
    "                gamma_list = [\n",
    "                    lambda t : torch.tensor(1) * ANNEAL_WEIGHT + (BUMP_VALUE * t * (1 - t)),\n",
    "                    lambda t : torch.tensor(-1) * ANNEAL_WEIGHT,\n",
    "                    lambda t : torch.tensor(1)\n",
    "                ]\n",
    "                d_gamma_list = [\n",
    "                    lambda t: torch.zeros_like(t) + (BUMP_VALUE * (1 - 2*t)),\n",
    "                    lambda t: torch.zeros_like(t),\n",
    "                    lambda t: torch.zeros_like(t)\n",
    "                ]\n",
    "            if seed == 0:\n",
    "                Criterion = lambda t: sum([ gamma_list[i](t) / (interpolants[names[i]](t))**2 for i in range(len(names)-1) ])\n",
    "                t = torch.linspace(0.0, 0.99, 100)\n",
    "                plt.plot(t.numpy(), Criterion(t).numpy())\n",
    "                plt.xlabel('t')\n",
    "                plt.ylabel('Criterion C(t)')\n",
    "                plt.title('Criterion C(t) vs t')\n",
    "                plt.grid(True)\n",
    "                plt.ylim(-20,100)\n",
    "                plt.savefig(os.path.join(experiment_id, f\"Criterion_plot_{names}_ANNEAL={ANNEAL_WEIGHT}_BUMP={BUMP_VALUE}_Method={Method_name}.png\"))\n",
    "                plt.show()\n",
    "                plt.close()\n",
    "\n",
    "            # print when Criterion = 0\n",
    "            for i in range(len(t)-1):\n",
    "                if Criterion(t[i]) > 0 and Criterion(t[i+1]) < 0:\n",
    "                    print(\"Criterion = 0 at t =\", t[i].item())\n",
    "                    break\n",
    "            for i in range(len(t)-1):\n",
    "                if Criterion(t[i]) < 0 and Criterion(t[i+1]) > 0:\n",
    "                    print(\"Criterion = 0 at t =\", t[i].item())\n",
    "                    break\n",
    "            for A, B in [(1,1)]: #[(1,1), (1,0), (0,1), (0,0)]:\n",
    "                print(f\"Conditioning on A={A}, B={B}\")\n",
    "\n",
    "                x0 = torch.randn(bs, 2).to(\"cuda\")\n",
    "                samples, logw_final, logw_history, sample_history, resample_history = simulate_ace(\n",
    "                    x0=x0, v_fn_list=v_fn_list, s_fn_list=s_fn_list, proj_list=proj_list, emb_list=emb_list, sigma_fn=sigma_fn,\n",
    "                    v_star= lambda z, t: v3_fn(z, t, torch.full((z.size(0), 1), B, device=z.device)), \n",
    "                    t0=0.0, t1=1.0, n_steps=n_steps, device=\"cuda\", ess_threshold=ESS_THRESHOLD, print_resample_history=True,\n",
    "                    gamma_list=gamma_list,\n",
    "                    d_gamma_list=d_gamma_list,\n",
    "                    resample= (Method_name != \"NR\")\n",
    "                )\n",
    "                samples = samples.cpu().numpy()\n",
    "                if seed == 0:\n",
    "                    plot_diagnostics(samples, logw_final, logw_history, save_name=f\"{experiment_id}/alpha={names}_{Method_name}_seed{seed}_AB={A}{B}_Bump={BUMP_VALUE}\")\n",
    "                    plot_path_trajectories(sample_history, n_frame=6, resample_history=None, experiment_id=experiment_id, name=f\"alpha={names}_{Method_name}_seed{seed}_AB={A}{B}_Bump={BUMP_VALUE}\", deg=-50)\n",
    "                    plt.close(); clear_output()\n",
    "\n",
    "                if RUN_METRIC_EVAL:\n",
    "                    print(f\"Evaluating {Method_name} for AB={A}{B}\")\n",
    "                    samples_gt = ground_truth_hcg(bs, cond_A=A, cond_B=B).numpy()\n",
    "                    w1, w2, mmd_rbf, total_var = compute_sample_based_metrics(\n",
    "                        torch.tensor(samples_gt), torch.tensor(samples)\n",
    "                    )\n",
    "                    results.append([seed, Method_name, w1, w2, mmd_rbf, total_var, A, B, ESS_THRESHOLD])\n",
    "\n",
    "                    df = pd.DataFrame(results, columns=[\"seed\", \"method\", \"W1\", \"W2\", \"MMD_RBF\", \"Total Var.\", \"A\", \"B\", \"ESS_Threshold\"])\n",
    "                    df.to_csv(f\"{experiment_id}/experiment_results_numseeds{len(seeds)}_bs{bs}_n_steps{n_steps}_ESS{ESS_THRESHOLD}_{names}_ANNEAL={ANNEAL_WEIGHT}_BUMP={BUMP_VALUE}.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74595e41",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "# !pip install seaborn\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import io\n",
    "\n",
    "# ==========================================\n",
    "# 1. LOAD DATA (You can run the previous cell to generate this data)\n",
    "# ==========================================\n",
    "csv_data = \"\"\"seed,method,W1,W2,MMD_RBF,Total Var.,A,B,ESS_Threshold\n",
    "0,FKC,2.671122276924688,2.958711624716925,1.267066478729248,0.9104000000000001,1,1,0.1\n",
    "0,ACE,0.45108392028256944,0.5294060959511929,0.09096813201904297,0.6539,1,1,0.1\n",
    "0,FKC,3.4695492583910754,3.7952742631138228,1.7368170022964478,0.9193060491493384,1,1,0.3\n",
    "0,ACE,0.2837551553082389,0.39101061972767515,0.026133954524993896,0.6348877853424109,1,1,0.3\n",
    "0,FKC,2.635403127283796,2.9104787290369014,1.4934245347976685,0.9286786924939467,1,1,0.5\n",
    "0,ACE,0.4462359031293904,0.5690063408124068,0.06979155540466309,0.6533240506329114,1,1,0.5\n",
    "0,FKC,2.5580212045630644,2.8336678031973097,1.5015771389007568,0.9128231367169639,1,1,0.7\n",
    "0,ACE,0.3362470547979451,0.4870415253971988,0.032257080078125,0.6351821164232847,1,1,0.7\n",
    "0,FKC,2.383411665109474,2.6509240935734875,2.8717684745788574,0.93428920889537,1,1,0.9\n",
    "0,ACE,0.37758675580423,0.5337254428397956,0.03982508182525635,0.6391539055449713,1,1,0.9\n",
    "1,FKC,4.2284852264504345,4.585312101147793,2.3089380264282227,0.9195476532302596,1,1,0.1\n",
    "1,ACE,0.4316968612763876,0.5378348636978592,0.07085424661636353,0.6533617661816358,1,1,0.1\n",
    "2,FKC,1.930929985235221,2.276854818614506,1.1867969036102295,0.8915554360353439,1,1,0.1\n",
    "2,ACE,0.4297983665602693,0.4962915973295295,0.0813295841217041,0.6556402243589743,1,1,0.1\n",
    "3,FKC,2.004467736415929,2.3474785539773624,1.4649248123168945,0.909593657086224,1,1,0.1\n",
    "3,ACE,0.5242704349986455,0.7009764740865341,0.03785830736160278,0.6377822764552911,1,1,0.1\n",
    "4,FKC,2.7904990279800783,3.064268541298036,1.2418321371078491,0.925,1,1,0.1\n",
    "4,ACE,0.4819212066345991,0.5788725075984056,0.0894121527671814,0.6654519356460533,1,1,0.1\n",
    "1,FKC,4.547946963459682,4.921675030264583,2.8481316566467285,0.9297020820575628,1,1,0.3\n",
    "1,ACE,0.2236679396480153,0.3075014204955778,0.018775463104248047,0.6322165849264338,1,1,0.3\n",
    "2,FKC,2.0499435269531423,2.3533058824504858,1.586472511291504,0.9040960521121201,1,1,0.3\n",
    "2,ACE,0.3237596034523107,0.4695252829295255,0.030008435249328613,0.6353644657863146,1,1,0.3\n",
    "3,FKC,2.0004597852714214,2.344475139888799,1.4729975461959839,0.9182887445887447,1,1,0.3\n",
    "3,ACE,0.25836356265242777,0.3714566058171081,0.020724594593048096,0.6320083083083083,1,1,0.3\n",
    "4,FKC,4.084050908036088,4.450108195686759,2.4135093688964844,0.9208823529411765,1,1,0.3\n",
    "4,ACE,0.24392678933833567,0.32033378496137577,0.021965444087982178,0.6313572864321608,1,1,0.3\n",
    "1,FKC,4.565417691592315,4.939379009426702,2.8578455448150635,0.9235714372346877,1,1,0.5\n",
    "1,ACE,0.347686367670207,0.48201250023863923,0.03843808174133301,0.6384377964575203,1,1,0.5\n",
    "2,FKC,2.0269332292908295,2.370237978907089,1.1330214738845825,0.9078942710560067,1,1,0.5\n",
    "2,ACE,0.37930631694566846,0.5050826542387004,0.04530757665634155,0.6368671679197995,1,1,0.5\n",
    "3,FKC,1.987620273176426,2.3376318451993527,1.4207051992416382,0.9090190766108575,1,1,0.5\n",
    "3,ACE,0.3156302133448867,0.451838641030446,0.02883732318878174,0.6356397111913357,1,1,0.5\n",
    "4,FKC,4.680488438897764,5.0621037885062945,3.112687349319458,0.9286780976220275,1,1,0.5\n",
    "4,ACE,0.2686304438141428,0.3546417513669526,0.0268593430519104,0.6375121121121121,1,1,0.5\n",
    "1,FKC,4.4482112878494116,4.815448088293516,2.644001007080078,0.9214455953016552,1,1,0.7\n",
    "1,ACE,0.26522724981493073,0.36215999959056405,0.022831201553344727,0.6282914891489149,1,1,0.7\n",
    "2,FKC,2.1187640176370395,2.467563381293868,1.0419046878814697,0.8853414075286417,1,1,0.7\n",
    "2,ACE,0.34453142230727757,0.4992698123291965,0.034919679164886475,0.6449338737115982,1,1,0.7\n",
    "3,FKC,1.9767251619131003,2.3259969709536645,1.4283051490783691,0.9098335427742871,1,1,0.7\n",
    "3,ACE,0.27452820674935186,0.3820974920280395,0.023885250091552734,0.6295593890836255,1,1,0.7\n",
    "4,FKC,4.665837171434316,5.047875703377264,3.082986354827881,0.9192302904564315,1,1,0.7\n",
    "4,ACE,0.3195528769562292,0.429236009123164,0.033013999462127686,0.6304146031428285,1,1,0.7\n",
    "1,FKC,4.065684461611906,4.411413775700504,2.0091426372528076,0.9173573382430299,1,1,0.9\n",
    "1,ACE,0.2973141426818991,0.4007880775002392,0.02447342872619629,0.6309641962944417,1,1,0.9\n",
    "2,FKC,1.98012851415699,2.3221128230509764,1.3376667499542236,0.8948727272727273,1,1,0.9\n",
    "2,ACE,0.2671305754103127,0.3840582329066812,0.022841036319732666,0.6349198278450605,1,1,0.9\n",
    "3,FKC,1.9941951706572165,2.335591893648943,1.3994040489196777,0.9083838741396264,1,1,0.9\n",
    "3,ACE,0.23806422919997175,0.3010554298584418,0.021109402179718018,0.6375040844929423,1,1,0.9\n",
    "4,FKC,4.713815317336253,5.092829365058878,3.1451077461242676,0.9236224852071007,1,1,0.9\n",
    "4,ACE,0.29863210080212105,0.39866761067141593,0.027861177921295166,0.6401947817360762,1,1,0.9\n",
    "\"\"\"\n",
    "\n",
    "df = pd.read_csv(io.StringIO(csv_data))\n",
    "\n",
    "# ==========================================\n",
    "# 2. CONFIG & STYLE\n",
    "# ==========================================\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman', 'Times', 'DejaVu Serif', 'serif']\n",
    "sns.set_context(\"paper\", font_scale=1.4)\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "METRIC_LABELS = {\n",
    "    'W1': r'$W_1$ Distance ($\\downarrow$)',\n",
    "    'W2': r'$W_2$ Distance ($\\downarrow$)',\n",
    "    'MMD_RBF': r'MMD-RBF ($\\downarrow$)',\n",
    "    'Total Var.': r'Total Var. ($\\downarrow$)'\n",
    "}\n",
    "\n",
    "PALETTE = {\n",
    "    'FKC': '#696969',  # Dark Gray\n",
    "    'ACE': '#D62728'   # Bold Red\n",
    "}\n",
    "\n",
    "# ==========================================\n",
    "# 3. PLOT GENERATION\n",
    "# ==========================================\n",
    "def plot_ess_sensitivity(df):\n",
    "    # Melt for plotting\n",
    "    df_melted = df.melt(\n",
    "        id_vars=['ESS_Threshold', 'method'], \n",
    "        value_vars=['W1', 'W2', 'MMD_RBF', 'Total Var.'],\n",
    "        var_name='Metric', value_name='Score'\n",
    "    )\n",
    "\n",
    "    # Create Grid\n",
    "    g = sns.FacetGrid(df_melted, col='Metric', hue='method', \n",
    "                      sharey=False, height=4, aspect=1.0, palette=PALETTE)\n",
    "    \n",
    "    g.map(sns.lineplot, 'ESS_Threshold', 'Score', marker='o', linewidth=2.5, markersize=8)\n",
    "\n",
    "    # Customize\n",
    "    for ax, title in zip(g.axes.flat, ['W1', 'W2', 'MMD_RBF', 'Total Var.']):\n",
    "        ax.set_title(METRIC_LABELS[title], fontweight='bold', fontsize=14, pad=15)\n",
    "        ax.set_xlabel(r\"ESS Threshold ($\\tau$)\", fontsize=12)\n",
    "        \n",
    "        # Highlight 0.7 as the chosen operating point\n",
    "        ax.axvline(x=0.7, color='black', linestyle='--', alpha=0.3)\n",
    "\n",
    "    # Add Legend manually to ensure placement\n",
    "    g.add_legend(title=\"Method\", fontsize=12, title_fontsize=12)\n",
    "    \n",
    "    plt.subplots_adjust(top=0.85)\n",
    "    # g.fig.suptitle(\"Sensitivity to Resampling Threshold (ESS)\", \n",
    "                #    fontsize=16, fontweight='bold', fontfamily='serif')\n",
    "    \n",
    "    plt.savefig(\"ESS_Sensitivity_Plot.png\", dpi=300, bbox_inches='tight')\n",
    "    print(\"Plot saved to ESS_Sensitivity_Plot.png\")\n",
    "    plt.show()\n",
    "\n",
    "# ==========================================\n",
    "# 4. LATEX TABLE GENERATION\n",
    "# ==========================================\n",
    "def generate_latex_table(df):\n",
    "    # Pivot to get side-by-side columns for methods\n",
    "    pivot_df = df.pivot(index='ESS_Threshold', columns='method', values=['W1', 'W2', 'MMD_RBF', 'Total Var.'])\n",
    "    \n",
    "    print(\"\\n=== LaTeX Table Code ===\")\n",
    "    print(r\"\\begin{table}[h]\")\n",
    "    print(r\"\\centering\")\n",
    "    print(r\"\\caption{Performance comparison across varying ESS Thresholds. \\textbf{Bold} indicates the best method.}\")\n",
    "    print(r\"\\label{tab:ess_ablation}\")\n",
    "    print(r\"\\begin{tabular}{c|cc|cc|cc}\")\n",
    "    print(r\"\\toprule\")\n",
    "    print(r\"ESS & \\multicolumn{2}{c|}{$W_1 (\\downarrow)$} & \\multicolumn{2}{c|}{$W_2 (\\downarrow)$} & \\multicolumn{2}{c}{MMD (\\downarrow)$} \\\\\")\n",
    "    print(r\"Threshold & FKC & ACE & FKC & ACE & FKC & ACE \\\\\")\n",
    "    print(r\"\\midrule\")\n",
    "\n",
    "    for ess in pivot_df.index:\n",
    "        row = pivot_df.loc[ess]\n",
    "        line = f\"{ess:.1f}\"\n",
    "        \n",
    "        # We only showing W1, W2, MMD for brevity in table (TV is often omitted in main text tables)\n",
    "        # You can add 'Total Var.' back if needed\n",
    "        for metric in ['W1', 'W2', 'MMD_RBF']:\n",
    "            fkc_val = row[(metric, 'FKC')]\n",
    "            ace_val = row[(metric, 'ACE')]\n",
    "            \n",
    "            fkc_str = f\"{fkc_val:.3f}\"\n",
    "            ace_str = f\"{ace_val:.3f}\"\n",
    "            \n",
    "            if fkc_val < ace_val:\n",
    "                fkc_str = f\"\\\\textbf{{{fkc_str}}}\"\n",
    "            else:\n",
    "                ace_str = f\"\\\\textbf{{{ace_str}}}\"\n",
    "            \n",
    "            line += f\" & {fkc_str} & {ace_str}\"\n",
    "        \n",
    "        line += r\" \\\\\"\n",
    "        print(line)\n",
    "        \n",
    "    print(r\"\\bottomrule\")\n",
    "    print(r\"\\end{tabular}\")\n",
    "    print(r\"\\end{table}\")\n",
    "\n",
    "# Run\n",
    "plot_ess_sensitivity(df)\n",
    "# generate_latex_table(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ca37ac0",
   "metadata": {},
   "source": [
    "## Figure E.5: Path exitence criterion $C(t)$ across various schedule combinations\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c594f37b",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = get_experiment_dir(f\"ace_demo_runs_{datetime.now().strftime('%Y%m%d')}\", \"figure_E5\")\n",
    "print(f\"Experiment directory: {exp_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9931476",
   "metadata": {},
   "outputs": [],
   "source": [
    "ANNEAL_WEIGHT = 1.0 #15\n",
    "\n",
    "\n",
    "model_path = \"PretrainedToyModels\"\n",
    "interpolants = load_interpolants_from_json_alpha_only(\"ace_lib/interpolant_schedules.json\")\n",
    "names = list(interpolants.keys())\n",
    "\n",
    "t = torch.linspace(0.0, 1.0, 100)\n",
    "\n",
    "name_eq = {\n",
    "    \"ddpm_linear\": r\"$\\alpha_t = \\text{DDPM}$\",\n",
    "    \"1-t**2\": r\"$\\alpha_t = 1-t^2$\",\n",
    "    \"sigmoid\": r\"$\\alpha_t = \\text{Sigmoid}$\",\n",
    "    \"default_linear\": r\"$\\alpha_t = 1-t$\",\n",
    "    \"cos_t\": r\"$\\alpha_t=\\cos(\\frac{\\pi}{2}t)$\"\n",
    "}\n",
    "\n",
    "def check_condition(alpha_funcs, n_grid=200, Bump = 0.0, Anneal_weight=1.0):\n",
    "    \"\"\"Check sign conditions for [a1, a2, a3]. a1 a3 / a2 and anneal weight applies to (a1 / a2)^w a3\"\"\"\n",
    "    ts = torch.linspace(0.0, 0.99, n_grid)\n",
    "\n",
    "    if n_grid > 1:\n",
    "        dt = ts[1] - ts[0]\n",
    "    else:\n",
    "        dt = torch.tensor(0.0) \n",
    "\n",
    "    alphas = [f(ts) for f in alpha_funcs]\n",
    "    Bumps = torch.tensor([Bump * t * (1-t) for t in ts])\n",
    "    if len(alpha_funcs) == 3:\n",
    "        C = (Anneal_weight + Bumps) / (alphas[0]**2 + 1e-12) - Anneal_weight / (alphas[1]**2 + 1e-12) + 1 / (alphas[2]**2 + 1e-12)\n",
    "    else:\n",
    "        raise ValueError(\"Only supports 3 schedules\")\n",
    "\n",
    "    total_negative_length = 0.0\n",
    "    if n_grid > 1:\n",
    "        negative_intervals = C[:-1] < 0\n",
    "        total_negative_length_tensor = torch.sum(negative_intervals.float()) * dt\n",
    "        total_negative_length = total_negative_length_tensor.item()\n",
    "\n",
    "    return (C.min() < 0), ts, C, total_negative_length\n",
    "\n",
    "name_eq_plot = {\n",
    "    \"ddpm_linear\": r\"$\\text{DDPM}$\",\n",
    "    \"1-t**2\": r\"$1-t^2$\",\n",
    "    \"sigmoid\": r\"$\\text{Sigmoid}$\",\n",
    "    \"default_linear\": r\"$1-t$\",\n",
    "    \"cos_t\": r\"$\\cos(\\frac{\\pi}{2}t)$\"\n",
    "}\n",
    "\n",
    "\n",
    "def find_valid_combinations(interpolants, Anneal_weight=1.0, Bump=0.0, unique=True):\n",
    "    collapse_combinations = []\n",
    "\n",
    "    for a1, a2, a3 in itertools.product(interpolants, repeat=3):\n",
    "        valid, t, C, total_negative_length = check_condition([interpolants[a1], interpolants[a2], interpolants[a3]], Anneal_weight=Anneal_weight, Bump=Bump)\n",
    "        if valid:\n",
    "            collapse_combinations.append([a1, a2, a3, total_negative_length])\n",
    "    collapse_combinations.sort(key=lambda x: x[3], reverse=True)\n",
    "    if unique:\n",
    "        unique_combinations = []\n",
    "        seen = set()\n",
    "        for combo in collapse_combinations:\n",
    "            identifier = combo[3]\n",
    "            if identifier not in seen:\n",
    "                unique_combinations.append(combo)\n",
    "                seen.add(identifier)\n",
    "        collapse_combinations = unique_combinations # Remove duplicates if total_negative_length is the same\n",
    "    return collapse_combinations\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "collapse_combinations = find_valid_combinations(interpolants, Anneal_weight=ANNEAL_WEIGHT, Bump=0.0, unique=False)\n",
    "print(f\"There are {len(collapse_combinations)} combinations that have path collapse:\")\n",
    "with open(os.path.join(exp_dir, \"non_unique_collapse_combinations.json\"), \"w\") as f:\n",
    "    json.dump(collapse_combinations, f, indent=4)\n",
    "\n",
    "from tqdm import tqdm\n",
    "plt.figure(figsize=(15,12))\n",
    "plt.title(f'Criterion C(t) with ACE Correction Three-Expert Composition with w={ANNEAL_WEIGHT}', fontsize=24)\n",
    "for names in tqdm(collapse_combinations):\n",
    "    a1, a2, a3, invalid_interval = names\n",
    "    valid = False\n",
    "    Bump = 50.0\n",
    "    while not valid:\n",
    "        valid, t, C, invalid_interval = check_condition([interpolants[a1], interpolants[a2], interpolants[a3]], Anneal_weight=ANNEAL_WEIGHT, Bump=Bump)\n",
    "        exceed_indices = torch.where(C > 200)[0]\n",
    "        if len(exceed_indices) > 0:\n",
    "            first_exceed_index = exceed_indices[0]\n",
    "            C = C[:first_exceed_index]\n",
    "            t = t[:first_exceed_index]\n",
    "        valid = C.min() >= 0\n",
    "        Bump = Bump + 70\n",
    "    plt.ylim((-20,100))\n",
    "    plt.grid(True)\n",
    "    plt.plot(t.numpy(), C.numpy(), label=f'{name_eq_plot[a1]}, {name_eq_plot[a2]}, {name_eq_plot[a3]}')\n",
    "    plt.legend(loc='upper left', fontsize=12, ncol=4)\n",
    "plt.xlabel('t', fontsize=24)\n",
    "plt.ylabel(r\"$C(t)$\", fontsize=24)\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_ACE_w={ANNEAL_WEIGHT}.png\"))\n",
    "plt.show()\n",
    "plt.close()\n",
    "clear_output()\n",
    "\n",
    "plt.figure(figsize=(15,12))\n",
    "plt.title(f'Criterion C(t) with Constant Exponents for Three-Expert Composition with w={ANNEAL_WEIGHT}', fontsize=24)\n",
    "for names in tqdm(collapse_combinations):\n",
    "    a1, a2, a3, invalid_interval = names\n",
    "    valid = False\n",
    "    valid, t, C, invalid_interval = check_condition([interpolants[a1], interpolants[a2], interpolants[a3]], Anneal_weight=ANNEAL_WEIGHT, Bump=0.0)\n",
    "    plt.ylim((-20,100))\n",
    "    plt.grid(True)\n",
    "    plt.plot(t.numpy(), C.numpy(), label=f'{name_eq_plot[a1]}, {name_eq_plot[a2]}, {name_eq_plot[a3]}')\n",
    "    plt.legend(loc='upper left', fontsize=12, ncol=4)\n",
    "plt.xlabel('t', fontsize=24)\n",
    "plt.ylabel(r\"$C(t)$\", fontsize=24)\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(exp_dir, f\"Criterion_plot_FKC_w={ANNEAL_WEIGHT}.png\"))\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56bcb5c1",
   "metadata": {},
   "source": [
    "## Figure E.6: Quantitative evaluation of Heterogeneous Ratio-of-Densities Sampling\n",
    "\n",
    "Using the generated plots, we can can also make Figures E.7~E.11: Visualization of generative trajectories and final samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b46698d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The following loop may take a long time to run for all collapse combinations and seeds. \n",
    "# We report five randomly selected collapse combinations in our paper:\n",
    "collapse_combinations = [[\"ddpm_linear\", \"default_linear\", \"ddpm_linear\", 7.5],\n",
    "                         [\"cos_t\", \"default_linear\", \"ddpm_linear\", 8.9],\n",
    "                         [\"ddpm_linear\", \"1-t**2\", \"default_linear\", 10.9],\n",
    "                         [\"cos_t\", \"sigmoid\", \"cos_t\", 11.4],\n",
    "                         [\"cos_t\", \"default_linear\", \"cos_t\", 48.8]]\n",
    "\n",
    "for names in collapse_combinations:\n",
    "    bs = 10000; n_steps=1000\n",
    "    seeds = [0,1,2,3,4]\n",
    "    ESS_THRESHOLD = 0.7\n",
    "    ANNEAL_WEIGHT = 1.0\n",
    "    BUMP_VALUE = 30.0\n",
    "    RUN_METRIC_EVAL = True\n",
    "    results = []\n",
    "\n",
    "    subexperiment_id = f\"{names}\"\n",
    "    experiment_id = f\"{exp_dir}/[{names[3]:.4f}]{subexperiment_id}_Visualizations\"\n",
    "    if not os.path.exists(experiment_id):\n",
    "        os.makedirs(experiment_id)\n",
    "\n",
    "    u_model1 = MLPInstFlexible(z_dim=1, cond_dim=1, width=256, depth=4, output_dim=1).to(device); u_model1.load_state_dict(torch.load(f\"{model_path}/u_model1_X_given_A_alpha={names[0]}.pth\")); u_model1.eval()\n",
    "    s_model1 = MLPInstFlexible(z_dim=1, cond_dim=1, width=256, depth=4, output_dim=1).to(device); s_model1.load_state_dict(torch.load(f\"{model_path}/s_model1_X_given_A_alpha={names[0]}.pth\")); s_model1.eval()\n",
    "    u_model2 = MLPInstFlexible(z_dim=2, cond_dim=1, width=256, depth=4, output_dim=2).to(device); u_model2.load_state_dict(torch.load(f\"{model_path}/u_model2_XY_given_B_alpha={names[2]}.pth\")); u_model2.eval()\n",
    "    s_model2 = MLPInstFlexible(z_dim=2, cond_dim=1, width=256, depth=4, output_dim=2).to(device); s_model2.load_state_dict(torch.load(f\"{model_path}/s_model2_XY_given_B_alpha={names[2]}.pth\")); s_model2.eval()\n",
    "    u_model3 = MLPInstFlexible(z_dim=1, cond_dim=0, width=256, depth=4, output_dim=1).to(device); u_model3.load_state_dict(torch.load(f\"{model_path}/u_model3_X_alpha={names[1]}.pth\")); u_model3.eval()\n",
    "    s_model3 = MLPInstFlexible(z_dim=1, cond_dim=0, width=256, depth=4, output_dim=1).to(device); s_model3.load_state_dict(torch.load(f\"{model_path}/s_model3_X_alpha={names[1]}.pth\")); s_model3.eval()\n",
    "\n",
    "    def v1_fn(x, t, A): return u_model1(x, t, A)\n",
    "    def s1_fn(x, t, A): return s_model1(x, t, A)\n",
    "    def v2_fn(x, t): return u_model3(x, t)\n",
    "    def s2_fn(x, t): return s_model3(x, t)\n",
    "    def v3_fn(z, t, B): return u_model2(z, t, B)\n",
    "    def s3_fn(z, t, B): return s_model2(z, t, B)\n",
    "    def sigma_fn(t): return 0.5 * torch.ones_like(t)\n",
    "\n",
    "    v_fn_list=[\n",
    "            lambda x, t: v1_fn(x[:, :1], t, torch.full((x.size(0), 1), A, device=x.device)), # v1(X|A)\n",
    "            lambda x, t: v2_fn(x[:, :1], t),                                                 # v2(X)\n",
    "            lambda x, t: v3_fn(x, t, torch.full((x.size(0), 1), B, device=x.device))         # v3(Z|B)\n",
    "        ]\n",
    "    s_fn_list=[\n",
    "            lambda x, t: s1_fn(x[:, :1], t, torch.full((x.size(0), 1), A, device=x.device)), # s1(X|A)\n",
    "            lambda x, t: s2_fn(x[:, :1], t),                                                 # s2(X)\n",
    "            lambda x, t: s3_fn(x, t, torch.full((x.size(0), 1), B, device=x.device))         # s3(Z|B)\n",
    "        ]\n",
    "    proj_list=[\n",
    "            lambda z: z[:, :1],    # project to X \n",
    "            lambda z: z[:, :1],    # project to X\n",
    "            lambda z: z            # identity for Z\n",
    "        ]\n",
    "    emb_list=[\n",
    "            lambda x: torch.cat([x, torch.zeros(x.size(0), 1, device=x.device)], dim=1),  # embed X→Z\n",
    "            lambda x: torch.cat([x, torch.zeros(x.size(0), 1, device=x.device)], dim=1),  # embed X→Z\n",
    "            lambda z: z  # identity\n",
    "        ]\n",
    "    print(f\"{names} models loaded.\")\n",
    "\n",
    "\n",
    "    for seed in seeds:\n",
    "        torch.manual_seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        print(f\"Seed set to {seed}\")\n",
    "\n",
    "        for Method_name in [\"FKC\", \"ACE\"]:\n",
    "            print(f\"Method: {Method_name}\")\n",
    "            if Method_name == \"FKC\":\n",
    "                    print(\"Simulating FKC (Constant Gammas)\")\n",
    "                    gamma_list = [\n",
    "                        lambda t : torch.tensor(1) * ANNEAL_WEIGHT,\n",
    "                        lambda t : torch.tensor(-1) * ANNEAL_WEIGHT,\n",
    "                        lambda t : torch.tensor(1)\n",
    "                    ]\n",
    "                    d_gamma_list = [\n",
    "                        lambda t: torch.zeros_like(t),\n",
    "                        lambda t: torch.zeros_like(t),\n",
    "                        lambda t: torch.zeros_like(t)\n",
    "                    ]\n",
    "            elif Method_name == \"ACE\":\n",
    "                print(\"Simulating ACE (Adaptive Gammas)\")\n",
    "                gamma_list = [\n",
    "                    lambda t : torch.tensor(1) * ANNEAL_WEIGHT + (BUMP_VALUE * t * (1 - t)),\n",
    "                    lambda t : torch.tensor(-1) * ANNEAL_WEIGHT,\n",
    "                    lambda t : torch.tensor(1)\n",
    "                ]\n",
    "                d_gamma_list = [\n",
    "                    lambda t: torch.zeros_like(t) + (BUMP_VALUE * (1 - 2*t)),\n",
    "                    lambda t: torch.zeros_like(t),\n",
    "                    lambda t: torch.zeros_like(t)\n",
    "                ]\n",
    "\n",
    "            for A, B in [(1,1), (1,0), (0,1), (0,0)]:\n",
    "                print(f\"Conditioning on A={A}, B={B}\")\n",
    "\n",
    "                x0 = torch.randn(bs, 2).to(\"cuda\")\n",
    "                samples, logw_final, logw_history, sample_history, resample_history = simulate_ace(\n",
    "                    x0=x0, v_fn_list=v_fn_list, s_fn_list=s_fn_list, proj_list=proj_list, emb_list=emb_list, sigma_fn=sigma_fn,\n",
    "                    v_star= lambda z, t: v3_fn(z, t, torch.full((z.size(0), 1), B, device=z.device)), \n",
    "                    t0=0.0, t1=1.0, n_steps=n_steps, device=\"cuda\", ess_threshold=ESS_THRESHOLD, print_resample_history=True,\n",
    "                    gamma_list=gamma_list,\n",
    "                    d_gamma_list=d_gamma_list,\n",
    "                    resample=True\n",
    "                )\n",
    "                samples = samples.cpu().numpy()\n",
    "                if seed == 0:\n",
    "                    plot_diagnostics(samples, logw_final, logw_history, save_name=f\"{experiment_id}/alpha={names}_{Method_name}_seed{seed}_AB={A}{B}_Bump={BUMP_VALUE}\")\n",
    "                    plot_path_trajectories(sample_history, n_frame=6, resample_history=None, experiment_id=experiment_id, name=f\"alpha={names}_{Method_name}_seed{seed}_AB={A}{B}_Bump={BUMP_VALUE}\", deg=-50)\n",
    "                    plt.close(); clear_output()\n",
    "\n",
    "                if RUN_METRIC_EVAL:\n",
    "                    print(f\"Evaluating {Method_name} for AB={A}{B}\")\n",
    "                    samples_gt = ground_truth_hcg(bs, cond_A=A, cond_B=B).numpy()\n",
    "                    w1, w2, mmd_rbf, total_var = compute_sample_based_metrics(\n",
    "                        torch.tensor(samples_gt), torch.tensor(samples)\n",
    "                    )\n",
    "                    results.append([seed, Method_name, w1, w2, mmd_rbf, total_var, A, B])\n",
    "\n",
    "                    df = pd.DataFrame(results, columns=[\"seed\", \"method\", \"W1\", \"W2\", \"MMD_RBF\", \"Total Var.\", \"A\", \"B\"])\n",
    "                    df.to_csv(f\"{experiment_id}/experiment_results_numseeds{len(seeds)}_bs{bs}_n_steps{n_steps}_ESS{ESS_THRESHOLD}_{names}_ANNEAL={ANNEAL_WEIGHT}_BUMP={BUMP_VALUE}.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea8b218",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import re\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import ast\n",
    "\n",
    "# ==========================================\n",
    "# CONFIGURATION\n",
    "# ==========================================\n",
    "PARENT_DIR = exp_dir\n",
    "OUTPUT_FILENAME = \"fig_e6.png\"\n",
    "plt.rcParams['savefig.dpi'] = 300 # For saved figure resolution\n",
    "\n",
    "# Metrics to plot (Columns)\n",
    "METRICS = ['W1', 'W2', 'MMD_RBF']\n",
    "METRIC_LABELS = {\n",
    "    'W1': r'$W_1$ Distance ($\\downarrow$)',\n",
    "    'W2': r'$W_2$ Distance ($\\downarrow$)',\n",
    "    'MMD_RBF': r'MMD-RBF ($\\downarrow$)'\n",
    "}\n",
    "\n",
    "# LaTeX Mappings for Schedules\n",
    "NAME_EQ_PLOT = {\n",
    "    \"ddpm_linear\": r\"$\\text{DDPM}$\",\n",
    "    \"1-t**2\": r\"$1-t^2$\",\n",
    "    \"sigmoid\": r\"$\\text{Sigmoid}$\",\n",
    "    \"default_linear\": r\"$1-t$\",\n",
    "    \"cos_t\": r\"$\\cos(\\frac{\\pi}{2}t)$\"\n",
    "}\n",
    "\n",
    "# Methods and Colors (Highlight ACE)\n",
    "PALETTE = {\n",
    "    'NR': '#B0B0B0',   # Light Gray\n",
    "    'FKC': '#696969',  # Dark Gray\n",
    "    'ACE': '#D62728'   # Bold Red\n",
    "}\n",
    "\n",
    "# Order of plotting on X-axis\n",
    "METHOD_ORDER = ['NR', 'FKC', 'ACE']\n",
    "\n",
    "# ==========================================\n",
    "# 1. DATA LOADING\n",
    "# ==========================================\n",
    "def load_all_data(parent_dir):\n",
    "    all_data = []\n",
    "    \n",
    "    # 1. ROBUST FOLDER FINDING\n",
    "    try:\n",
    "        subdirs = [\n",
    "            os.path.join(parent_dir, d) for d in os.listdir(parent_dir) \n",
    "            if os.path.isdir(os.path.join(parent_dir, d)) and d.endswith(\"_Visualizations\")\n",
    "        ]\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: Parent directory '{parent_dir}' not found.\")\n",
    "        return pd.DataFrame() \n",
    "    \n",
    "    # Sort them based on the float number in the first bracket\n",
    "    def sort_key(path):\n",
    "        match = re.search(r\"\\[([\\d\\.]+)\\]\", os.path.basename(path))\n",
    "        return float(match.group(1)) if match else 0\n",
    "    \n",
    "    subdirs.sort(key=sort_key)\n",
    "\n",
    "    for i, folder_path in enumerate(subdirs):\n",
    "        folder_name = os.path.basename(folder_path)\n",
    "        \n",
    "        # Extract and parse the schedule list for the Case Label\n",
    "        try:\n",
    "            # The folder structure is usually: [sort_val]['s1', 's2', 's3', len]_Visualizations\n",
    "            # We extract the list part between ']' and '_Visualizations'\n",
    "            raw_list_part = folder_name.split(']')[1].split('_Visualizations')[0]\n",
    "            \n",
    "            if not raw_list_part.endswith(']'):\n",
    "                raw_list_part += ']'\n",
    "            \n",
    "            # Parse list\n",
    "            schedule_list = ast.literal_eval(raw_list_part)\n",
    "            schedules = schedule_list[:3]\n",
    "            schedule_collapse_length = schedule_list[3] if len(schedule_list) > 3 else \"N/A\"\n",
    "            # convert string to float with 1 decimal place\n",
    "            try:\n",
    "                schedule_collapse_length = f\"{float(schedule_collapse_length)*100:.1f}\"\n",
    "            except:\n",
    "                schedule_collapse_length = \"N/A\"\n",
    "            \n",
    "            # 1. Map to LaTeX (strip existing $ so we can wrap it uniformly)\n",
    "            latex_schedules = [NAME_EQ_PLOT.get(s, s).replace('$', '') for s in schedules]\n",
    "            \n",
    "            # 2. Apply Reordering Logic (User specified: [2, 0, 1])\n",
    "            # Original: [s1, s2, s3] -> New: [s3, s1, s2]\n",
    "            reordered = [latex_schedules[2]] + latex_schedules[:2]\n",
    "            \n",
    "            # 3. Format into 3 lines: alpha^(i)_t = val\n",
    "            formatted_lines = []\n",
    "            for idx, val in enumerate(reordered):\n",
    "                # Construct LaTeX string: $\\alpha^{(i)}_t = val$\n",
    "                # Using raw strings to handle backslashes safely\n",
    "                line = r\"$\\alpha^{(\" + str(idx+1) + r\")}_t = \" + val + r\"$\"\n",
    "                formatted_lines.append(line)\n",
    "            formatted_lines.append(f\"\\nCollapse Duration: {schedule_collapse_length}%\")\n",
    "            case_label = \"\\n\".join(formatted_lines)\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f\"Warning: Could not parse label for {folder_name}, using default. Error: {e}\")\n",
    "            case_label = f\"Case {i+1}\"\n",
    "\n",
    "        # 2. ROBUST CSV FINDING\n",
    "        try:\n",
    "            files_in_folder = os.listdir(folder_path)\n",
    "            csv_files = [\n",
    "                os.path.join(folder_path, f) for f in files_in_folder\n",
    "                if f.startswith(\"experiment_results\") and f.endswith(\".csv\")\n",
    "            ]\n",
    "        except OSError:\n",
    "            print(f\"Could not access folder: {folder_name}\")\n",
    "            continue\n",
    "\n",
    "        if not csv_files:\n",
    "            print(f\"Skipping {folder_name}: No results CSV found.\")\n",
    "            continue\n",
    "            \n",
    "        # Read Data\n",
    "        df = pd.read_csv(csv_files[0])\n",
    "        df['Condition'] = df.apply(lambda row: f\"({int(row['A'])},{int(row['B'])})\", axis=1)\n",
    "        df['Case'] = case_label\n",
    "        df['Case_Index'] = i \n",
    "        all_data.append(df)\n",
    "\n",
    "    if not all_data:\n",
    "        raise ValueError(\"No data found! Check your paths.\")\n",
    "        \n",
    "    return pd.concat(all_data, ignore_index=True)\n",
    "\n",
    "# ==========================================\n",
    "# 2. PLOTTING\n",
    "# ==========================================\n",
    "def create_facet_plot(df):\n",
    "    # --- FONT SETTINGS ---\n",
    "    # Set global font family to serif\n",
    "    plt.rcParams['font.family'] = 'Times New Roman'\n",
    "    sns.set_context(\"paper\", font_scale=2)\n",
    "\n",
    "    cases = df.sort_values('Case_Index')['Case'].unique()\n",
    "    n_cases = len(cases)\n",
    "    n_metrics = len(METRICS)\n",
    "    \n",
    "    # Increase height per case to accommodate 3 lines of text\n",
    "    # Width = 5 * n_metrics, Height = 3.5 * n_cases (was 3.0)\n",
    "    fig, axes = plt.subplots(n_cases, n_metrics, figsize=(5 * n_metrics, 3.5 * n_cases), sharex=True)\n",
    "    \n",
    "    sns.set_style(\"whitegrid\")\n",
    "\n",
    "    print(\"Generating Plots...\")\n",
    "\n",
    "    for row_idx, case in enumerate(cases):\n",
    "        for col_idx, metric in enumerate(METRICS):\n",
    "            \n",
    "            if n_cases > 1 and n_metrics > 1:\n",
    "                ax = axes[row_idx, col_idx]\n",
    "            elif n_cases > 1:\n",
    "                ax = axes[row_idx]\n",
    "            elif n_metrics > 1:\n",
    "                ax = axes[col_idx]\n",
    "            else:\n",
    "                ax = axes\n",
    "\n",
    "            subset = df[df['Case'] == case]\n",
    "            \n",
    "            sns.barplot(\n",
    "                data=subset,\n",
    "                x='Condition',\n",
    "                y=metric,\n",
    "                hue='method',\n",
    "                hue_order=METHOD_ORDER,\n",
    "                palette=PALETTE,\n",
    "                ax=ax,\n",
    "                edgecolor='black',\n",
    "                linewidth=0.5,\n",
    "                errorbar='sd',\n",
    "                capsize=0.1,\n",
    "                err_kws={'linewidth': 1}\n",
    "            )\n",
    "            \n",
    "            # Headers\n",
    "            if row_idx == 0:\n",
    "                ax.set_title(METRIC_LABELS[metric], fontsize=20, fontweight='bold', pad=15)\n",
    "            else:\n",
    "                ax.set_title(\"\")\n",
    "\n",
    "            # Row Labels (Case Names)\n",
    "            if col_idx == 0:\n",
    "                # No textwrap! The string already has newlines.\n",
    "                # Using va='center' to align the 3 lines block with the plot center\n",
    "                # labelpad moves it left.\n",
    "                ax.set_ylabel(case, fontsize=18, rotation=0, labelpad=90, va='center')\n",
    "            else:\n",
    "                ax.set_ylabel(\"\")\n",
    "\n",
    "            if row_idx == n_cases - 1:\n",
    "                ax.set_xlabel(r\"Condition $(1_A,1_B)$\", fontsize=18)\n",
    "            else:\n",
    "                ax.set_xlabel(\"\")\n",
    "\n",
    "            ax.grid(True, axis='y', linestyle='--', alpha=0.6)\n",
    "            sns.despine(ax=ax, left=True)\n",
    "            \n",
    "            if ax.get_legend():\n",
    "                ax.get_legend().remove()\n",
    "\n",
    "    # Global Legend\n",
    "    handles, labels = (axes[0, 0] if n_cases > 1 and n_metrics > 1 else axes[0]).get_legend_handles_labels()\n",
    "    fig.legend(handles, labels, loc='lower center', ncol=3, \n",
    "               bbox_to_anchor=(0.55, -0.01), fontsize=18, frameon=False)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    # Adjust Margins\n",
    "    # Increased left margin (0.2 -> 0.25) for the wider 3-line labels\n",
    "    plt.subplots_adjust(top=0.92, bottom=0.07, left=0.25)\n",
    "    \n",
    "    # Add \"Schedules\" Header\n",
    "    # Placed in the top-left margin area\n",
    "    # x=0.125 is roughly centered in the 0.25 left margin\n",
    "    fig.text(0.17, 0.93, \"Schedules\", fontsize=20, fontweight='bold', ha='center', va='bottom', fontfamily='Times New Roman')\n",
    "    \n",
    "    output_path = os.path.join(PARENT_DIR, OUTPUT_FILENAME)\n",
    "    plt.savefig(output_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\"Success! Plot saved to: {output_path}\")\n",
    "    plt.show()\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    if os.path.exists(PARENT_DIR):\n",
    "        combined_df = load_all_data(PARENT_DIR)\n",
    "        if not combined_df.empty:\n",
    "            create_facet_plot(combined_df)\n",
    "    else:\n",
    "        print(f\"Error: Directory '{PARENT_DIR}' not found.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ace_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
