{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3d05273a-d531-4409-9dde-6f9ed5eb7994",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "import copy\n",
    "import torch\n",
    "from torch import nn\n",
    "import grnewt\n",
    "from grnewt.optimizers import AdamUpdate, SGDUpdate\n",
    "from torch.optim import Adam, SGD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2a6f2cb0-778c-4e93-b9d5-2138f7a9da12",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(type='cuda', index=0)\n",
    "dtype = torch.float32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1d011693-df7e-4768-96ea-766ba3a92a8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Perceptron(torch.nn.Module):\n",
    "    def __init__(self, layers, act_name = 'tanh'):\n",
    "        super(Perceptron, self).__init__()\n",
    "        \n",
    "        if act_name == 'identity':\n",
    "            act_name = 'linear'\n",
    "    \n",
    "        gain = nn.init.calculate_gain(act_name)\n",
    "        \n",
    "        self.layers = torch.nn.ModuleList()\n",
    "        for l_in, l_out in zip(layers[:-1], layers[1:]):\n",
    "            self.layers.append(torch.nn.Linear(l_in, l_out))\n",
    "            with torch.no_grad():\n",
    "                self.layers[-1].weight.mul_(gain)\n",
    "        self.nb_layers = len(self.layers)\n",
    "        \n",
    "        if act_name in ['tanh', 'sigmoid', 'relu']:\n",
    "            self.act_function = torch.__dict__[act_name]\n",
    "        elif act_name == 'linear':\n",
    "            self.act_function = lambda x: x\n",
    "        \n",
    "    def forward(self, x):\n",
    "        for l in self.layers[:-1]:\n",
    "            x = l(x)\n",
    "            x = self.act_function(x)\n",
    "        x = self.layers[-1](x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "c94e7a25-7e98-4d89-9fce-c5df01bf3a4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def copy_model(src, dst):\n",
    "    src_params = dict(src.named_parameters())\n",
    "    dst_params = dict(dst.named_parameters())\n",
    "    with torch.no_grad():\n",
    "        for n, p in src_params.items():\n",
    "            dst_params[n].data.copy_(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7a158545-b153-442a-9697-2e5ed889f9f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the model\n",
    "\n",
    "layers = [100, 60, 20, 10]\n",
    "act_name = 'tanh'\n",
    "Din = layers[0]\n",
    "Dout = layers[-1]\n",
    "\n",
    "loss_indiv = lambda y1, y2: (y1 - y2).pow(2).sum(1)\n",
    "loss_sum = lambda y1, y2: loss_indiv(y1, y2).sum(0)\n",
    "loss_mean = torch.nn.MSELoss(reduction = 'mean')\n",
    "\n",
    "model = Perceptron(layers, act_name)\n",
    "model = model.to(device = device, dtype = dtype)\n",
    "\n",
    "dct_params = OrderedDict(model.named_parameters())\n",
    "tup_params = tuple(v for _, v in dct_params.items())\n",
    "shape = tuple(t.size() for t in tup_params)\n",
    "num_params = sum(t.numel() for t in tup_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4a006ac8-217d-4373-9524-9e465d454b6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build a toy dataset\n",
    "\n",
    "n_ts = 7\n",
    "x_ts = torch.randn(n_ts, Din, device = device, dtype = dtype)\n",
    "y_ts = torch.randn(n_ts, Dout, device = device, dtype = dtype)\n",
    "\n",
    "n_tr = 3\n",
    "x_tr = torch.randn(n_tr, Din, device = device, dtype = dtype)\n",
    "y_tr = torch.randn(n_tr, Dout, device = device, dtype = dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "f2b798f3-e5ca-4767-905a-759817feb790",
   "metadata": {},
   "outputs": [],
   "source": [
    "model2 = Perceptron(layers, act_name).to(device = device, dtype = dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a526f38d-cb3a-455c-87ca-70e75f79a892",
   "metadata": {},
   "outputs": [],
   "source": [
    "copy_model(model, model2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "48d6ef63-440b-4c2c-8dcb-efccb0d83780",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.equal(model.layers[0].weight.data, model2.layers[0].weight.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "7e4ab500-377d-4a92-9da9-87647a6201bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "print(id(model.layers[0].weight.data) == id(model2.layers[0].weight.data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "348cbc47-ee2d-471a-b640-1d4e52e7ca7d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0918, -0.1619, -0.1643,  ...,  0.0175,  0.1238,  0.0036],\n",
       "        [ 0.1612, -0.0218, -0.0259,  ..., -0.0433,  0.1504, -0.0217],\n",
       "        [ 0.1402,  0.0008, -0.0873,  ..., -0.1235,  0.0162, -0.1103],\n",
       "        ...,\n",
       "        [ 0.1547, -0.1008,  0.0586,  ..., -0.0610,  0.1551, -0.0530],\n",
       "        [-0.0173,  0.0665, -0.0950,  ..., -0.1231, -0.0419,  0.1482],\n",
       "        [-0.0862,  0.0611, -0.1066,  ...,  0.0273,  0.1483,  0.1467]])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model2.layers[0].weight.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "b8ba0e62-6d7b-4150-a61b-567dd96aff37",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('layers',\n",
       "              ModuleList(\n",
       "                (0): Linear(in_features=100, out_features=60, bias=True)\n",
       "                (1): Linear(in_features=60, out_features=20, bias=True)\n",
       "                (2): Linear(in_features=20, out_features=10, bias=True)\n",
       "              ))])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model._modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "8ae0120b-09fc-430f-ae3d-c13d4c39d565",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-1.7124, -0.5146, -0.3749, -2.0397, -0.6050],\n",
      "        [-0.5018, -0.0056, -0.9185, -0.5233,  1.3444],\n",
      "        [ 0.1072, -2.8564, -2.5753,  0.4318, -1.7363],\n",
      "        [-2.1104,  2.4521, -0.0224, -0.4488, -0.7589],\n",
      "        [ 0.3028, -1.2001,  1.8800, -0.5125, -0.8675],\n",
      "        [-0.7479, -0.8054,  0.6479,  0.1142, -0.5955]])\n",
      "tensor([[1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.]])\n",
      "tensor([[-1.7124, -0.5146, -0.3749, -2.0397, -0.6050],\n",
      "        [-0.5018, -0.0056, -0.9185, -0.5233,  1.3444],\n",
      "        [ 0.1072, -2.8564, -2.5753,  0.4318, -1.7363],\n",
      "        [-2.1104,  2.4521, -0.0224, -0.4488, -0.7589],\n",
      "        [ 0.3028, -1.2001,  1.8800, -0.5125, -0.8675],\n",
      "        [-0.7479, -0.8054,  0.6479,  0.1142, -0.5955]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(6, 5)\n",
    "y = torch.ones(6, 5)\n",
    "print(x)\n",
    "print(y)\n",
    "y.copy_(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "0e7a189e-8d7d-4162-899d-81f298ffedf3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "139644435613360"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "id(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "5c563489-2ef3-45ed-854b-10abd8f78c84",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "139644435613264"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "id(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d150b2fd-0b99-4653-8dd4-026db85e4a0f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
