{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a1d3c363-3247-4cc4-b078-9f48555bdb31",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import grnewt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b5a2def-a983-40b5-bfef-3e7a976ac885",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Perceptron(torch.nn.Module):\n",
    "    def __init__(self, layers, act_name = 'tanh'):\n",
    "        super(Perceptron, self).__init__()\n",
    "        \n",
    "        if act_name == 'identity':\n",
    "            act_name = 'linear'\n",
    "    \n",
    "        gain = nn.init.calculate_gain(act_name)\n",
    "        \n",
    "        self.layers = torch.nn.ModuleList()\n",
    "        for l_in, l_out in zip(layers[:-1], layers[1:]):\n",
    "            self.layers.append(torch.nn.Linear(l_in, l_out))\n",
    "            with torch.no_grad():\n",
    "                self.layers[-1].weight.mul_(gain)\n",
    "        self.nb_layers = len(self.layers)\n",
    "        \n",
    "        if act_name in ['tanh', 'sigmoid', 'relu']:\n",
    "            self.act_function = torch.__dict__[act_name]\n",
    "        elif act_name == 'linear':\n",
    "            self.act_function = lambda x: x\n",
    "        \n",
    "    def forward(self, x):\n",
    "        for l in self.layers[:-1]:\n",
    "            x = l(x)\n",
    "            x = self.act_function(x)\n",
    "        x = self.layers[-1](x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "cb2f33db-a49a-4af5-a167-6b31f08483e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_taylor_n_scalar_linear():\n",
    "    dtype = torch.float64\n",
    "    \n",
    "    a = torch.nn.Parameter(torch.randn(1, dtype = dtype))\n",
    "    b = torch.nn.Parameter(torch.randn(1, dtype = dtype))\n",
    "    x = torch.randn(1, dtype = dtype)\n",
    "    y = a * x + b\n",
    "    v = torch.randn(1, dtype = dtype)\n",
    "    result = grnewt.taylor_n(y, (a, b), 1, [v])\n",
    "\n",
    "    assert torch.allclose(result, torch.tensor([y, x * v], dtype = dtype))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "5e0b6c10-352f-44dd-b703-e30902c35d97",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_taylor_n_scalar_nonlinear():\n",
    "    dtype = torch.float64\n",
    "\n",
    "    N = 7\n",
    "    a = torch.nn.Parameter(torch.randn(1, dtype = dtype))\n",
    "    lst_a = torch.concat([a.pow(i) for i in range(N + 1)])\n",
    "    lst_x = torch.randn(N + 1, dtype = dtype)\n",
    "    y = torch.dot(lst_a, lst_x)\n",
    "    v = torch.randn(1, dtype = dtype)\n",
    "    result = grnewt.taylor_n(y, (a,), N, [v] * N)\n",
    "\n",
    "    expected = torch.empty(N + 1, dtype = dtype)\n",
    "    factors = lst_x.clone()\n",
    "    powers = torch.tensor([i for i in range(N + 1)], dtype = dtype)\n",
    "    with torch.no_grad():\n",
    "        for i in range(N + 1):\n",
    "            expected[i] = sum([f * a.pow(p) for f, p in zip(factors, powers)])\n",
    "            for i in range(N + 1):\n",
    "                factors[i] *= powers[i]\n",
    "                if powers[i] > 0:\n",
    "                    powers[i] -= 1\n",
    "\n",
    "    for i in range(N + 1):\n",
    "        expected[i] = expected[i] * v.pow(i)\n",
    "\n",
    "    assert torch.allclose(result, expected)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "69c0e891-4285-41ff-a013-aa5905d3ba6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_taylor_n_scalar_nonlinear()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "027f5a60-a6d4-4610-acc6-0642a5c0ae17",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_taylor_n_scalar_linear()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e06642b6-4525-4ea3-8363-0b8412478586",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
