{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49ae70c8",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "η_QK/η_OV=0.01: tau_end=2.55e+06, mu_OV 78.797->81.669, mu_QK 0.521->0.559, alpha 0.144->0.149, loss 2.215e-04->1.000e-04\n",
      "η_QK/η_OV=0.1: tau_end=4.9e+06, mu_OV 41.216->42.282, mu_QK 1.334->1.395, alpha 0.275->0.287, loss 2.253e-04->1.000e-04\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.font_manager as fm\n",
    "\n",
    "# ============================================================\n",
    "#                     PARAMETERS\n",
    "# ============================================================\n",
    "m, n, b, M = 5, 50, 50, 20       # synthetic setting\n",
    "eta_OV = 1.0                         # fix eta_OV and sweep eta_QK\n",
    "ratios = [0.01, 0.1, 1, 10, 100]      # eta_QK / eta_OV\n",
    "\n",
    "# baseline horizon guess (auto-extend if needed)\n",
    "T_tau_init = 2e5\n",
    "\n",
    "dt_tau = 10                        # RK4 step in tau (smaller => more accurate, slower)\n",
    "\n",
    "EXP_CLIP = 60\n",
    "EPS = 1e-12\n",
    "\n",
    "# target: make sure we run until each ratio hits this loss\n",
    "TARGET_LOSS = 1e-4\n",
    "MAX_T_TAU = 5e7                      # hard cap to avoid infinite runs\n",
    "\n",
    "# Stopping-loss thresholds requested (linear x-axis)\n",
    "\n",
    "L_max =1\n",
    "L_min = 1e-4\n",
    "stop_thresholds = np.logspace(np.log10(L_max), np.log10(L_min), 50)\n",
    "\n",
    "# Plot styling\n",
    "LABEL_FONTSIZE = 172\n",
    "TICK_FONTSIZE = 130\n",
    "LEGEND_FONTSIZE = 140\n",
    "LINEWIDTH = 40\n",
    "\n",
    "# ============================================================\n",
    "#                     UTILITIES\n",
    "# ============================================================\n",
    "def safe_exp(x):\n",
    "    \"\"\"Clipped exponential for numerical stability.\"\"\"\n",
    "    return np.exp(np.clip(x, -EXP_CLIP, EXP_CLIP))\n",
    "\n",
    "# ============================================================\n",
    "#              ODEs (original time t)\n",
    "# ============================================================\n",
    "def alpha_of_mu_QK(mu_QK):\n",
    "    \"\"\"\n",
    "    alpha(t) = (m e^{mu_QK}) / (m e^{mu_QK} + n)\n",
    "    \"\"\"\n",
    "    exp_mu = safe_exp(mu_QK)\n",
    "    return (m * exp_mu) / (m * exp_mu + n + EPS)\n",
    "\n",
    "def dmu_OV_dt(mu_OV, mu_QK, eta_OV_):\n",
    "    \"\"\"\n",
    "    d mu_OV / dt = (eta_OV * alpha) / ( b ( exp(alpha*mu_OV) + M - 1 ) )\n",
    "    \"\"\"\n",
    "    alpha = alpha_of_mu_QK(mu_QK)\n",
    "    denom = b * (safe_exp(alpha * mu_OV) + M - 1) + EPS\n",
    "    return (eta_OV_ * alpha) / denom\n",
    "\n",
    "def dmu_QK_dt(mu_OV, mu_QK, eta_QK_):\n",
    "    \"\"\"\n",
    "    d mu_QK / dt = eta_QK * (M-1) * alpha(1-alpha) * mu_OV /\n",
    "                   ( M b^2 ( exp(alpha*mu_OV) + M - 1 ) )\n",
    "    \"\"\"\n",
    "    alpha = alpha_of_mu_QK(mu_QK)\n",
    "    denom = (b * b) * M * (safe_exp(alpha * mu_OV) + M - 1) + EPS\n",
    "    num = eta_QK_ * (M - 1) * alpha * (1 - alpha) * mu_OV\n",
    "    return num / denom\n",
    "\n",
    "# ============================================================\n",
    "#                 LOSS FUNCTION (correct class)\n",
    "# ============================================================\n",
    "def stable_loss(mu_OV, mu_QK):\n",
    "    \"\"\"\n",
    "    loss = log( 1 + (M-1) exp(-mu_OV * alpha) )\n",
    "    \"\"\"\n",
    "    alpha = alpha_of_mu_QK(mu_QK)\n",
    "    logit = mu_OV * alpha\n",
    "    return np.log1p((M - 1) * safe_exp(-logit))\n",
    "\n",
    "# ============================================================\n",
    "#          RK4 in normalized time tau = t * min_eta\n",
    "# ============================================================\n",
    "def rhs_tau(mu_OV, mu_QK, eta_OV_, eta_QK_, min_eta):\n",
    "    \"\"\"\n",
    "    Integrate in tau: d/dtau = (1/min_eta) d/dt\n",
    "    \"\"\"\n",
    "    f1 = dmu_OV_dt(mu_OV, mu_QK, eta_OV_) / min_eta\n",
    "    f2 = dmu_QK_dt(mu_OV, mu_QK, eta_QK_) / min_eta\n",
    "    return np.array([f1, f2], dtype=float)\n",
    "\n",
    "def rk4_step(mu_OV, mu_QK, eta_OV_, eta_QK_, min_eta, dtau):\n",
    "    y0 = np.array([mu_OV, mu_QK], dtype=float)\n",
    "\n",
    "    k1 = rhs_tau(y0[0], y0[1], eta_OV_, eta_QK_, min_eta)\n",
    "    k2 = rhs_tau(y0[0] + 0.5*dtau*k1[0], y0[1] + 0.5*dtau*k1[1], eta_OV_, eta_QK_, min_eta)\n",
    "    k3 = rhs_tau(y0[0] + 0.5*dtau*k2[0], y0[1] + 0.5*dtau*k2[1], eta_OV_, eta_QK_, min_eta)\n",
    "    k4 = rhs_tau(y0[0] + dtau*k3[0], y0[1] + dtau*k3[1], eta_OV_, eta_QK_, min_eta)\n",
    "\n",
    "    y1 = y0 + (dtau/6.0) * (k1 + 2*k2 + 2*k3 + k4)\n",
    "\n",
    "    # Theory keeps these >= 0; enforce for stability.\n",
    "    y1[0] = max(float(y1[0]), 0.0)\n",
    "    y1[1] = max(float(y1[1]), 0.0)\n",
    "\n",
    "    # Mild clipping (rarely needed; helps if dt_tau is too large)\n",
    "    y1[0] = min(y1[0], 1e6)\n",
    "    y1[1] = min(y1[1], 1e6)\n",
    "    return y1[0], y1[1]\n",
    "\n",
    "# ============================================================\n",
    "#        SIMULATION ROUTINE (auto-extend until target loss)\n",
    "# ============================================================\n",
    "def run_simulation_until(eta_OV_, eta_QK_, init_mu_OV=0.0, init_mu_QK=0.0,\n",
    "                         target_loss=1e-5, T_tau_init=2e5, max_T_tau=5e7):\n",
    "    \"\"\"\n",
    "    Integrate in tau until loss <= target_loss, auto-extending horizon if needed.\n",
    "    Hard-stops at max_T_tau.\n",
    "    \"\"\"\n",
    "    min_eta = min(eta_OV_, eta_QK_)\n",
    "\n",
    "    # storage as Python lists for easy extension\n",
    "    tau_list = [0.0]\n",
    "    mu_OV_list = [float(init_mu_OV)]\n",
    "    mu_QK_list = [float(init_mu_QK)]\n",
    "\n",
    "    alpha0 = alpha_of_mu_QK(mu_QK_list[0])\n",
    "    logit0 = mu_OV_list[0] * alpha0\n",
    "    loss0 = stable_loss(mu_OV_list[0], mu_QK_list[0])\n",
    "\n",
    "    alpha_list = [float(alpha0)]\n",
    "    logit_list = [float(logit0)]\n",
    "    loss_list  = [float(loss0)]\n",
    "\n",
    "    # horizon in tau for this run (will grow)\n",
    "    T_tau_current = float(T_tau_init)\n",
    "\n",
    "    def step_once():\n",
    "        muov, muqk = rk4_step(mu_OV_list[-1], mu_QK_list[-1], eta_OV_, eta_QK_, min_eta, dt_tau)\n",
    "        tau_new = tau_list[-1] + dt_tau\n",
    "\n",
    "        tau_list.append(tau_new)\n",
    "        mu_OV_list.append(muov)\n",
    "        mu_QK_list.append(muqk)\n",
    "\n",
    "        a = alpha_of_mu_QK(muqk)\n",
    "        lgt = muov * a\n",
    "        ls = stable_loss(muov, muqk)\n",
    "\n",
    "        alpha_list.append(float(a))\n",
    "        logit_list.append(float(lgt))\n",
    "        loss_list.append(float(ls))\n",
    "\n",
    "    # Run in chunks; if not reached, extend\n",
    "    while True:\n",
    "        while tau_list[-1] < T_tau_current:\n",
    "            step_once()\n",
    "            if loss_list[-1] <= target_loss:\n",
    "                break\n",
    "\n",
    "        if loss_list[-1] <= target_loss:\n",
    "            break\n",
    "\n",
    "        if T_tau_current >= max_T_tau:\n",
    "            print(f\"[WARN] η_QK/η_OV={eta_QK_/eta_OV_:.3g}: \"\n",
    "                  f\"did NOT reach loss <= {target_loss} by tau={T_tau_current:.3g}. \"\n",
    "                  f\"Final loss={loss_list[-1]:.3e}\")\n",
    "            break\n",
    "\n",
    "        T_tau_current = min(2.0 * T_tau_current, max_T_tau)\n",
    "\n",
    "    # Convert to arrays\n",
    "    tau = np.array(tau_list, dtype=float)\n",
    "    mu_OV = np.array(mu_OV_list, dtype=float)\n",
    "    mu_QK = np.array(mu_QK_list, dtype=float)\n",
    "    alpha_t = np.array(alpha_list, dtype=float)\n",
    "    logit = np.array(logit_list, dtype=float)\n",
    "    loss = np.array(loss_list, dtype=float)\n",
    "\n",
    "    mid = len(tau) // 2\n",
    "    print(f\"η_QK/η_OV={eta_QK_/eta_OV_:.2g}: \"\n",
    "          f\"tau_end={tau[-1]:.3g}, \"\n",
    "          f\"mu_OV {mu_OV[mid]:.3f}->{mu_OV[-1]:.3f}, \"\n",
    "          f\"mu_QK {mu_QK[mid]:.3f}->{mu_QK[-1]:.3f}, \"\n",
    "          f\"alpha {alpha_t[mid]:.3f}->{alpha_t[-1]:.3f}, \"\n",
    "          f\"loss {loss[mid]:.3e}->{loss[-1]:.3e}\")\n",
    "\n",
    "    return tau, mu_OV, mu_QK, loss, alpha_t, logit\n",
    "\n",
    "# ============================================================\n",
    "#                     RUN SWEEP (auto-extend)\n",
    "# ============================================================\n",
    "eta_pairs = [(eta_OV, eta_OV * r) for r in ratios]\n",
    "results = {}\n",
    "for eov, eqk in eta_pairs:\n",
    "    results[(eov, eqk)] = run_simulation_until(\n",
    "        eov, eqk,\n",
    "        target_loss=TARGET_LOSS,\n",
    "        T_tau_init=T_tau_init,\n",
    "        max_T_tau=MAX_T_TAU\n",
    "    )\n",
    "\n",
    "# ============================================================\n",
    "#                        PLOTTING HELPERS\n",
    "# ============================================================\n",
    "font_properties = fm.FontProperties(weight='bold', size=LEGEND_FONTSIZE)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# ============================================================\n",
    "#                ORIGINAL TIME-SERIES PLOTS (optional)\n",
    "# ============================================================\n",
    "\n",
    "\n",
    "# # ============================================================\n",
    "# #                NEW STOPPING-LOSS PLOTS (requested)\n",
    "# # ============================================================\n",
    "# plot_stopping_curve_linear_x(r\"$\\bf\\mu_{OV}\\ \\text{at stop}$\",\n",
    "#                              lambda muov_at, muqk_at, alpha_at: muov_at,\n",
    "#                              \"mu_OV_vs_stopping_loss_linearx.pdf\")\n",
    "\n",
    "# plot_stopping_curve_linear_x(r\"$\\bf\\mu_{QK}\\ \\text{at stop}$\",\n",
    "#                              lambda muov_at, muqk_at, alpha_at: muqk_at,\n",
    "#                              \"mu_QK_vs_stopping_loss_linearx.pdf\")\n",
    "\n",
    "# plot_stopping_curve_linear_x(r\"$\\bf\\alpha(\\tau)\\ \\text{at stop}$\",\n",
    "#                              lambda muov_at, muqk_at, alpha_at: alpha_at,\n",
    "#                              \"alpha_vs_stopping_loss_linearx.pdf\")\n",
    "\n",
    "# print(\"Done.\")\n",
    "# print(f\"Params: m={m}, n={n}, b={b}, M={M}, eta_OV={eta_OV}, ratios={ratios}\")\n",
    "# print(f\"Stopping thresholds: {stop_thresholds}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02912a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_quantity_log1p(ylabel, extractor, filename, T_tau_for_ticks=None):\n",
    "    \"\"\"\n",
    "    Plot y(tau) vs x=log10(1+tau) (your original style),\n",
    "    but using the actual simulated tau arrays.\n",
    "    \"\"\"\n",
    "    fig, ax = plt.subplots(figsize=(48, 48))\n",
    "\n",
    "    # If not provided, set tick range from max tau across all curves\n",
    "    if T_tau_for_ticks is None:\n",
    "        T_tau_for_ticks = max(results[(eov, eqk)][0][-1] for (eov, eqk) in eta_pairs)\n",
    "\n",
    "    for (eov, eqk) in eta_pairs:\n",
    "        tau, mu_OV, mu_QK, loss, alpha_t, logit = results[(eov, eqk)]\n",
    "        y = extractor(mu_OV, mu_QK, loss, alpha_t, logit)\n",
    "        x = np.log10(1.0 + tau)\n",
    "        ax.plot(x, y, linewidth=LINEWIDTH,\n",
    "                 label=rf'$\\bf r={eqk/eov:.3g}$',solid_capstyle=\"round\")\n",
    "    ax.set_xlabel(r\"Normalized time $t$\", fontweight=\"bold\", fontsize=172, labelpad=24)\n",
    "    ax.set_ylabel(ylabel,fontweight=\"bold\", fontsize=172, labelpad=22)\n",
    "    \n",
    "    \n",
    "    max_power = int(np.floor(np.log10(max(1.0, T_tau_for_ticks))))\n",
    "    raw_ticks_tau = [0] + [10**k for k in range(0, max_power + 1)]\n",
    "    tick_pos = [np.log10(1.0 + t) for t in raw_ticks_tau]\n",
    "    ax.set_xticks(tick_pos)\n",
    "    ax.set_xticklabels([r\"$0$\", r\"$1$\", r\"$10$\",r\"$10^2$\", r\"$10^3$\", r\"$10^4$\", r\"$10^5$\" ,r\"$10^6$\"])\n",
    "    \n",
    "    #ax.set_yticks([0, 20, 40, 60,80])\n",
    "    \n",
    "    # Tick label sizes (as you had) + tick line thickness/length + padding\n",
    "    ax.tick_params(\n",
    "        axis=\"both\",\n",
    "        which=\"major\",\n",
    "        labelsize=130,\n",
    "        width=24,        # tick thickness\n",
    "        length=36,      # tick length\n",
    "        direction=\"out\",\n",
    "        pad=18         # push tick labels away from axes (reduces overlap with plot)\n",
    "    )\n",
    "     # Make tick labels bold (tick_params doesn't set fontweight reliably everywhere)\n",
    "    for t in ax.get_xticklabels() + ax.get_yticklabels():\n",
    "        t.set_fontweight(\"bold\")   \n",
    "    # Thicken the axis box (spines) to match thick ticks/lines\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_linewidth(10)# Add a little internal margin so curves don't hug the axes/tick labels\n",
    "    ax.margins(x=0.02, y=0.08)\n",
    "    # Optional: if you re-enable legend, keep it out of the data\n",
    "    ax.legend(frameon=True, fontsize=140, loc=\"upper left\", bbox_to_anchor=(0.02, 0.98))\n",
    "        \n",
    "    fig.tight_layout(pad=0.4)\n",
    "\n",
    "    plt.savefig(filename, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26c87f59",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plot_quantity_log1p(r\"$\\bf\\mu_{OV}(t)$\",\n",
    "                    lambda mu_OV, mu_QK, loss, alpha_t, logit: mu_OV,\n",
    "                    \"mu_OV_1.pdf\")\n",
    "\n",
    "plot_quantity_log1p(r\"$\\bf\\mu_{QK}(t)$\",\n",
    "                    lambda mu_OV, mu_QK, loss, alpha_t, logit: mu_QK,\n",
    "                    \"mu_QK_1.pdf\")\n",
    "\n",
    "plot_quantity_log1p(r\"Loss\",\n",
    "                    lambda mu_OV, mu_QK, loss, alpha_t, logit: loss,\n",
    "                    \"loss_t_1.pdf\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e70c2e4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plot_quantity_log1p(r\"$\\alpha(\\tau)$\",\n",
    "                    lambda mu_OV, mu_QK, loss, alpha_t, logit: alpha_t,\n",
    "                    \"alpha_t_1.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae703dd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stopping_value_vs_threshold(mu_OV, mu_QK, loss, alpha_t, thresholds):\n",
    "    \"\"\"\n",
    "    For each threshold L, find the earliest index i where loss[i] <= L.\n",
    "    Return arrays of mu_OV[i], mu_QK[i], alpha_t[i] at that stopping index.\n",
    "    If never reached, returns NaN.\n",
    "    \"\"\"\n",
    "    muov_at = np.full(thresholds.shape, np.nan, dtype=float)\n",
    "    muqk_at = np.full(thresholds.shape, np.nan, dtype=float)\n",
    "    alpha_at = np.full(thresholds.shape, np.nan, dtype=float)\n",
    "\n",
    "    for j, L in enumerate(thresholds):\n",
    "        idx = np.where(loss <= L)[0]\n",
    "        if idx.size > 0:\n",
    "            i0 = int(idx[0])\n",
    "            muov_at[j] = mu_OV[i0]\n",
    "            muqk_at[j] = mu_QK[i0]\n",
    "            alpha_at[j] = alpha_t[i0]\n",
    "\n",
    "    return muov_at, muqk_at, alpha_at"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7f55c00",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_stopping_curve_linear_x(ylabel, y_selector, filename):\n",
    "    fig, ax = plt.subplots(figsize=(48, 48))\n",
    "    \n",
    "    \n",
    "\n",
    "    for (eov, eqk) in eta_pairs:\n",
    "        tau, mu_OV, mu_QK, loss, alpha_t, logit = results[(eov, eqk)]\n",
    "        muov_at, muqk_at, alpha_at = stopping_value_vs_threshold(\n",
    "            mu_OV, mu_QK, loss, alpha_t, stop_thresholds\n",
    "        )\n",
    "        y = y_selector(muov_at, muqk_at, alpha_at)\n",
    "\n",
    "        plt.plot(\n",
    "            stop_thresholds, y,\n",
    "            linewidth=LINEWIDTH,\n",
    "            marker='o',\n",
    "            label=rf'$\\bfr={eqk/eov:.3g}$'\n",
    "        )\n",
    "\n",
    "    ax.set_xscale(\"log\")\n",
    "    ax.invert_xaxis()\n",
    "\n",
    "    # Linear x-axis (as requested). Show decreasing thresholds left->right.\n",
    "    #ax.invert_xaxis()\n",
    "\n",
    "    # Force exact tick positions/labels you requested\n",
    "    \n",
    "    ax.set_xlabel(r\"Stopping loss threshold\", fontweight=\"bold\", fontsize=172, labelpad=24)\n",
    "    ax.set_ylabel(ylabel,fontweight=\"bold\", fontsize=172, labelpad=22)\n",
    "    \n",
    "    \n",
    "    # Tick label sizes (as you had) + tick line thickness/length + padding\n",
    "    ax.tick_params(\n",
    "        axis=\"both\",\n",
    "        which=\"major\",\n",
    "        labelsize=130,\n",
    "        width=24,        # tick thickness\n",
    "        length=36,      # tick length\n",
    "        direction=\"out\",\n",
    "        pad=18         # push tick labels away from axes (reduces overlap with plot)\n",
    "    )\n",
    "     # Make tick labels bold (tick_params doesn't set fontweight reliably everywhere)\n",
    "    for t in ax.get_xticklabels() + ax.get_yticklabels():\n",
    "        t.set_fontweight(\"bold\")   \n",
    "    # Thicken the axis box (spines) to match thick ticks/lines\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_linewidth(10)# Add a little internal margin so curves don't hug the axes/tick labels\n",
    "    ax.margins(x=0.02, y=0.08)\n",
    "    # Optional: if you re-enable legend, keep it out of the data\n",
    "    ax.legend(frameon=True, fontsize=140, loc=\"upper left\", bbox_to_anchor=(0.02, 0.98))\n",
    "        \n",
    "    fig.tight_layout(pad=0.4)\n",
    "\n",
    "    plt.savefig(filename, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd374103",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# ============================================================\n",
    "plot_stopping_curve_linear_x(r\"$\\bf{\\mu_{OV}}$\",\n",
    "                             lambda muov_at, muqk_at, alpha_at: muov_at,\n",
    "                             \"mu_OV_stopping_loss_1.pdf\")\n",
    "\n",
    "plot_stopping_curve_linear_x(r\"$\\bf{\\mu_{QK}}$\",\n",
    "                             lambda muov_at, muqk_at, alpha_at: muqk_at,\n",
    "                             \"mu_QK_stopping_loss_1.pdf\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0441f0cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_stopping_curve_linear_x(r\"$\\bf{\\alpha(\\tau)}$\",\n",
    "                             lambda muov_at, muqk_at, alpha_at: alpha_at,\n",
    "                             \"alpha_stopping_loss_1.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88722a0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# _1 means m, n, b, M = 5, 50, 50, 20       # synthetic setting"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
