{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-01-28T07:16:18.681485Z",
     "iopub.status.busy": "2026-01-28T07:16:18.680063Z",
     "iopub.status.idle": "2026-01-28T07:16:27.561449Z",
     "shell.execute_reply": "2026-01-28T07:16:27.560400Z",
     "shell.execute_reply.started": "2026-01-28T07:16:18.681429Z"
    },
    "trusted": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import mpmath as mp\n",
    "import random\n",
    "from mpmath import expm1, tanh, exp, sech\n",
    "\n",
    "\n",
    "mp.mp.dps = 200 \n",
    "\n",
    "seed=0\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "\n",
    "\n",
    "\n",
    "def forward(X, a, b, c):\n",
    " \n",
    "    w = mp.matrix(len(a), 1)\n",
    "    for i in range(len(a)):\n",
    "        if i < 2: \n",
    "            w[i] =2*a[i]*b[i] + (a[i]**2 + b[i]**2 + c[i]**2)*c[i]\n",
    "        else:     \n",
    "            \n",
    "            w[i] = a[i]*tanh(a[i]) - (expm1(b[i]))**2\n",
    "            \n",
    "    return X * w, w\n",
    "\n",
    "def compute_gradients(X, y_train, y_pred, a, b, c, w):\n",
    "\n",
    "    m = X.rows\n",
    "    n = X.cols\n",
    "    error = y_pred - y_train\n",
    "    \n",
    "\n",
    "    grad_L_w = (2 / m) * (X.T * error)\n",
    "    \n",
    "\n",
    "    grad_w_a = mp.matrix(n, 1)\n",
    "    grad_w_b = mp.matrix(n, 1)\n",
    "    grad_w_c = mp.matrix(n, 1)\n",
    "    \n",
    "    for i in range(n):\n",
    "        if i < 2:  \n",
    "            grad_w_a[i] = 2*b[i] + 2 * a[i] * c[i]\n",
    "       \n",
    "            grad_w_b[i] = 2*a[i] + 2 * b[i] * c[i]\n",
    "          \n",
    "            grad_w_c[i] = a[i]**2 + b[i]**2 + 3 * c[i]**2\n",
    "        else:     \n",
    "            tanh_ai = tanh(a[i])\n",
    "            grad_w_a[i] = tanh_ai + a[i] * (1 - tanh_ai**2)\n",
    "\n",
    "            grad_w_b[i] = -2 * expm1(b[i]) * exp(b[i])\n",
    "\n",
    "            grad_w_c[i] = 0\n",
    "\n",
    "    grad_L_a = mp.matrix([grad_L_w[i] * grad_w_a[i] for i in range(n)])\n",
    "    grad_L_b = mp.matrix([grad_L_w[i] * grad_w_b[i] for i in range(n)])\n",
    "    grad_L_c = mp.matrix([grad_L_w[i] * grad_w_c[i] for i in range(n)])\n",
    "    \n",
    "    return grad_L_a, grad_L_b, grad_L_c\n",
    "\n",
    "\n",
    "n = 4\n",
    "m = 2\n",
    "\n",
    "X_np = np.array([[1, 0.5, 0.7, 0], [0.5, 1, 0.1, 0.7]])\n",
    "y_np = np.array([1, 0])\n",
    "\n",
    "X_train = mp.matrix(X_np)\n",
    "y_train = mp.matrix(y_np)\n",
    "\n",
    "\n",
    "a = mp.matrix([mp.rand() * 2e-60 - 1e-60 for _ in range(n)])\n",
    "b = mp.matrix([mp.rand() * 2e-60 - 1e-60 for _ in range(n)])\n",
    "c = mp.matrix([mp.rand() * 2e-60 - 1e-60 for _ in range(n)])\n",
    "\n",
    "\n",
    "epochs = 5001\n",
    "lr = mp.mpf('0.1')\n",
    "loss_history = []\n",
    "w_history = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "\n",
    "    y_pred, w = forward(X_train, a, b, c)\n",
    "\n",
    "    error = y_pred - y_train\n",
    "    loss = sum(e**2 for e in error) / m\n",
    "    \n",
    "    loss_history.append(loss)\n",
    "    w_history.append(list(w))\n",
    "\n",
    "\n",
    "    grad_a, grad_b, grad_c = compute_gradients(X_train, y_train, y_pred, a, b, c, w)\n",
    "\n",
    "\n",
    "    a -= lr * grad_a\n",
    "    b -= lr * grad_b\n",
    "    c -= lr * grad_c\n",
    "\n",
    "\n",
    "\n",
    "plt.figure(figsize=(14, 6))\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "\n",
    "\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot([float(l) for l in loss_history], color='blue')\n",
    "plt.title(\"Loss Curve (MSE)\", fontsize=14)\n",
    "plt.xlabel(\"Epoch\", fontsize=12)\n",
    "plt.ylabel(\"Loss\", fontsize=12)\n",
    "plt.yscale('log')\n",
    "\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "w_history_transposed = list(zip(*w_history))\n",
    "for i in range(n):\n",
    "\n",
    "    w_i_abs_history = [float(abs(val)) for val in w_history_transposed[i]]\n",
    "    w_i_history = [float(val) for val in w_history_transposed[i]]\n",
    "\n",
    "    plt.plot(w_i_abs_history, label=f'$|w_{i+1}|$') \n",
    "\n",
    "plt.title(\"Evolution of $|w_i|$\", fontsize=14)\n",
    "plt.yscale('log')\n",
    "plt.xlabel(\"Epoch\", fontsize=12)\n",
    "plt.ylabel(\"Parameter Value (abs)\", fontsize=12)\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-01-28T07:24:59.409670Z",
     "iopub.status.busy": "2026-01-28T07:24:59.409324Z",
     "iopub.status.idle": "2026-01-28T07:25:01.121729Z",
     "shell.execute_reply": "2026-01-28T07:25:01.120675Z",
     "shell.execute_reply.started": "2026-01-28T07:24:59.409645Z"
    },
    "trusted": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "n = 4\n",
    "index = 4600\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6.4, 6.4)) \n",
    "\n",
    "\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "\n",
    "\n",
    "ax1.plot([float(l) for l in loss_history[:index]], color='black')\n",
    "ax1.set_title(\"Loss Curve\", fontsize=14)\n",
    "ax1.set_xlabel(\"Epoch\", fontsize=12)\n",
    "ax1.set_ylabel(\"Loss\", fontsize=12)\n",
    "ax1.set_yscale('log')\n",
    "\n",
    "\n",
    "\n",
    "ax1.axvspan(0, 760, color='orange', alpha=0.2)\n",
    "ax1.text(400, 1e-4, 'Feature\\n Selection\\nPhase 1', \n",
    "         horizontalalignment='center', fontsize=10, color='darkorange', weight='bold')\n",
    "\n",
    "\n",
    "ax1.axvspan(800, 2530, color='green', alpha=0.2)\n",
    "ax1.text(1650, 1e-4, 'Feature Selection\\nPhase 2', \n",
    "         horizontalalignment='center', fontsize=10, color='darkgreen', weight='bold')\n",
    "\n",
    "\n",
    "\n",
    "ax1.axvspan(2870, 4350, color='blue', alpha=0.2)\n",
    "ax1.text(3500, 1e-4, 'Feature Selection\\nPhase 3', \n",
    "         horizontalalignment='center', fontsize=10, color='darkblue', weight='bold')\n",
    "\n",
    "\n",
    "ax1.annotate('Learning Phase 1', \n",
    "             xy=(770, 0.2),  \n",
    "             xytext=(1500, 1e-0),\n",
    "             arrowprops=dict(facecolor='red', shrink=0.05, width=2, headwidth=8),\n",
    "             fontsize=10, horizontalalignment='center', \n",
    "             color='darkred')\n",
    "\n",
    "\n",
    "ax1.annotate('Learning Phase 2', \n",
    "             xy=(2630, 0.05), \n",
    "             xytext=(2800, 1), \n",
    "             arrowprops=dict(facecolor='red', shrink=0.05, width=1, headwidth=8),\n",
    "             fontsize=10, horizontalalignment='center',color='darkred')\n",
    "\n",
    "ax1.annotate('Learning Phase 3', \n",
    "             xy=(4475, 1e-5), \n",
    "             xytext=(4600, 1e-1), \n",
    "             arrowprops=dict(facecolor='red', shrink=0.05, width=1, headwidth=8),\n",
    "             fontsize=10, horizontalalignment='center',color='darkred')\n",
    "\n",
    "ax1.set_xlim(0, 5000)  \n",
    "ax1.set_ylim(1e-13, 1e2) \n",
    "\n",
    "\n",
    "\n",
    "def format_value(value):\n",
    " \n",
    "    \n",
    "    val_abs = float(abs(value))\n",
    "    if val_abs > 0 and val_abs < 1e-3:\n",
    "        return f'{value:.2e}'  \n",
    "    else:\n",
    "        return f'{value:.2f}'  \n",
    "\n",
    "\n",
    "def format_value(value):\n",
    "    val_abs = abs(value)\n",
    "    if val_abs > 0 and val_abs < 1e-5:\n",
    "        return f'{value:.2e}'\n",
    "    else:\n",
    "        return f'{value:.3f}'\n",
    "\n",
    "\n",
    "\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "\n",
    "w_history_transposed = list(zip(*w_history))\n",
    "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "\n",
    "for i in range(n):\n",
    "    w_i_history = [float(abs(val)) for val in w_history_transposed[i]]\n",
    "    ax2.plot(w_i_history, label=f'$k_{i+1}$', color=colors[i], lw=2)\n",
    "\n",
    "\n",
    "ax2.set_title(\"Evolution of $|k_i|$\", fontsize=16)\n",
    "ax2.set_yscale('log')\n",
    "ax2.set_xlabel(\"Epoch\", fontsize=12)\n",
    "ax2.set_ylabel(\"Parameter Value\", fontsize=12)\n",
    "ax2.legend()\n",
    "ax2.grid(True, which=\"both\", ls=\"-\", alpha=0.5)\n",
    "\n",
    "ax2.set_xlim(0, 5000)\n",
    "\n",
    "epochs_to_mark = [1000, 3500, 4700]\n",
    "\n",
    "\n",
    "combined_annotation_placements = {\n",
    "    1000: {'x_offset': 200, 'y_target': 1e-60, 'y_text': 1e-60},\n",
    "    3500: {'x_offset': -200, 'y_target': 1e-60, 'y_text': 1e-60},\n",
    "    4700: {'x_offset': -200, 'y_target': 1e-40, 'y_text': 1e-40}\n",
    "}\n",
    "\n",
    "for epoch in epochs_to_mark:\n",
    "\n",
    "    ax2.axvline(x=epoch, color='gray', linestyle='--', linewidth=1.0, alpha=0.9)\n",
    "\n",
    "\n",
    "    text_lines = []\n",
    "    for param_idx in range(n):\n",
    "        mpf_value = w_history[epoch][param_idx]\n",
    "        value_float = float(mpf_value)\n",
    "        formatted_val = format_value(value_float)\n",
    " \n",
    "        text_lines.append(f\"$\\\\bf k_{param_idx+1}$: {formatted_val}\")\n",
    "\n",
    "\n",
    "    combined_text = '\\n'.join(text_lines)\n",
    "    \n",
    "\n",
    "    placement = combined_annotation_placements[epoch]\n",
    "    x_offset = placement['x_offset']\n",
    "    \n",
    "\n",
    "    horizontal_alignment = 'left' if x_offset > 0 else 'right'\n",
    "\n",
    "\n",
    "    ax2.annotate(\n",
    "        combined_text,\n",
    "        xy=(epoch, placement['y_target']), \n",
    "        xytext=(epoch + x_offset, placement['y_text']), \n",
    "        arrowprops=dict(\n",
    "            arrowstyle=\"->\",\n",
    "            color='black',\n",
    "            connectionstyle=\"arc3,rad=0.2\",\n",
    "            alpha=0.7\n",
    "        ),\n",
    "        fontsize=9,\n",
    "\n",
    "        ha=horizontal_alignment,\n",
    "        va='center', \n",
    "        multialignment='left',\n",
    "        bbox=dict(\n",
    "            boxstyle='round,pad=0.5',\n",
    "            fc='whitesmoke', \n",
    "            ec='black',\n",
    "            lw=1,\n",
    "            alpha=0.05\n",
    "        )\n",
    "    )\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('experiment.png')\n",
    "plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "trusted": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kaggle": {
   "accelerator": "none",
   "dataSources": [],
   "dockerImageVersionId": 31153,
   "isGpuEnabled": false,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook"
  },
  "kernelspec": {
   "display_name": "mybrian",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "-1.-1.-1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
