{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "0eb780cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d279c146",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer = nn.Linear(9,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fce06619",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer.in_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "87a6b366",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer.out_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "494da758",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(3,2,4,4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a5b72782",
   "metadata": {},
   "outputs": [],
   "source": [
    "bn1d = nn.BatchNorm1d(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a0ce3bee",
   "metadata": {},
   "outputs": [],
   "source": [
    "bn2d = nn.BatchNorm2d(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "bc931309",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.9569, -0.7949,  1.3136,  0.6220, -0.3353,  1.6615,  0.3000,\n",
       "           2.1551,  0.5807, -2.0479,  0.8196,  0.0826,  0.4183,  0.7555,\n",
       "           0.7565, -1.3567],\n",
       "         [ 0.5158, -0.5265,  0.0332,  2.1364, -0.5130, -0.4107,  2.0007,\n",
       "          -0.9213,  0.5166,  0.1016,  0.2755,  0.8551,  0.6120, -1.4044,\n",
       "          -1.4384, -1.0647]],\n",
       "\n",
       "        [[-0.4135, -0.4451,  0.8586,  0.4593, -0.4868, -0.7337, -1.2463,\n",
       "           1.2187, -1.0884,  0.2791, -0.1157, -0.1326, -0.7610,  0.7755,\n",
       "           0.3624,  0.1180],\n",
       "         [-0.5167,  1.5329,  0.4519, -0.6117,  0.2868, -0.4266,  0.1023,\n",
       "           0.7071,  1.1832, -2.0488,  0.0033, -0.4413, -0.6110,  0.8175,\n",
       "          -0.0995,  0.6308]],\n",
       "\n",
       "        [[-0.6305,  0.7671, -1.4042,  1.5287,  0.0228, -0.3188, -1.3402,\n",
       "          -0.9405, -0.9813,  0.2581, -0.1096, -0.8756, -2.6717,  1.4622,\n",
       "           1.1261, -0.4286],\n",
       "         [ 1.4402,  0.8704,  1.0997, -0.2029, -1.5159,  1.0008,  0.6271,\n",
       "          -0.2882, -0.2003, -1.7213,  0.9306, -0.8571, -1.9241,  1.0537,\n",
       "          -1.4706, -0.5702]]], grad_fn=<NativeBatchNormBackward0>)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bn1d(x.flatten(2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9b7807f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.9569, -0.7949,  1.3136,  0.6220, -0.3353,  1.6615,  0.3000,\n",
       "           2.1551,  0.5807, -2.0479,  0.8196,  0.0826,  0.4183,  0.7555,\n",
       "           0.7565, -1.3567],\n",
       "         [ 0.5158, -0.5265,  0.0332,  2.1364, -0.5130, -0.4107,  2.0007,\n",
       "          -0.9213,  0.5166,  0.1016,  0.2755,  0.8551,  0.6120, -1.4044,\n",
       "          -1.4384, -1.0647]],\n",
       "\n",
       "        [[-0.4135, -0.4451,  0.8586,  0.4593, -0.4868, -0.7337, -1.2463,\n",
       "           1.2187, -1.0884,  0.2791, -0.1157, -0.1326, -0.7610,  0.7755,\n",
       "           0.3624,  0.1180],\n",
       "         [-0.5167,  1.5329,  0.4519, -0.6117,  0.2868, -0.4266,  0.1023,\n",
       "           0.7071,  1.1832, -2.0488,  0.0033, -0.4413, -0.6110,  0.8175,\n",
       "          -0.0995,  0.6308]],\n",
       "\n",
       "        [[-0.6305,  0.7671, -1.4042,  1.5287,  0.0228, -0.3188, -1.3402,\n",
       "          -0.9405, -0.9813,  0.2581, -0.1096, -0.8756, -2.6717,  1.4622,\n",
       "           1.1261, -0.4286],\n",
       "         [ 1.4402,  0.8704,  1.0997, -0.2029, -1.5159,  1.0008,  0.6271,\n",
       "          -0.2882, -0.2003, -1.7213,  0.9306, -0.8571, -1.9241,  1.0537,\n",
       "          -1.4706, -0.5702]]], grad_fn=<ReshapeAliasBackward0>)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bn2d(x).flatten(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ecb1bc2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jointist",
   "language": "python",
   "name": "jointist"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
