{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5e0d5ca3-4077-4662-88cb-1166eb288da9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "from grnewt import compute_Hg, nesterov_lrs, fullbatch_gradient\n",
    "from grnewt import NewtonSummaryVanilla\n",
    "from grnewt import optimizers\n",
    "from grnewt import partition as build_partition\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "046b90a1-0af8-45df-9f3b-c81f7f89eb47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build dummy regression dataset\n",
    "\n",
    "total_size = 200\n",
    "\n",
    "size_in = 5\n",
    "size_out = 4\n",
    "batch_size = 10\n",
    "\n",
    "data_in = torch.randn(total_size, size_in)\n",
    "data_tar = torch.randn(total_size, size_out)\n",
    "\n",
    "dataset = torch.utils.data.TensorDataset(data_in, data_tar)\n",
    "data_loader = torch.utils.data.DataLoader(dataset, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d394f1a7-246c-4ef1-999b-49eaa8531da3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define simple model\n",
    "\n",
    "size_hidden = 6\n",
    "act_function_cl = torch.nn.Tanh\n",
    "\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model, self).__init__()\n",
    "\n",
    "        self.hidden_layer = torch.nn.Linear(size_in, size_hidden)\n",
    "        self.activation = act_function_cl()\n",
    "        self.out_layer = torch.nn.Linear(size_hidden, size_out)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.hidden_layer(x)\n",
    "        x = self.activation(x)\n",
    "        return self.out_layer(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f44e7af1-c574-4751-812f-4aa8b02ad3c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build losses\n",
    "\n",
    "loss_fn = lambda x, y: (x - y).pow(2).mean() #.sqrt()\n",
    "full_loss = lambda x, y: loss_fn(model(x), y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7ada2348-87d7-43a7-8ece-c78e08e6873d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#optimizer.name='NewtonSummary'\n",
    "#optimizer.momentum=.9\n",
    "hg_batch_size=1000\n",
    "#optimizer.hg.optimizer='SGD'\n",
    "\n",
    "\n",
    "dct_nesterov = {'use': True, \n",
    "                'damping_int': 1.,\n",
    "                'mom_order3_': 0.}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "67497233-e53a-4cd4-b3db-6c637565155f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 0.9769\n",
      "Epoch 1: 0.9329\n",
      "Epoch 2: 0.8707\n",
      "Epoch 3: 0.8371\n",
      "Epoch 4: 0.8188\n",
      "Epoch 5: 0.8024\n",
      "Epoch 6: 0.7835\n",
      "Epoch 7: 0.7669\n",
      "Epoch 8: 0.7523\n",
      "Epoch 9: 0.7395\n",
      "Epoch 10: 0.7291\n",
      "Epoch 11: 0.7210\n",
      "Epoch 12: 0.7145\n",
      "Epoch 13: 0.7075\n",
      "Epoch 14: 0.7016\n",
      "Epoch 15: 0.6966\n",
      "Epoch 16: 0.6929\n",
      "Epoch 17: 0.6881\n",
      "Epoch 18: 0.6841\n",
      "Epoch 19: 0.6770\n",
      "Epoch 20: 0.6716\n",
      "Epoch 21: 0.6661\n",
      "Epoch 22: 0.6635\n",
      "Epoch 23: 0.6604\n",
      "Epoch 24: 0.6584\n",
      "Epoch 25: 0.6561\n",
      "Epoch 26: 0.6545\n",
      "Epoch 27: 0.6529\n",
      "Epoch 28: 0.6523\n",
      "Epoch 29: 0.6515\n",
      "Epoch 30: 0.6519\n",
      "Epoch 31: 0.6517\n",
      "Epoch 32: 0.6532\n",
      "Epoch 33: 0.6530\n",
      "Epoch 34: 0.6551\n",
      "Epoch 35: 0.6541\n",
      "Epoch 36: 0.6554\n",
      "Epoch 37: 0.6536\n",
      "Epoch 38: 0.6530\n",
      "Epoch 39: 0.6511\n",
      "Epoch 40: 0.6491\n",
      "Epoch 41: 0.6477\n",
      "Epoch 42: 0.6455\n",
      "Epoch 43: 0.6450\n",
      "Epoch 44: 0.6431\n",
      "Epoch 45: 0.6431\n",
      "Epoch 46: 0.6417\n",
      "Epoch 47: 0.6418\n",
      "Epoch 48: 0.6408\n",
      "Epoch 49: 0.6409\n",
      "Epoch 50: 0.6402\n",
      "Epoch 51: 0.6402\n",
      "Epoch 52: 0.6396\n",
      "Epoch 53: 0.6397\n",
      "Epoch 54: 0.6392\n",
      "Epoch 55: 0.6393\n",
      "Epoch 56: 0.6387\n",
      "Epoch 57: 0.6389\n",
      "Epoch 58: 0.6383\n",
      "Epoch 59: 0.6386\n",
      "Epoch 60: 0.6380\n",
      "Epoch 61: 0.6383\n",
      "Epoch 62: 0.6377\n",
      "Epoch 63: 0.6380\n",
      "Epoch 64: 0.6375\n",
      "Epoch 65: 0.6377\n",
      "Epoch 66: 0.6373\n",
      "Epoch 67: 0.6375\n",
      "Epoch 68: 0.6371\n",
      "Epoch 69: 0.6373\n",
      "Epoch 70: 0.6369\n",
      "Epoch 71: 0.6371\n",
      "Epoch 72: 0.6367\n",
      "Epoch 73: 0.6369\n",
      "Epoch 74: 0.6366\n",
      "Epoch 75: 0.6367\n",
      "Epoch 76: 0.6364\n",
      "Epoch 77: 0.6365\n",
      "Epoch 78: 0.6363\n",
      "Epoch 79: 0.6364\n",
      "Epoch 80: 0.6362\n",
      "Epoch 81: 0.6362\n",
      "Epoch 82: 0.6361\n",
      "Epoch 83: 0.6361\n",
      "Epoch 84: 0.6360\n",
      "Epoch 85: 0.6359\n",
      "Epoch 86: 0.6359\n",
      "Epoch 87: 0.6358\n",
      "Epoch 88: 0.6358\n",
      "Epoch 89: 0.6357\n",
      "Epoch 90: 0.6357\n",
      "Epoch 91: 0.6356\n",
      "Epoch 92: 0.6356\n",
      "Epoch 93: 0.6355\n",
      "Epoch 94: 0.6355\n",
      "Epoch 95: 0.6354\n",
      "Epoch 96: 0.6355\n",
      "Epoch 97: 0.6353\n",
      "Epoch 98: 0.6354\n",
      "Epoch 99: 0.6352\n"
     ]
    }
   ],
   "source": [
    "epochs=100\n",
    "lr = .1\n",
    "\n",
    "# Build model\n",
    "model = Model()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr = lr)\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    loss_tot = 0\n",
    "    n = 0\n",
    "    for x, y in data_loader:\n",
    "        model.zero_grad()\n",
    "        yhat = model(x)\n",
    "        loss = loss_fn(yhat, y)\n",
    "        #print(loss.item())\n",
    "        loss_tot += loss.item()\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        n += 1\n",
    "    print(f\"Epoch {epoch}: {loss_tot/n:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "64892e12-a9e9-4486-a7e4-9d77ee5863d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 0.9737\n",
      "Epoch 1: 0.8949\n",
      "Epoch 2: 0.8495\n",
      "Epoch 3: 0.8391\n",
      "Epoch 4: 0.8303\n",
      "Epoch 5: 0.8246\n",
      "Epoch 6: 0.8210\n",
      "Epoch 7: 0.8181\n",
      "Epoch 8: 0.8146\n",
      "Epoch 9: 0.8100\n",
      "Epoch 10: 0.8041\n",
      "Epoch 11: 0.7967\n",
      "Epoch 12: 0.7879\n",
      "Epoch 13: 0.7761\n",
      "Epoch 14: 0.7674\n",
      "Epoch 15: 0.7598\n",
      "Epoch 16: 0.7529\n",
      "Epoch 17: 0.7470\n",
      "Epoch 18: 0.7419\n",
      "Epoch 19: 0.7374\n",
      "Epoch 20: 0.7333\n",
      "Epoch 21: 0.7296\n",
      "Epoch 22: 0.7261\n",
      "Epoch 23: 0.7229\n",
      "Epoch 24: 0.7200\n",
      "Epoch 25: 0.7173\n",
      "Epoch 26: 0.7148\n",
      "Epoch 27: 0.7126\n",
      "Epoch 28: 0.7105\n",
      "Epoch 29: 0.7086\n",
      "Epoch 30: 0.7068\n",
      "Epoch 31: 0.7051\n",
      "Epoch 32: 0.7036\n",
      "Epoch 33: 0.7022\n",
      "Epoch 34: 0.7009\n",
      "Epoch 35: 0.6997\n",
      "Epoch 36: 0.6986\n",
      "Epoch 37: 0.6975\n",
      "Epoch 38: 0.6965\n",
      "Epoch 39: 0.6956\n",
      "Epoch 40: 0.6947\n",
      "Epoch 41: 0.6939\n",
      "Epoch 42: 0.6931\n",
      "Epoch 43: 0.6924\n",
      "Epoch 44: 0.6917\n",
      "Epoch 45: 0.6911\n",
      "Epoch 46: 0.6904\n",
      "Epoch 47: 0.6898\n",
      "Epoch 48: 0.6893\n",
      "Epoch 49: 0.6887\n",
      "Epoch 50: 0.6882\n",
      "Epoch 51: 0.6877\n",
      "Epoch 52: 0.6873\n",
      "Epoch 53: 0.6868\n",
      "Epoch 54: 0.6864\n",
      "Epoch 55: 0.6860\n",
      "Epoch 56: 0.6856\n",
      "Epoch 57: 0.6852\n",
      "Epoch 58: 0.6848\n",
      "Epoch 59: 0.6844\n",
      "Epoch 60: 0.6841\n",
      "Epoch 61: 0.6838\n",
      "Epoch 62: 0.6834\n",
      "Epoch 63: 0.6831\n",
      "Epoch 64: 0.6828\n",
      "Epoch 65: 0.6825\n",
      "Epoch 66: 0.6823\n",
      "Epoch 67: 0.6820\n",
      "Epoch 68: 0.6817\n",
      "Epoch 69: 0.6815\n",
      "Epoch 70: 0.6812\n",
      "Epoch 71: 0.6810\n",
      "Epoch 72: 0.6807\n",
      "Epoch 73: 0.6805\n",
      "Epoch 74: 0.6803\n",
      "Epoch 75: 0.6801\n",
      "Epoch 76: 0.6798\n",
      "Epoch 77: 0.6796\n",
      "Epoch 78: 0.6794\n",
      "Epoch 79: 0.6792\n",
      "Epoch 80: 0.6790\n",
      "Epoch 81: 0.6788\n",
      "Epoch 82: 0.6787\n",
      "Epoch 83: 0.6785\n",
      "Epoch 84: 0.6783\n",
      "Epoch 85: 0.6781\n",
      "Epoch 86: 0.6780\n",
      "Epoch 87: 0.6778\n",
      "Epoch 88: 0.6776\n",
      "Epoch 89: 0.6775\n",
      "Epoch 90: 0.6773\n",
      "Epoch 91: 0.6772\n",
      "Epoch 92: 0.6770\n",
      "Epoch 93: 0.6769\n",
      "Epoch 94: 0.6767\n",
      "Epoch 95: 0.6766\n",
      "Epoch 96: 0.6764\n",
      "Epoch 97: 0.6763\n",
      "Epoch 98: 0.6762\n",
      "Epoch 99: 0.6760\n"
     ]
    }
   ],
   "source": [
    "epochs=100\n",
    "Cl_Updater = optimizers.AdamUpdate # AdamUpdate, SGDUpdate\n",
    "lr = 1e0\n",
    "momentum = .9\n",
    "dampening = .9\n",
    "betas = (.9, .999)\n",
    "\n",
    "hg_period=10\n",
    "hg_damping=.1\n",
    "hg_mom_lrs=0.\n",
    "dct_nesterov[\"damping_int\"] = 10.\n",
    "hg_remove_negative=False\n",
    "\n",
    "data_loader_hg = torch.utils.data.DataLoader(dataset, batch_size * 10)\n",
    "\n",
    "# Build model\n",
    "model = Model()\n",
    "param_groups, name_groups = build_partition.canonical(model) # canonical, trivial, wb\n",
    "\n",
    "if Cl_Updater == optimizers.SGDUpdate:\n",
    "    updater = Cl_Updater(model.parameters(), lr = lr, momentum = momentum, dampening = dampening)\n",
    "elif Cl_Updater == optimizers.AdamUpdate:\n",
    "    updater = Cl_Updater(model.parameters(), lr = lr, betas = betas)\n",
    "else:\n",
    "    raise ValueError()\n",
    "\n",
    "optimizer = NewtonSummaryVanilla(param_groups, full_loss, data_loader_hg, updater, \n",
    "                     damping = hg_damping, period_hg = hg_period, mom_lrs = hg_mom_lrs,\n",
    "                     dct_nesterov = dct_nesterov, remove_negative = hg_remove_negative, \n",
    "                     maintain_true_lrs = True)\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    loss_tot = 0\n",
    "    n = 0\n",
    "    for x, y in data_loader:\n",
    "        model.zero_grad()\n",
    "        yhat = model(x)\n",
    "        loss = loss_fn(yhat, y)\n",
    "        #print(loss.item())\n",
    "        loss_tot += loss.item()\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        n += 1\n",
    "    print(f\"Epoch {epoch}: {loss_tot/n:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93356c19-5ff2-41d5-95b6-804c35c209c2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
