{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<All keys matched successfully>\n",
      "66093066\n"
     ]
    }
   ],
   "source": [
    "from train_config import BaseConfig\n",
    "import torch\n",
    "from Classifer_Model import create_encoder_unet\n",
    "config=BaseConfig()\n",
    "model=create_encoder_unet(config)\n",
    "test=model.load_state_dict(torch.load('./guided_classifier/epoch152.pth'),strict=False)\n",
    "print(test)\n",
    "print(sum(p.numel() for p in model.parameters() if p.requires_grad))\n",
    "#guided_classifier\\epoch302.pth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(2.0527, device='cuda:0') tensor(-1.2146, device='cuda:0')\n",
      "tensor(2.2365, device='cuda:0') tensor(-1.1087, device='cuda:0')\n",
      "tensor(2.0094, device='cuda:0') tensor(-1.1259, device='cuda:0')\n",
      "tensor(2.1003, device='cuda:0') tensor(-1.2265, device='cuda:0')\n",
      "tensor(2.0912, device='cuda:0') tensor(-0.9740, device='cuda:0')\n",
      "tensor(2.0556, device='cuda:0') tensor(-1.1058, device='cuda:0')\n",
      "tensor(2.0196, device='cuda:0') tensor(-1.0819, device='cuda:0')\n",
      "tensor(2.0097, device='cuda:0') tensor(-1.0765, device='cuda:0')\n",
      "tensor(2.0235, device='cuda:0') tensor(-1.0677, device='cuda:0')\n",
      "tensor(1.9983, device='cuda:0') tensor(-0.9580, device='cuda:0')\n",
      "tensor(2.0514, device='cuda:0') tensor(-1.1411, device='cuda:0')\n",
      "tensor(2.0635, device='cuda:0') tensor(-0.9974, device='cuda:0')\n",
      "tensor(2.0967, device='cuda:0') tensor(-1.1001, device='cuda:0')\n",
      "tensor(2.2424, device='cuda:0') tensor(-1.1216, device='cuda:0')\n",
      "tensor(2.1994, device='cuda:0') tensor(-1.1037, device='cuda:0')\n",
      "tensor(2.1575, device='cuda:0') tensor(-1.0932, device='cuda:0')\n",
      "tensor(2.2646, device='cuda:0') tensor(-1.1529, device='cuda:0')\n",
      "tensor(2.0912, device='cuda:0') tensor(-1.1009, device='cuda:0')\n",
      "tensor(2.0528, device='cuda:0') tensor(-1.0736, device='cuda:0')\n",
      "tensor(2.2126, device='cuda:0') tensor(-1.0980, device='cuda:0')\n",
      "tensor(2.0358, device='cuda:0') tensor(-1.1645, device='cuda:0')\n",
      "tensor(2.2588, device='cuda:0') tensor(-1.0622, device='cuda:0')\n",
      "tensor(2.0685, device='cuda:0') tensor(-1.1375, device='cuda:0')\n",
      "tensor(2.2909, device='cuda:0') tensor(-1.2610, device='cuda:0')\n",
      "tensor(2.1078, device='cuda:0') tensor(-1.1037, device='cuda:0')\n",
      "tensor(2.1054, device='cuda:0') tensor(-1.0728, device='cuda:0')\n",
      "tensor(2.0469, device='cuda:0') tensor(-1.1038, device='cuda:0')\n",
      "tensor(2.2414, device='cuda:0') tensor(-1.0824, device='cuda:0')\n",
      "tensor(2.1097, device='cuda:0') tensor(-1.3129, device='cuda:0')\n",
      "tensor(2.0912, device='cuda:0') tensor(-1.3362, device='cuda:0')\n",
      "tensor(2.0311, device='cuda:0') tensor(-1.0566, device='cuda:0')\n",
      "tensor(2.1403, device='cuda:0') tensor(-1.1693, device='cuda:0')\n",
      "tensor(2.1224, device='cuda:0') tensor(-1.1220, device='cuda:0')\n",
      "tensor(2.1269, device='cuda:0') tensor(-1.2331, device='cuda:0')\n",
      "tensor(2.3735, device='cuda:0') tensor(-1.1897, device='cuda:0')\n",
      "tensor(2.0470, device='cuda:0') tensor(-1.2536, device='cuda:0')\n",
      "tensor(2.1911, device='cuda:0') tensor(-0.9937, device='cuda:0')\n",
      "tensor(2.0427, device='cuda:0') tensor(-1.0172, device='cuda:0')\n",
      "tensor(2.0662, device='cuda:0') tensor(-1.2001, device='cuda:0')\n",
      "tensor(2.0838, device='cuda:0') tensor(-1.3120, device='cuda:0')\n",
      "tensor(2.0017, device='cuda:0') tensor(-1.0713, device='cuda:0')\n",
      "tensor(2.1755, device='cuda:0') tensor(-1.2058, device='cuda:0')\n",
      "tensor(2.1188, device='cuda:0') tensor(-1.1388, device='cuda:0')\n",
      "tensor(2.0152, device='cuda:0') tensor(-1.3908, device='cuda:0')\n",
      "tensor(2.0502, device='cuda:0') tensor(-1.0357, device='cuda:0')\n",
      "tensor(2.2613, device='cuda:0') tensor(-1.2216, device='cuda:0')\n",
      "tensor(2.1779, device='cuda:0') tensor(-1.2305, device='cuda:0')\n",
      "tensor(2.0324, device='cuda:0') tensor(-1.1472, device='cuda:0')\n",
      "tensor(2.1094, device='cuda:0') tensor(-1.1103, device='cuda:0')\n",
      "tensor(1.9682, device='cuda:0') tensor(-1.1270, device='cuda:0')\n",
      "tensor(2.1429, device='cuda:0') tensor(-1.1083, device='cuda:0')\n",
      "tensor(2.0325, device='cuda:0') tensor(-1.2500, device='cuda:0')\n",
      "tensor(2.0457, device='cuda:0') tensor(-1.1618, device='cuda:0')\n",
      "tensor(2.2580, device='cuda:0') tensor(-1.1676, device='cuda:0')\n",
      "tensor(2.0384, device='cuda:0') tensor(-1.1161, device='cuda:0')\n",
      "tensor(2.1368, device='cuda:0') tensor(-1.0155, device='cuda:0')\n",
      "tensor(2.1837, device='cuda:0') tensor(-1.1954, device='cuda:0')\n",
      "tensor(2.1148, device='cuda:0') tensor(-1.2499, device='cuda:0')\n",
      "tensor(2.0651, device='cuda:0') tensor(-1.0732, device='cuda:0')\n",
      "tensor(2.2243, device='cuda:0') tensor(-0.9731, device='cuda:0')\n",
      "tensor(2.2037, device='cuda:0') tensor(-1.1796, device='cuda:0')\n",
      "tensor(2.0508, device='cuda:0') tensor(-1.2623, device='cuda:0')\n",
      "tensor(2.0623, device='cuda:0') tensor(-1.1480, device='cuda:0')\n",
      "tensor(2.0263, device='cuda:0') tensor(-1.2301, device='cuda:0')\n",
      "tensor(2.0491, device='cuda:0') tensor(-1.2067, device='cuda:0')\n",
      "tensor(2.0431, device='cuda:0') tensor(-1.0579, device='cuda:0')\n",
      "tensor(2.0381, device='cuda:0') tensor(-1.1452, device='cuda:0')\n",
      "tensor(2.2071, device='cuda:0') tensor(-1.0544, device='cuda:0')\n",
      "tensor(2.3522, device='cuda:0') tensor(-1.0614, device='cuda:0')\n",
      "tensor(1.9300, device='cuda:0') tensor(-1.1427, device='cuda:0')\n",
      "tensor(1.9880, device='cuda:0') tensor(-1.1743, device='cuda:0')\n",
      "tensor(2.2820, device='cuda:0') tensor(-1.0350, device='cuda:0')\n",
      "tensor(2.0579, device='cuda:0') tensor(-1.1900, device='cuda:0')\n",
      "tensor(2.3975, device='cuda:0') tensor(-1.1232, device='cuda:0')\n",
      "tensor(2.1058, device='cuda:0') tensor(-1.1475, device='cuda:0')\n",
      "tensor(2.1561, device='cuda:0') tensor(-1.1189, device='cuda:0')\n",
      "tensor(2.1530, device='cuda:0') tensor(-1.1056, device='cuda:0')\n",
      "tensor(2.1335, device='cuda:0') tensor(-1.0853, device='cuda:0')\n",
      "tensor(2.0024, device='cuda:0') tensor(-1.0300, device='cuda:0')\n",
      "tensor(2.0529, device='cuda:0') tensor(-1.0541, device='cuda:0')\n",
      "tensor(2.1020, device='cuda:0') tensor(-1.0961, device='cuda:0')\n",
      "tensor(2.2234, device='cuda:0') tensor(-1.1260, device='cuda:0')\n",
      "tensor(2.1131, device='cuda:0') tensor(-1.1480, device='cuda:0')\n",
      "tensor(2.1810, device='cuda:0') tensor(-1.0425, device='cuda:0')\n",
      "tensor(2.1354, device='cuda:0') tensor(-1.0748, device='cuda:0')\n",
      "tensor(2.1411, device='cuda:0') tensor(-1.1449, device='cuda:0')\n",
      "tensor(1.9662, device='cuda:0') tensor(-1.1329, device='cuda:0')\n",
      "tensor(2.3854, device='cuda:0') tensor(-1.0531, device='cuda:0')\n",
      "tensor(1.9218, device='cuda:0') tensor(-1.0659, device='cuda:0')\n",
      "tensor(2.0333, device='cuda:0') tensor(-1.1444, device='cuda:0')\n",
      "tensor(2.2902, device='cuda:0') tensor(-1.1996, device='cuda:0')\n",
      "tensor(2.1032, device='cuda:0') tensor(-1.0408, device='cuda:0')\n",
      "tensor(2.0636, device='cuda:0') tensor(-1.3515, device='cuda:0')\n",
      "tensor(2.1947, device='cuda:0') tensor(-1.1995, device='cuda:0')\n",
      "tensor(2.1483, device='cuda:0') tensor(-1.0402, device='cuda:0')\n",
      "tensor(2.0474, device='cuda:0') tensor(-1.1621, device='cuda:0')\n",
      "tensor(2.1313, device='cuda:0') tensor(-1.0890, device='cuda:0')\n",
      "tensor(2.3145, device='cuda:0') tensor(-1.1148, device='cuda:0')\n",
      "tensor(1.9494, device='cuda:0') tensor(-1.4048, device='cuda:0')\n",
      "tensor(2.3577, device='cuda:0') tensor(-1.1534, device='cuda:0')\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[1;32mIn [7]\u001b[0m, in \u001b[0;36m<cell line: 57>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     55\u001b[0m trainset \u001b[38;5;241m=\u001b[39m torchvision\u001b[38;5;241m.\u001b[39mdatasets\u001b[38;5;241m.\u001b[39mCIFAR10(root\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mD:/Living_and_Study_In_University/Dataset/CIFA-10\u001b[39m\u001b[38;5;124m'\u001b[39m, train\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,transform\u001b[38;5;241m=\u001b[39mtransform)\n\u001b[0;32m     56\u001b[0m trainloader\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mDataLoader(trainset, batch_size\u001b[38;5;241m=\u001b[39mconfig\u001b[38;5;241m.\u001b[39mbatch_size\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m2\u001b[39m,shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m---> 57\u001b[0m \u001b[43mTimeStep_Accuracy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrainloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbf16\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minverval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m20\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcuda\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "Input \u001b[1;32mIn [7]\u001b[0m, in \u001b[0;36mTimeStep_Accuracy\u001b[1;34m(model, train_loader, fp, inverval, device)\u001b[0m\n\u001b[0;32m     24\u001b[0m samples\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m\n\u001b[0;32m     25\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m---> 26\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m batch, label \u001b[38;5;129;01min\u001b[39;00m train_loader:\n\u001b[0;32m     27\u001b[0m         \u001b[38;5;66;03m# batch = batch.to(device)\u001b[39;00m\n\u001b[0;32m     28\u001b[0m         \u001b[38;5;66;03m# label = label.to(device)\u001b[39;00m\n\u001b[0;32m     29\u001b[0m         \u001b[38;5;66;03m# if fp == \"fp16\":\u001b[39;00m\n\u001b[0;32m     30\u001b[0m         \u001b[38;5;66;03m#     batch = batch.half().to(device)\u001b[39;00m\n\u001b[0;32m     31\u001b[0m         \u001b[38;5;66;03m#     label = label.half().to(device)\u001b[39;00m\n\u001b[0;32m     32\u001b[0m         times \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandint(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m100\u001b[39m, (batch\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m],), dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mint64)\u001b[38;5;241m.\u001b[39mto(device)\u001b[38;5;241m.\u001b[39mlong()\n\u001b[0;32m     34\u001b[0m         noise_shape \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn_like(batch)\u001b[38;5;241m.\u001b[39mto(device)\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\accelerate\\data_loader.py:458\u001b[0m, in \u001b[0;36mDataLoaderShard.__iter__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    456\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    457\u001b[0m     current_batch \u001b[38;5;241m=\u001b[39m send_to_device(current_batch, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m--> 458\u001b[0m next_batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdataloader_iter\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    459\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_index \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mskip_batches:\n\u001b[0;32m    460\u001b[0m     \u001b[38;5;28;01myield\u001b[39;00m current_batch\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:630\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    627\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    628\u001b[0m     \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m    629\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()  \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 630\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m    632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[0;32m    633\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[0;32m    634\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:674\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    672\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m    673\u001b[0m     index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index()  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 674\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m    675\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[0;32m    676\u001b[0m         data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m     49\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m     50\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 51\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m     52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m     53\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:51\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m     49\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m     50\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 51\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m     52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m     53\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\torchvision\\datasets\\cifar.py:115\u001b[0m, in \u001b[0;36mCIFAR10.__getitem__\u001b[1;34m(self, index)\u001b[0m\n\u001b[0;32m    111\u001b[0m img, target \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata[index], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtargets[index]\n\u001b[0;32m    113\u001b[0m \u001b[38;5;66;03m# doing this so that it is consistent with all other datasets\u001b[39;00m\n\u001b[0;32m    114\u001b[0m \u001b[38;5;66;03m# to return a PIL Image\u001b[39;00m\n\u001b[1;32m--> 115\u001b[0m img \u001b[38;5;241m=\u001b[39m \u001b[43mImage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfromarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    117\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    118\u001b[0m     img \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform(img)\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\PIL\\Image.py:2834\u001b[0m, in \u001b[0;36mfromarray\u001b[1;34m(obj, mode)\u001b[0m\n\u001b[0;32m   2831\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m   2832\u001b[0m         obj \u001b[38;5;241m=\u001b[39m obj\u001b[38;5;241m.\u001b[39mtostring()\n\u001b[1;32m-> 2834\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfrombuffer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mraw\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrawmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\PIL\\Image.py:2761\u001b[0m, in \u001b[0;36mfrombuffer\u001b[1;34m(mode, size, data, decoder_name, *args)\u001b[0m\n\u001b[0;32m   2758\u001b[0m         im\u001b[38;5;241m.\u001b[39mreadonly \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m   2759\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m im\n\u001b[1;32m-> 2761\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfrombytes\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdecoder_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\PIL\\Image.py:2707\u001b[0m, in \u001b[0;36mfrombytes\u001b[1;34m(mode, size, data, decoder_name, *args)\u001b[0m\n\u001b[0;32m   2704\u001b[0m     args \u001b[38;5;241m=\u001b[39m mode\n\u001b[0;32m   2706\u001b[0m im \u001b[38;5;241m=\u001b[39m new(mode, size)\n\u001b[1;32m-> 2707\u001b[0m \u001b[43mim\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrombytes\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdecoder_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   2708\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m im\n",
      "File \u001b[1;32md:\\Anaconda\\lib\\site-packages\\PIL\\Image.py:781\u001b[0m, in \u001b[0;36mImage.frombytes\u001b[1;34m(self, data, decoder_name, *args)\u001b[0m\n\u001b[0;32m    779\u001b[0m d \u001b[38;5;241m=\u001b[39m _getdecoder(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmode, decoder_name, args)\n\u001b[0;32m    780\u001b[0m d\u001b[38;5;241m.\u001b[39msetimage(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mim)\n\u001b[1;32m--> 781\u001b[0m s \u001b[38;5;241m=\u001b[39m \u001b[43md\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m s[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m    784\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnot enough image data\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from diffusers import DDPMScheduler\n",
    "from accelerate import Accelerator\n",
    "def correct_idx(idx,prediction,lables):\n",
    "    return torch.argmax(prediction[idx],dim=1)==lables[idx]\n",
    "\n",
    "def TimeStep_Accuracy(model, train_loader, fp=\"fp32\", inverval=50, device=\"cuda\"):\n",
    "    \n",
    "    #model=model.to(device)\n",
    "    if(fp!=\"fp32\"):\n",
    "        accelerate=Accelerator(mixed_precision=fp)\n",
    "    else:\n",
    "        accelerate=Accelerator()\n",
    "    # if(fp==\"fp16\"):\n",
    "    #     model=model.half()\n",
    "    scheduler = DDPMScheduler(num_train_timesteps=1000)\n",
    "    #model.eval()\n",
    "    model, train_loader = accelerate.prepare(model, train_loader)\n",
    "    model.load_state_dict(torch.load('./guided_classifier/epoch302.pth'),strict=False)\n",
    "    all_idxs_samples = {i: 1 for i in range(1000 // inverval)}\n",
    "    all_idxs_correct = {i: 0 for i in range(1000 // inverval)}\n",
    "    # with torch.no_grad():\n",
    "    correct=0\n",
    "    samples=0\n",
    "    with torch.no_grad():\n",
    "        for batch, label in train_loader:\n",
    "            # batch = batch.to(device)\n",
    "            # label = label.to(device)\n",
    "            # if fp == \"fp16\":\n",
    "            #     batch = batch.half().to(device)\n",
    "            #     label = label.half().to(device)\n",
    "            times = torch.randint(0, 100, (batch.shape[0],), dtype=torch.int64).to(device).long()\n",
    "\n",
    "            noise_shape = torch.randn_like(batch).to(device)\n",
    "            noise_image = scheduler.add_noise(batch, noise_shape, times)\n",
    "            print(noise_image.max(),noise_image.min())\n",
    "            predictions = model(noise_image, times)\n",
    "            correct+=torch.sum(torch.argmax(predictions,dim=1)==label).item()\n",
    "            samples+=batch.shape[0]\n",
    "            for i in range(1000 // inverval):\n",
    "                idxs = torch.where((times >= i * inverval) & (times < (i + 1) * inverval))[0]\n",
    "                #print(idxs.shape)\n",
    "                all_idxs_samples[i] += idxs.shape[0]\n",
    "                all_idxs_correct[i] += torch.sum(correct_idx(idxs, predictions, label)).item()\n",
    "            #print(all_idxs_samples)\n",
    "            #print(\"accuracy: \",correct/samples)\n",
    "    accuracy={(i*inverval,(i+1)*inverval):all_idxs_correct[i] / all_idxs_samples[i] for i in range(1000 // inverval)}        \n",
    "    return accuracy\n",
    "#Cifar-10\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch\n",
    "from diffusers import DDPMScheduler\n",
    "transform=transforms.Compose([transforms.ToTensor()])\n",
    "trainset = torchvision.datasets.CIFAR10(root=r'D:/Living_and_Study_In_University/Dataset/CIFA-10', train=True,transform=transform)\n",
    "trainloader=torch.utils.data.DataLoader(trainset, batch_size=config.batch_size//2,shuffle=False)\n",
    "TimeStep_Accuracy(model, trainloader, fp=\"bf16\", inverval=20, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
