{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d2b233b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using matplotlib backend: MacOSX\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.integrate import solve_ivp\n",
    "import matplotlib.patheffects as path_effects\n",
    "\n",
    "%matplotlib inline\n",
    "%matplotlib auto\n",
    "\n",
    "plt.style.use('seaborn-white')\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "\n",
    "def f(x, y):\n",
    "    term1 = x * (y - 0.45)\n",
    "    term2 = 0.25 * x**2 - 0.5 * x**4 + (1/6) * x**6\n",
    "    term3 = 0.25 * y**2 - 0.5 * y**4 + (1/6) * y**6\n",
    "    return term1 + term2 - term3\n",
    "\n",
    "def df_dx(x, y):\n",
    "    return y - 0.45 + 0.5 * x - 2 * x**3 + x**5\n",
    "\n",
    "def df_dy(x, y):\n",
    "    return x - 0.5 * y + 2 * y**3 - y**5\n",
    "\n",
    "def d2f_dx2(x, y):\n",
    "    return 0.5 - 6 * x**2 + 5 * x**4\n",
    "\n",
    "def d2f_dy2(x, y):\n",
    "    return -0.5 + 6 * y**2 - 5 * y**4\n",
    "\n",
    "def d2f_dxdy(x, y):\n",
    "    return 1.0\n",
    "\n",
    "def run_saddle_point_adam(x0, y0, alpha=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8, num_steps=1000):\n",
    "    x, y = x0, y0\n",
    "    m_x, m_y = 0.0, 0.0\n",
    "    v_x, v_y = 0.0, 0.0\n",
    "    trajectory = [(x, y)]\n",
    "    \n",
    "    for t in range(1, num_steps + 1):\n",
    "        g_x, g_y = df_dx(x, y), df_dy(x, y)\n",
    "        m_x = beta1 * m_x + (1 - beta1) * g_x\n",
    "        m_y = beta1 * m_y + (1 - beta1) * g_y\n",
    "        v_x = beta2 * v_x + (1 - beta2) * (g_x**2)\n",
    "        v_y = beta2 * v_y + (1 - beta2) * (g_y**2)\n",
    "        m_hat_x = m_x / (1 - beta1**t)\n",
    "        m_hat_y = m_y / (1 - beta1**t)\n",
    "        v_hat_x = v_x / (1 - beta2**t)\n",
    "        v_hat_y = v_y / (1 - beta2**t)\n",
    "        x = x - alpha * m_hat_x / (np.sqrt(v_hat_x) + epsilon)\n",
    "        y = y + alpha * m_hat_y / (np.sqrt(v_hat_y) + epsilon)\n",
    "        trajectory.append((x, y))\n",
    "        \n",
    "    return np.array(trajectory)\n",
    "\n",
    "\n",
    "def continuous_time_adam_ode(t, state, beta, rho, epsilon, h):\n",
    "    x, y = state\n",
    "    g_x, g_y = df_dx(x, y), df_dy(x, y)\n",
    "    g_xx, g_yy, g_xy = d2f_dx2(x, y), d2f_dy2(x, y), d2f_dxdy(x, y)\n",
    "    \n",
    "    norm_x_eps = np.sqrt(g_x**2 + epsilon)\n",
    "    norm_y_eps = np.sqrt(g_y**2 + epsilon)\n",
    "    \n",
    "    mu_eps = 1.0 / norm_x_eps\n",
    "    nu_eps = 1.0 / norm_y_eps \n",
    "    \n",
    "    c1 = (1+beta)/(1-beta) - (1+rho)/(1-rho)\n",
    "    c2 = epsilon * (1+rho) / (1-rho)\n",
    "    \n",
    "    M_mu = c1 + c2 * mu_eps**2\n",
    "    M_nu = c1 + c2 * nu_eps**2\n",
    "    \n",
    "    grad_N_x = g_xx * mu_eps * g_x - g_xy * nu_eps * g_y\n",
    "    grad_N_y = g_xy * mu_eps * g_x - g_yy * nu_eps * g_y\n",
    "    \n",
    "    dx_dt = -mu_eps * (g_x + (h/2) * M_mu * grad_N_x)\n",
    "    dy_dt =  nu_eps * (g_y + (h/2) * M_nu * grad_N_y) \n",
    "    \n",
    "    return [dx_dt, dy_dt]\n",
    "\n",
    "def run_continuous_time_adam(x0, y0, beta, rho, epsilon, h, t_span, num_points):\n",
    "    ode_func = lambda t, state: continuous_time_adam_ode(t, state, beta, rho, epsilon, h)\n",
    "    t_eval = np.linspace(t_span[0], t_span[1], num_points)\n",
    "    \n",
    "    solution = solve_ivp(fun=ode_func, t_span=t_span, y0=[x0, y0], t_eval=t_eval, method='RK45')\n",
    "    \n",
    "    return solution.y.T\n",
    "\n",
    "def sign_based_ode(t, state):\n",
    "    x, y = state\n",
    "    g_x = df_dx(x, y)\n",
    "    g_y = df_dy(x, y)\n",
    "    \n",
    "    dx_dt = -np.sign(g_x)\n",
    "    dy_dt = np.sign(g_y)\n",
    "    \n",
    "    return [dx_dt, dy_dt]\n",
    "\n",
    "def run_sign_based_flow(x0, y0, t_span, num_points):\n",
    "    t_eval = np.linspace(t_span[0], t_span[1], num_points)\n",
    "    \n",
    "    solution = solve_ivp(fun=sign_based_ode, t_span=t_span, y0=[x0, y0], t_eval=t_eval, method='RK45')\n",
    "    \n",
    "    return solution.y.T\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def plot_trajectories(trajectories_dict, x_range, y_range, num_grid_points=200):\n",
    "    x_grid = np.linspace(x_range[0], x_range[1], num_grid_points)\n",
    "    y_grid = np.linspace(y_range[0], y_range[1], num_grid_points)\n",
    "    X, Y = np.meshgrid(x_grid, y_grid)\n",
    "\n",
    "    plt.style.use('seaborn-white')\n",
    "    plt.figure(figsize=(12, 10))\n",
    "    \n",
    "    U = -df_dx(X, Y)\n",
    "    V = df_dy(X, Y)\n",
    "    \n",
    "    plt.streamplot(X, Y, U, V, color='gray', linewidth=0.75, density=1.5,\n",
    "                   arrowstyle='->', arrowsize=1.2, minlength=0.1)\n",
    "\n",
    "    custom_colors = ['#18315C', '#053154', '#A50D71']\n",
    "    \n",
    "    for i, (name, traj) in enumerate(trajectories_dict.items()):\n",
    "        if name == \"Adam-DA\":\n",
    "            color = custom_colors[0] \n",
    "            line_width = 8 \n",
    "            path_effect_width = 4.5 \n",
    "            \n",
    "            plt.plot(traj[:, 0], traj[:, 1], \n",
    "                     label=name, \n",
    "                     color=color, \n",
    "                     linewidth=line_width, \n",
    "                     alpha=0.5,\n",
    "                     linestyle='--',       \n",
    "                     path_effects=[path_effects.withStroke(linewidth=path_effect_width, foreground='white')])\n",
    "\n",
    "        elif name == \"Continuous Adam-DA\":\n",
    "            line_width = 2.8\n",
    "            color = custom_colors[1] \n",
    "            path_effect_width = 3.5\n",
    "            plt.plot(traj[:, 0], traj[:, 1], label=name, color=color, linewidth=line_width, alpha=0.9,\n",
    "                     path_effects=[path_effects.withStroke(linewidth=path_effect_width, foreground='white')])\n",
    "\n",
    "        else: \n",
    "            line_width = 2.8\n",
    "            color = custom_colors[2] \n",
    "            path_effect_width = 3.5\n",
    "            plt.plot(traj[:, 0], traj[:, 1], label=name, color=color, linewidth=line_width, alpha=0.9,\n",
    "                     path_effects=[path_effects.withStroke(linewidth=path_effect_width, foreground='white')])\n",
    "\n",
    "        plt.plot(traj[-1, 0], traj[-1, 1], 'o', color=color, markersize=8,\n",
    "                 markeredgecolor='white', markeredgewidth=1.5)\n",
    "\n",
    "\n",
    "    start_point_traj = next(iter(trajectories_dict.values()))\n",
    "    plt.plot(start_point_traj[0, 0], start_point_traj[0, 1], 'w*', markersize=15,\n",
    "             markeredgecolor='black', label='Start Point')\n",
    "\n",
    "    plt.title(\"\", fontsize=16)\n",
    "    plt.xlabel(\"x\", fontsize=18)\n",
    "    plt.ylabel(\"y\", fontsize=18)\n",
    "    plt.legend(fontsize=17)\n",
    "    plt.axis('equal')\n",
    "    plt.xlim(x_range)\n",
    "    plt.ylim(y_range)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    start_point = (0.6, 0.6)\n",
    "    \n",
    "    BETA1 = 0\n",
    "    BETA2 = 0.5\n",
    "    EPSILON = 1e-6\n",
    "    \n",
    "    N = 980\n",
    "    step_size = 0.007\n",
    "\n",
    "   \n",
    "    adam_traj = run_saddle_point_adam(\n",
    "        x0=start_point[0], y0=start_point[1], alpha=step_size, \n",
    "        beta1=BETA1, beta2=BETA2, epsilon=EPSILON, \n",
    "        num_steps=N\n",
    "    )\n",
    "    continuous_adam_traj = run_continuous_time_adam(\n",
    "        x0=start_point[0], y0=start_point[1], beta=BETA1, rho=BETA2, \n",
    "        epsilon=EPSILON, h=step_size, \n",
    "        t_span=(0, step_size * N),\n",
    "        num_points=200000\n",
    "    )\n",
    "    sign_flow_traj = run_sign_based_flow(\n",
    "        x0=start_point[0], y0=start_point[1],\n",
    "        t_span=(0, step_size * N),\n",
    "        num_points=200000\n",
    "    )\n",
    "    \n",
    "\n",
    "    all_trajectories = {\n",
    "        \"Adam-DA\": adam_traj,\n",
    "        \"Continuous Adam-DA\": continuous_adam_traj,\n",
    "        \"SignGDA Flow\": sign_flow_traj\n",
    "    }\n",
    "\n",
    "\n",
    "    plot_trajectories(all_trajectories, x_range=[-1.5, 1.7], y_range=[-1.5, 1.7])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "418c5387",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow",
   "language": "python",
   "name": "tensorflow"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
