{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms, models\n",
    "from torchvision.models.resnet import ResNet,BasicBlock,Bottleneck"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "21289802\n",
      "11181642\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Anaconda\\lib\\site-packages\\torchvision\\models\\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "d:\\Anaconda\\lib\\site-packages\\torchvision\\models\\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
      "  warnings.warn(msg)\n"
     ]
    }
   ],
   "source": [
    "#create Resnet18 using the ResNet Class\n",
    "model=ResNet(BasicBlock,[3,4,6,3],num_classes=10)\n",
    "print(sum(p.numel() for p in model.parameters() if p.requires_grad))\n",
    "model=models.resnet18(pretrained=False,num_classes=10)\n",
    "print(sum(p.numel() for p in model.parameters() if p.requires_grad))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#generate synthetic dataset using a classifier and a diffusion, the classifier gives the label and the diffusion gives the image\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23354698\n",
      "Files already downloaded and verified\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n",
      "torch.Size([128, 3, 32, 32])\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[1;32mIn [6]\u001b[0m, in \u001b[0;36m<cell line: 13>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     11\u001b[0m dataloader\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mDataLoader(dataset,batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m128\u001b[39m,shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m     12\u001b[0m accu\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m\n\u001b[1;32m---> 13\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i,(data,target) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(dataloader):\n\u001b[0;32m     14\u001b[0m     \u001b[38;5;28mprint\u001b[39m(data\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m     15\u001b[0m     output\u001b[38;5;241m=\u001b[39mmodel(data\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m),torch\u001b[38;5;241m.\u001b[39mTensor([\u001b[38;5;241m0\u001b[39m])\u001b[38;5;241m.\u001b[39mlong()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:630\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    627\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    628\u001b[0m     \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m    629\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()  \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 630\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m    632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[0;32m    633\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[0;32m    634\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:674\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    672\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m    673\u001b[0m     index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index()  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 674\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m    675\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[0;32m    676\u001b[0m         data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m     49\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m     50\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 51\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m     52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m     53\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:51\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m     49\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m     50\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 51\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m     52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m     53\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torchvision\\datasets\\cifar.py:118\u001b[0m, in \u001b[0;36mCIFAR10.__getitem__\u001b[1;34m(self, index)\u001b[0m\n\u001b[0;32m    115\u001b[0m img \u001b[38;5;241m=\u001b[39m Image\u001b[38;5;241m.\u001b[39mfromarray(img)\n\u001b[0;32m    117\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 118\u001b[0m     img \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    120\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    121\u001b[0m     target \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform(target)\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torchvision\\transforms\\transforms.py:95\u001b[0m, in \u001b[0;36mCompose.__call__\u001b[1;34m(self, img)\u001b[0m\n\u001b[0;32m     93\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, img):\n\u001b[0;32m     94\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransforms:\n\u001b[1;32m---> 95\u001b[0m         img \u001b[38;5;241m=\u001b[39m \u001b[43mt\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     96\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m img\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torchvision\\transforms\\transforms.py:277\u001b[0m, in \u001b[0;36mNormalize.forward\u001b[1;34m(self, tensor)\u001b[0m\n\u001b[0;32m    269\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, tensor: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m    270\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m    271\u001b[0m \u001b[38;5;124;03m    Args:\u001b[39;00m\n\u001b[0;32m    272\u001b[0m \u001b[38;5;124;03m        tensor (Tensor): Tensor image to be normalized.\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    275\u001b[0m \u001b[38;5;124;03m        Tensor: Normalized Tensor image.\u001b[39;00m\n\u001b[0;32m    276\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m--> 277\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormalize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minplace\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torchvision\\transforms\\functional.py:363\u001b[0m, in \u001b[0;36mnormalize\u001b[1;34m(tensor, mean, std, inplace)\u001b[0m\n\u001b[0;32m    360\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensor, torch\u001b[38;5;241m.\u001b[39mTensor):\n\u001b[0;32m    361\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimg should be Tensor Image. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(tensor)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m--> 363\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF_t\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormalize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmean\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstd\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minplace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minplace\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torchvision\\transforms\\_functional_tensor.py:922\u001b[0m, in \u001b[0;36mnormalize\u001b[1;34m(tensor, mean, std, inplace)\u001b[0m\n\u001b[0;32m    920\u001b[0m mean \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mas_tensor(mean, dtype\u001b[38;5;241m=\u001b[39mdtype, device\u001b[38;5;241m=\u001b[39mtensor\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[0;32m    921\u001b[0m std \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mas_tensor(std, dtype\u001b[38;5;241m=\u001b[39mdtype, device\u001b[38;5;241m=\u001b[39mtensor\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m--> 922\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[43mstd\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m)\u001b[38;5;241m.\u001b[39many():\n\u001b[0;32m    923\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstd evaluated to zero after conversion to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, leading to division by zero.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m    924\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m mean\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#Renset34 with time embedding\n",
    "from ResNet_Time import ResNet_Time,BasicBlock_Time\n",
    "model=ResNet_Time(BasicBlock_Time,[3,4,6,3],num_classes=10)\n",
    "#model=ResNet(BasicBlock,[3,4,6,3],num_classes=10)\n",
    "print(sum(p.numel() for p in model.parameters() if p.requires_grad))\n",
    "#print(model(torch.randn(1,3,224,224),torch.Tensor([0]).long()).shape)\n",
    "model.load_state_dict(torch.load(r'D:\\Living_and_Study_In_University\\Dataset\\CIFA-10\\resnet34.pth'),strict=False)\n",
    "model.to(\"cuda\")\n",
    "#model.load_state_dict_from_original(torch.load(r'D:\\Living_and_Study_In_University\\Dataset\\CIFA-10\\resnet34.pth'),strict=True)\n",
    "from create_models import check_fined_loaded_parameters\n",
    "check_fined_loaded_parameters(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23644362\n",
      "23528522\n"
     ]
    }
   ],
   "source": [
    "#create a Resnet50 using the ResNet Class\n",
    "model=ResNet_Time(Bottleneck,[3,4,6,3],num_classes=10)\n",
    "print(sum(p.numel() for p in model.parameters() if p.requires_grad))\n",
    "model=models.resnet50(pretrained=False,num_classes=10)\n",
    "print(sum(p.numel() for p in model.parameters() if p.requires_grad))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
