{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5e0d5ca3-4077-4662-88cb-1166eb288da9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "from grnewt import compute_Hg, nesterov_lrs, fullbatch_gradient\n",
    "from grnewt import NewtonSummaryVanilla, NewtonSummaryUniformMean\n",
    "from grnewt import optimizers\n",
    "from grnewt import partition as build_partition\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d948f449-fedd-4bb4-8084-343c8792a179",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "046b90a1-0af8-45df-9f3b-c81f7f89eb47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset\n",
    "\n",
    "def build_dataset(total_size, size_in, size_out, batch_size):\n",
    "    data_in = torch.randn(total_size, size_in)\n",
    "    data_tar = torch.randn(total_size, size_out)\n",
    "    dataset = torch.utils.data.TensorDataset(data_in, data_tar)\n",
    "    return dataset, torch.utils.data.DataLoader(dataset, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4bafbb15-6d11-4e28-b94b-2cc20c8fbeed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simple model\n",
    "\n",
    "size_hidden = 6\n",
    "act_function_cl = torch.nn.Tanh\n",
    "\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model, self).__init__()\n",
    "\n",
    "        self.hidden_layer = torch.nn.Linear(size_in, size_hidden)\n",
    "        self.activation = act_function_cl()\n",
    "        self.out_layer = torch.nn.Linear(size_hidden, size_out)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.hidden_layer(x)\n",
    "        x = self.activation(x)\n",
    "        return self.out_layer(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a8424cd3-d9a3-4d1c-8b91-4d35998b984d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset parameters\n",
    "size_in = 5\n",
    "size_out = 4\n",
    "\n",
    "# Build\n",
    "dct_nesterov = {'use': True, \n",
    "                'damping_int': 1.,\n",
    "                'mom_order3_': 0.}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "67497233-e53a-4cd4-b3db-6c637565155f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_adam(epochs, data_loader, lr, show = False):\n",
    "    model = Model().to(device = device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr = lr)\n",
    "\n",
    "    loss_fn = lambda x, y: (x - y).pow(2).mean() #.sqrt()\n",
    "    full_loss = lambda x, y: loss_fn(model(x), y)\n",
    "    \n",
    "    for epoch in range(epochs):\n",
    "        loss_tot = 0\n",
    "        n = 0\n",
    "        for x, y in data_loader:\n",
    "            x, y = x.to(device = device), y.to(device = device)\n",
    "            model.zero_grad()\n",
    "            yhat = model(x)\n",
    "            loss = loss_fn(yhat, y)\n",
    "            loss_tot += loss.item()\n",
    "    \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "    \n",
    "            n += 1\n",
    "        if show:\n",
    "            print(f\"Epoch {epoch}: {loss_tot/n:.4f}\")\n",
    "\n",
    "    loss_tot = 0\n",
    "    n = 0\n",
    "    with torch.no_grad():    \n",
    "        for x, y in data_loader:\n",
    "            x, y = x.to(device = device), y.to(device = device)\n",
    "            yhat = model(x)\n",
    "            loss = loss_fn(yhat, y)\n",
    "            loss_tot += loss.item()\n",
    "\n",
    "            n += 1\n",
    "    \n",
    "    return loss_tot/n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dc2f6cec-1aac-4c4f-9821-53a63df06d0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_NS(epochs, data_loader, momentum = 0., dampening = 0., hg_period = 10, hg_mom_lrs = 0.,\n",
    "            hg_remove_negative = False, dct_nesterov = None, hg_damping = 0., dct_uniform_mean = None,\n",
    "            data_loader_hg = None, show = False):\n",
    "    lr = 1e0\n",
    "    Cl_Updater = optimizers.SGDUpdate\n",
    "\n",
    "    # Build model\n",
    "    model = Model().to(device = device)\n",
    "    param_groups, name_groups = build_partition.canonical(model) # canonical, trivial, wb\n",
    "\n",
    "    loss_fn = lambda x, y: (x - y).pow(2).mean() #.sqrt()\n",
    "    full_loss = lambda x, y: loss_fn(model(x), y)\n",
    "    \n",
    "    updater = optimizers.SGDUpdate(model.parameters(), lr = lr, momentum = momentum, dampening = dampening)\n",
    "    \n",
    "    optimizer = NewtonSummaryUniformMean(param_groups, full_loss, data_loader_hg, updater, \n",
    "                         damping = hg_damping, period_hg = hg_period, mom_lrs = hg_mom_lrs,\n",
    "                         dct_nesterov = dct_nesterov, remove_negative = hg_remove_negative,\n",
    "                         dct_uniform_mean = dct_uniform_mean)\n",
    "    \n",
    "    for epoch in range(epochs):\n",
    "        loss_tot = 0\n",
    "        n = 0\n",
    "        for x, y in data_loader:\n",
    "            x, y = x.to(device = device), y.to(device = device)\n",
    "            model.zero_grad()\n",
    "            yhat = model(x)\n",
    "            loss = loss_fn(yhat, y)\n",
    "            loss_tot += loss.item()\n",
    "    \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            n += 1\n",
    "        if show:\n",
    "            print(f\"Epoch {epoch}: {loss_tot/n:.4f}\")\n",
    "\n",
    "    loss_tot = 0\n",
    "    n = 0\n",
    "    with torch.no_grad():    \n",
    "        for x, y in data_loader:\n",
    "            x, y = x.to(device = device), y.to(device = device)\n",
    "            yhat = model(x)\n",
    "            loss = loss_fn(yhat, y)\n",
    "            loss_tot += loss.item()\n",
    "\n",
    "            n += 1\n",
    "    \n",
    "    return loss_tot/n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "92d47d48-737c-4a62-9011-6deccb61d840",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_expes = 10\n",
    "\n",
    "epochs = 100\n",
    "\n",
    "total_size = 400\n",
    "batch_size = 10\n",
    "\n",
    "lr_adam = 1e-2\n",
    "\n",
    "momentum = 0.9 #.9\n",
    "dampening = 0.9 #.9\n",
    "hg_period=8\n",
    "hg_damping=.3\n",
    "hg_mom_lrs=0.\n",
    "hg_batch_size=batch_size*2\n",
    "dct_nesterov[\"damping_int\"] = 1.\n",
    "hg_remove_negative=True\n",
    "dct_uniform_mean = {\"use\": True, \"period\": 20, \"warmup\": 5}\n",
    "\n",
    "results_adam = []\n",
    "results_ns = []\n",
    "\n",
    "for i in range(n_expes):\n",
    "    dataset, data_loader = build_dataset(total_size, size_in, size_out, batch_size)\n",
    "    data_loader_hg = torch.utils.data.DataLoader(dataset, hg_batch_size)\n",
    "    results_adam.append(train_adam(epochs, data_loader, lr_adam))\n",
    "    results_ns.append(train_NS(epochs, data_loader, momentum = momentum, dampening = momentum, hg_period = hg_period, \n",
    "                               hg_mom_lrs = hg_mom_lrs, hg_remove_negative = hg_remove_negative, dct_nesterov = dct_nesterov, \n",
    "                               hg_damping = hg_damping, dct_uniform_mean = dct_uniform_mean, data_loader_hg = data_loader_hg))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7b7d68ae-1525-4b99-bc6d-7b75d96f56f8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9583953260630368 0.9817768289893865\n"
     ]
    }
   ],
   "source": [
    "print(np.array(results_adam).mean(), np.array(results_ns).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3ac87d1c-992e-493c-b512-d5a10779e9fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 1.0538\n",
      "Epoch 1: 1.0119\n",
      "Epoch 2: 1.0072\n",
      "Epoch 3: 1.0035\n",
      "Epoch 4: 1.0010\n",
      "Epoch 5: 0.9993\n",
      "Epoch 6: 0.9979\n",
      "Epoch 7: 0.9968\n",
      "Epoch 8: 0.9959\n",
      "Epoch 9: 0.9952\n",
      "Epoch 10: 0.9946\n",
      "Epoch 11: 0.9941\n",
      "Epoch 12: 0.9937\n",
      "Epoch 13: 0.9934\n",
      "Epoch 14: 0.9931\n",
      "Epoch 15: 0.9928\n",
      "Epoch 16: 0.9925\n",
      "Epoch 17: 0.9923\n",
      "Epoch 18: 0.9920\n",
      "Epoch 19: 0.9918\n",
      "Epoch 20: 0.9915\n",
      "Epoch 21: 0.9912\n",
      "Epoch 22: 0.9910\n",
      "Epoch 23: 0.9907\n",
      "Epoch 24: 0.9904\n",
      "Epoch 25: 0.9900\n",
      "Epoch 26: 0.9897\n",
      "Epoch 27: 0.9893\n",
      "Epoch 28: 0.9890\n",
      "Epoch 29: 0.9886\n",
      "Epoch 30: 0.9882\n",
      "Epoch 31: 0.9877\n",
      "Epoch 32: 0.9872\n",
      "Epoch 33: 0.9867\n",
      "Epoch 34: 0.9862\n",
      "Epoch 35: 0.9856\n",
      "Epoch 36: 0.9849\n",
      "Epoch 37: 0.9842\n",
      "Epoch 38: 0.9835\n",
      "Epoch 39: 0.9827\n",
      "Epoch 40: 0.9819\n",
      "Epoch 41: 0.9811\n",
      "Epoch 42: 0.9803\n",
      "Epoch 43: 0.9794\n",
      "Epoch 44: 0.9786\n",
      "Epoch 45: 0.9777\n",
      "Epoch 46: 0.9769\n",
      "Epoch 47: 0.9760\n",
      "Epoch 48: 0.9752\n",
      "Epoch 49: 0.9745\n",
      "Epoch 50: 0.9737\n",
      "Epoch 51: 0.9730\n",
      "Epoch 52: 0.9724\n",
      "Epoch 53: 0.9717\n",
      "Epoch 54: 0.9711\n",
      "Epoch 55: 0.9706\n",
      "Epoch 56: 0.9701\n",
      "Epoch 57: 0.9696\n",
      "Epoch 58: 0.9691\n",
      "Epoch 59: 0.9687\n",
      "Epoch 60: 0.9683\n",
      "Epoch 61: 0.9679\n",
      "Epoch 62: 0.9675\n",
      "Epoch 63: 0.9672\n",
      "Epoch 64: 0.9668\n",
      "Epoch 65: 0.9665\n",
      "Epoch 66: 0.9662\n",
      "Epoch 67: 0.9659\n",
      "Epoch 68: 0.9656\n",
      "Epoch 69: 0.9653\n",
      "Epoch 70: 0.9650\n",
      "Epoch 71: 0.9647\n",
      "Epoch 72: 0.9644\n",
      "Epoch 73: 0.9642\n",
      "Epoch 74: 0.9639\n",
      "Epoch 75: 0.9637\n",
      "Epoch 76: 0.9634\n",
      "Epoch 77: 0.9632\n",
      "Epoch 78: 0.9630\n",
      "Epoch 79: 0.9627\n",
      "Epoch 80: 0.9625\n",
      "Epoch 81: 0.9623\n",
      "Epoch 82: 0.9621\n",
      "Epoch 83: 0.9619\n",
      "Epoch 84: 0.9617\n",
      "Epoch 85: 0.9614\n",
      "Epoch 86: 0.9612\n",
      "Epoch 87: 0.9610\n",
      "Epoch 88: 0.9608\n",
      "Epoch 89: 0.9606\n",
      "Epoch 90: 0.9605\n",
      "Epoch 91: 0.9603\n",
      "Epoch 92: 0.9601\n",
      "Epoch 93: 0.9599\n",
      "Epoch 94: 0.9597\n",
      "Epoch 95: 0.9595\n",
      "Epoch 96: 0.9594\n",
      "Epoch 97: 0.9592\n",
      "Epoch 98: 0.9590\n",
      "Epoch 99: 0.9589\n",
      "\n",
      "Epoch 0: 1.1367\n",
      "Epoch 1: 1.0964\n",
      "Epoch 2: 1.0172\n",
      "Epoch 3: 1.0129\n",
      "Epoch 4: 0.9921\n",
      "Epoch 5: 0.9876\n",
      "Epoch 6: 0.9866\n",
      "Epoch 7: 0.9858\n",
      "Epoch 8: 0.9939\n",
      "Epoch 9: 0.9940\n",
      "Epoch 10: 0.9942\n",
      "Epoch 11: 0.9958\n",
      "Epoch 12: 1.0001\n",
      "Epoch 13: 0.9926\n",
      "Epoch 14: 0.9923\n",
      "Epoch 15: 0.9948\n",
      "Epoch 16: 1.0011\n",
      "Epoch 17: 0.9883\n",
      "Epoch 18: 0.9882\n",
      "Epoch 19: 0.9906\n",
      "Epoch 20: 0.9972\n",
      "Epoch 21: 0.9794\n",
      "Epoch 22: 0.9816\n",
      "Epoch 23: 0.9818\n",
      "Epoch 24: 0.9893\n",
      "Epoch 25: 0.9785\n",
      "Epoch 26: 0.9771\n",
      "Epoch 27: 0.9771\n",
      "Epoch 28: 0.9906\n",
      "Epoch 29: 0.9846\n",
      "Epoch 30: 0.9771\n",
      "Epoch 31: 0.9762\n",
      "Epoch 32: 0.9888\n",
      "Epoch 33: 0.9958\n",
      "Epoch 34: 0.9858\n",
      "Epoch 35: 0.9841\n",
      "Epoch 36: 0.9909\n",
      "Epoch 37: 0.9825\n",
      "Epoch 38: 0.9776\n",
      "Epoch 39: 0.9772\n",
      "Epoch 40: 0.9830\n",
      "Epoch 41: 0.9940\n",
      "Epoch 42: 0.9959\n",
      "Epoch 43: 0.9927\n",
      "Epoch 44: 0.9932\n",
      "Epoch 45: 0.9879\n",
      "Epoch 46: 0.9911\n",
      "Epoch 47: 0.9922\n",
      "Epoch 48: 0.9901\n",
      "Epoch 49: 0.9860\n",
      "Epoch 50: 0.9923\n",
      "Epoch 51: 0.9925\n",
      "Epoch 52: 0.9983\n",
      "Epoch 53: 0.9936\n",
      "Epoch 54: 0.9954\n",
      "Epoch 55: 0.9936\n",
      "Epoch 56: 0.9941\n",
      "Epoch 57: 0.9867\n",
      "Epoch 58: 0.9891\n",
      "Epoch 59: 0.9912\n",
      "Epoch 60: 0.9887\n",
      "Epoch 61: 0.9871\n",
      "Epoch 62: 0.9927\n",
      "Epoch 63: 0.9912\n",
      "Epoch 64: 0.9977\n",
      "Epoch 65: 0.9934\n",
      "Epoch 66: 0.9944\n",
      "Epoch 67: 0.9943\n",
      "Epoch 68: 0.9939\n",
      "Epoch 69: 0.9841\n",
      "Epoch 70: 0.9853\n",
      "Epoch 71: 0.9880\n",
      "Epoch 72: 0.9851\n",
      "Epoch 73: 0.9856\n",
      "Epoch 74: 0.9918\n",
      "Epoch 75: 0.9901\n",
      "Epoch 76: 0.9974\n",
      "Epoch 77: 0.9935\n",
      "Epoch 78: 0.9939\n",
      "Epoch 79: 0.9933\n",
      "Epoch 80: 0.9934\n",
      "Epoch 81: 0.9853\n",
      "Epoch 82: 0.9832\n",
      "Epoch 83: 0.9835\n",
      "Epoch 84: 0.9819\n",
      "Epoch 85: 0.9790\n",
      "Epoch 86: 0.9824\n",
      "Epoch 87: 0.9880\n",
      "Epoch 88: 0.9948\n",
      "Epoch 89: 0.9969\n",
      "Epoch 90: 0.9895\n",
      "Epoch 91: 0.9817\n",
      "Epoch 92: 0.9889\n",
      "Epoch 93: 0.9866\n",
      "Epoch 94: 0.9753\n",
      "Epoch 95: 0.9745\n",
      "Epoch 96: 0.9888\n",
      "Epoch 97: 0.9925\n",
      "Epoch 98: 0.9787\n",
      "Epoch 99: 0.9790\n",
      "0.9409182131290436\n",
      "0.9694820076227189\n"
     ]
    }
   ],
   "source": [
    "n_expes = 1\n",
    "show = True\n",
    "\n",
    "epochs = 100\n",
    "\n",
    "total_size = 400\n",
    "batch_size = 10\n",
    "\n",
    "lr_adam = 1e-2\n",
    "\n",
    "momentum = 0.9 #.9\n",
    "dampening = 0.9 #.9\n",
    "hg_period=8\n",
    "hg_damping=.3\n",
    "hg_mom_lrs=0.\n",
    "hg_batch_size=batch_size*2\n",
    "dct_nesterov[\"damping_int\"] = 1.\n",
    "hg_remove_negative=True\n",
    "dct_uniform_mean = {\"use\": True, \"period\": 20, \"warmup\": 5}\n",
    "\n",
    "results_adam = []\n",
    "results_ns = []\n",
    "\n",
    "for i in range(n_expes):\n",
    "    dataset, data_loader = build_dataset(total_size, size_in, size_out, batch_size)\n",
    "    data_loader_hg = torch.utils.data.DataLoader(dataset, hg_batch_size)\n",
    "    results_adam.append(train_adam(epochs, data_loader, lr_adam, show = show))\n",
    "    if show: print(\"\")\n",
    "    results_ns.append(train_NS(epochs, data_loader, momentum = momentum, dampening = momentum, hg_period = hg_period, \n",
    "                               hg_mom_lrs = hg_mom_lrs, hg_remove_negative = hg_remove_negative, dct_nesterov = dct_nesterov, \n",
    "                               hg_damping = hg_damping, dct_uniform_mean = dct_uniform_mean, data_loader_hg = data_loader_hg, \n",
    "                               show = show))\n",
    "\n",
    "print(results_adam[0])\n",
    "print(results_ns[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86d414f3-44be-460e-aa48-52dcbb03fafd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
