{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from models.shared import *\n",
    "\n",
    "class UNet_gate(nn.Module):\n",
    "    \"\"\"\n",
    "    UNet - Basic Implementation\n",
    "    Paper : https://arxiv.org/abs/1505.04597\n",
    "    \"\"\"\n",
    "    def __init__(self, in_ch=3, out_ch=1, kernel_size=3,decoder_ks = 3):\n",
    "        super().__init__()\n",
    "\n",
    "        n1 = 64\n",
    "        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]\n",
    "        self.ks = kernel_size\n",
    "        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.Uppool = nn.Upsample(scale_factor=2,mode = 'bilinear',align_corners= True)\n",
    "        \n",
    "\n",
    "        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)\n",
    "        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)\n",
    "        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)\n",
    "        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)\n",
    "        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)\n",
    "        \n",
    "        self.Conv6_cla = nn.Conv2d(filters[4],out_ch,kernel_size=1)\n",
    "        self.Conv5_cla = nn.Conv2d(filters[3],out_ch,kernel_size=1)\n",
    "        self.Conv4_cla = nn.Conv2d(filters[2],out_ch,kernel_size=1)\n",
    "        self.Conv3_cla = nn.Conv2d(filters[1],out_ch,kernel_size=1)\n",
    "        self.Conv2_cla = nn.Conv2d(filters[0],out_ch,kernel_size=1)\n",
    "        \n",
    "        self.Conv6_gate = nn.Conv2d(filters[3],1,kernel_size=2,stride = 2)\n",
    "        self.Conv5_gate = nn.Conv2d(filters[2],1,kernel_size=2,stride = 2)\n",
    "        self.Conv4_gate = nn.Conv2d(filters[1],1,kernel_size=2,stride = 2)\n",
    "        self.Conv3_gate = nn.Conv2d(filters[0],1,kernel_size=2,stride = 2)\n",
    "        \n",
    "        self.Up5 = up_conv(filters[4], filters[3])\n",
    "        self.Up_conv5 = conv_block(filters[3]+filters[3], filters[3], kernel_size=decoder_ks)\n",
    "\n",
    "        self.Up4 = up_conv(filters[3], filters[2])\n",
    "        self.Up_conv4 = conv_block(filters[2]+filters[2], filters[2], kernel_size=decoder_ks)\n",
    "\n",
    "        self.Up3 = up_conv(filters[2], filters[1])\n",
    "        self.Up_conv3 = conv_block(filters[1]+filters[1], filters[1], kernel_size=decoder_ks)\n",
    "\n",
    "        self.Up2 = up_conv(filters[1], filters[0])\n",
    "        self.Up_conv2 = conv_block(filters[0]+filters[0], filters[0], kernel_size=decoder_ks)\n",
    "\n",
    "\n",
    "       # self.active = torch.nn.Sigmoid()\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        e1 = self.Conv1(x)\n",
    "\n",
    "        e2 = self.Maxpool1(e1)\n",
    "        e2 = self.Conv2(e2)\n",
    "        g3 = self.Conv3_gate(e1)\n",
    "        \n",
    "        e3 = self.Maxpool2(e2)\n",
    "        e3 = self.Conv3(e3)\n",
    "        g4 = self.Conv4_gate(e2)\n",
    "\n",
    "        e4 = self.Maxpool3(e3)\n",
    "        e4 = self.Conv4(e4)\n",
    "        g5 = self.Conv5_gate(e3)\n",
    "\n",
    "        e5 = self.Maxpool4(e4)\n",
    "        e5 = self.Conv5(e5)\n",
    "        c6 = self.Conv6_cla(e5)\n",
    "        g6 = self.Conv6_gate(e4)\n",
    "        \n",
    "        d5 = self.Up5(e5)\n",
    "        d5 = torch.cat((e4, d5), dim=1)\n",
    "        d5 = self.Up_conv5(d5)\n",
    "        c5 = self.Conv5_cla(d5)\n",
    "        \n",
    "        d4 = self.Up4(d5)\n",
    "        d4 = torch.cat((e3, d4), dim=1)\n",
    "        d4 = self.Up_conv4(d4)\n",
    "        c4 = self.Conv4_cla(d4)\n",
    "\n",
    "        d3 = self.Up3(d4)\n",
    "        d3 = torch.cat((e2, d3), dim=1)\n",
    "        d3 = self.Up_conv3(d3)\n",
    "        c3 = self.Conv3_cla(d3)\n",
    "\n",
    "        d2 = self.Up2(d3)\n",
    "        d2 = torch.cat((e1, d2), dim=1)\n",
    "        d2 = self.Up_conv2(d2)\n",
    "        c2 = self.Conv2_cla(d2)\n",
    "\n",
    "#         out = self.Conv(d2)\n",
    "\n",
    "        #d1 = self.active(out)\n",
    "        return (c6,c5,c4,c3,c2),(g6,g5,g4,g3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(1,3,256,256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model= UNet_gate(3,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 5, 16, 16])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs[0][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = torch.zeros(1,5,256,256)\n",
    "target[:,-1] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import loss_functions\n",
    "task = 'training'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hul/.local/lib/python3.7/site-packages/torch/nn/functional.py:3208: UserWarning: nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.\n",
      "  warnings.warn(\"nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.\")\n"
     ]
    }
   ],
   "source": [
    "o,g = outputs\n",
    "not_visited = torch.ones(o[0].shape[0],1,o[0].shape[2],o[0].shape[3]).to(x.device)\n",
    "result = torch.zeros_like(o[0])\n",
    "loss = 0 \n",
    "for idx,scale_f in enumerate([16,8,4,2,1]):\n",
    "    target_pool = F.avg_pool2d(target,kernel_size = scale_f,stride=scale_f)\n",
    "\n",
    "    \n",
    "    if scale_f!=1:\n",
    "        expand = not_visited.float()\n",
    "        expand = F.max_pool2d(expand.float(),kernel_size = 3,stride = 1,padding = 1).bool()\n",
    "        loss += scale_f**0.5*loss_functions.mae_loss(o[idx],target_pool,sign = expand)\n",
    "        loss += scale_f**0.5*loss_functions.mae_loss(g[idx],(torch.max(target_pool,dim = 1,keepdim=True)[0]>=(1-1./(scale_f**2))).float(),sign = expand)\n",
    "\n",
    "        if task == 'training':\n",
    "            cur_visited = (torch.max(target_pool,dim = 1,keepdim=True)[0]>=(1-1./(scale_f**2))).float()\n",
    "        else:\n",
    "            cur_visited = g[idx]>=0.5\n",
    "    else:\n",
    "        cur_visited = 1\n",
    "        expand = not_visited.float()\n",
    "        expand = F.max_pool2d(expand.float(),kernel_size = 5,stride = 1,padding =2).bool()\n",
    "        loss += scale_f**0.5*loss_functions.kl_loss(o[idx],target_pool,sign = expand) #scale_f**0.5*\n",
    "    \n",
    "\n",
    "    #cache for next step feature\n",
    "    result += (not_visited*cur_visited)*o[idx]\n",
    "\n",
    "    not_visited = not_visited*(1-cur_visited)\n",
    "    if scale_f != 1:\n",
    "        not_visited = F.upsample_nearest(not_visited,scale_factor=2)\n",
    "        result = F.upsample_nearest(result,scale_factor=2)\n",
    "\n",
    "y_hr_pred = result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 1, 256, 256])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "not_visited.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(3,4,256,256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = x.reshape(3,4,128,2,128,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = r.permute(0,1,3,5,2,4).reshape(3,4,4,128,128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 4, 128, 128])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.var(c,dim = 2).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class Condidence_gate(nn.Module):\n",
    "    def __init__(self,in_channel,out_channel=1,kernel_size=2,stride = 2):\n",
    "        super().__init__()\n",
    "        assert kernel_size == stride\n",
    "        self.ks = kernel_size\n",
    "        self.conv = nn.Sequential(nn.Conv2d(in_channel,in_channel//2,kernel_size=2,stride = 2),\n",
    "                                    nn.BatchNorm2d(in_channel//2),\n",
    "                                    nn.ReLU())\n",
    "        self.conv_var = nn.Sequential(nn.Conv2d(in_channel,in_channel-in_channel//2,kernel_size=1),\n",
    "                                    nn.BatchNorm2d(in_channel-in_channel//2),\n",
    "                                    nn.ReLU())\n",
    "\n",
    "        self.conv_gate = nn.Sequential(nn.Conv2d(in_channel,out_channel,kernel_size=1),nn.Sigmoid())\n",
    "\n",
    "    def forward(self,x):\n",
    "        B,C,H,W = x.shape\n",
    "        v = x.reshape(B,C,H//self.ks,self.ks,W//self.ks,self.ks)\n",
    "        v = v.permute(0,1,3,5,2,4).reshape(B,C,self.ks**2,H//self.ks,W//self.ks)\n",
    "        v= torch.var(v,dim = 2)\n",
    "        print(v.shape)\n",
    "        v = self.conv_var(v)\n",
    "        \n",
    "        x = self.conv(x)\n",
    "\n",
    "        return self.conv_gate(torch.cat([x,v],dim =1))\n",
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Condidence_gate(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 3, 5, 5])\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Given groups=1, weight of size [1, 3, 1, 1], expected input[1, 2, 5, 5] to have 3 channels, but got 2 channels instead",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-40-97da547225c0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/.local/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-38-7bd0369bfe30>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     23\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv_gate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdim\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.7/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    115\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m             \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    118\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.7/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    422\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 423\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_conv_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    424\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    425\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mConv3d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_ConvNd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.7/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36m_conv_forward\u001b[0;34m(self, input, weight)\u001b[0m\n\u001b[1;32m    418\u001b[0m                             _pair(0), self.dilation, self.groups)\n\u001b[1;32m    419\u001b[0m         return F.conv2d(input, weight, self.bias, self.stride,\n\u001b[0;32m--> 420\u001b[0;31m                         self.padding, self.dilation, self.groups)\n\u001b[0m\u001b[1;32m    421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    422\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Given groups=1, weight of size [1, 3, 1, 1], expected input[1, 2, 5, 5] to have 3 channels, but got 2 channels instead"
     ]
    }
   ],
   "source": [
    "model(torch.randn(1,3,10,10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
