{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "00a46aaa",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lianhai/anaconda3/envs/mytorch/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "#import necessary packages\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "import torchvision\n",
    "from torchvision.transforms import ToTensor\n",
    "import pytorch_lightning as PL\n",
    "\n",
    "import tqdm\n",
    "\n",
    "from Models import ResNet50_one_run,LoKr_ResNet50_one_run,ViTb16_one_run\n",
    "from ModifiedModel import MPCmodel,LossModifiedModel_test,LossModifiedModel_final"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c67f7fb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda:2 device\n"
     ]
    }
   ],
   "source": [
    "#specify the gpu\n",
    "gpu=0\n",
    "device = (\n",
    "    \"cuda:\"+str(gpu)\n",
    "    if torch.cuda.is_available()\n",
    "    else \"mps\"\n",
    "    if torch.backends.mps.is_available()\n",
    "    else \"cpu\"\n",
    ")\n",
    "print(f\"Using {device} device\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4d3eeaea",
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 1e-3\n",
    "batch_size = 64\n",
    "epochs = 30\n",
    "loss_fn = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25f745e9",
   "metadata": {},
   "source": [
    "## Test ResNet50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1c4d115d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.Resize(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(\n",
    "        mean=[0.485, 0.456, 0.406],\n",
    "        std=[0.229, 0.224, 0.225]\n",
    "    )\n",
    "])\n",
    "training_data = datasets.CIFAR100(\n",
    "    root=\"../data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "test_data = datasets.CIFAR100(\n",
    "    root=\"../data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "train_dataloader = DataLoader(training_data, batch_size=batch_size,num_workers=2,shuffle=True,)\n",
    "test_dataloader = DataLoader(test_data, batch_size=batch_size,num_workers=2,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c570682d",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseconfig_resnet50={\n",
    "        \"learning_rate\": learning_rate,\n",
    "        \"architecture\": \"ResNet50\",\n",
    "        \"conv_channels\":[3,64,64,128,256,512,],\n",
    "        \"conv_repeats\":[1,3,4,6,3],\n",
    "        \"dataset\": \"CIFAR100\",\n",
    "        \"epochs\": epochs,\n",
    "        'gpu':gpu,\n",
    "        'optimizer':'sgd',\n",
    "        'momentum':0.9,\n",
    "        }\n",
    "dataconfig={\n",
    "    'train_dataloader':train_dataloader,\n",
    "    'test_dataloader':test_dataloader,\n",
    "    'training_data':training_data,\n",
    "    'loss_fn':loss_fn,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "316e8e97",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n",
      "\n",
      "  | Name              | Type             | Params\n",
      "-------------------------------------------------------\n",
      "0 | model             | ResNet50         | 23.7 M\n",
      "1 | loss_fn           | CrossEntropyLoss | 0     \n",
      "2 | blocks            | ModuleList       | 23.5 M\n",
      "3 | stem              | Identity         | 0     \n",
      "4 | lossblocks        | ModuleList       | 1.5 M \n",
      "5 | metrics           | MetricCollection | 0     \n",
      "6 | val_metrics       | MetricCollection | 0     \n",
      "7 | train_loss_metric | MeanMetric       | 0     \n",
      "8 | val_loss_metric   | MeanMetric       | 0     \n",
      "-------------------------------------------------------\n",
      "25.0 M    Trainable params\n",
      "0         Non-trainable params\n",
      "25.0 M    Total params\n",
      "100.106   Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Sanity Checking: 0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lianhai/anaconda3/envs/mytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
      "  rank_zero_warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                           \r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lianhai/anaconda3/envs/mytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
      "  rank_zero_warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:09<00:00, 16.24it/s].800, acc=0.312, acc5=0.625, val_loss=2.080, val_acc=0.472, val_acc5=0.795]\n",
      "Epoch 1:  60%|██████    | 470/782 [01:08<00:45,  6.82it/s, v_num=0, loss=2.360, acc=0.328, acc5=0.797, val_loss=2.080, val_acc=0.472, val_acc5=0.795]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lianhai/anaconda3/envs/mytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:52: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n",
      "  rank_zero_warn(\"Detected KeyboardInterrupt, attempting graceful shutdown...\")\n",
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for idx in [1,2,3]:\n",
    "    for h in [1,2,5,17]:\n",
    "        Model_dict={'None': MPCmodel}\n",
    "        if h<17:\n",
    "            Model_dict['Modify']=LossModifiedModel_final\n",
    "        if h==1:\n",
    "            Model_dict['Test']=LossModifiedModel_test\n",
    "        for modifyname,Modeltype in Model_dict.items():\n",
    "            config=baseconfig_resnet50.copy()\n",
    "            config.update({\n",
    "                \"horizon\":h,\n",
    "                \"stride\":1,\n",
    "                'idx':idx,\n",
    "                'modify':modifyname,\n",
    "                'device':device,\n",
    "            })\n",
    "            if modifyname=='Modify':\n",
    "                config['lambda_modify']=1.\n",
    "                modifyname+=f\"_lamb{config['lambda_modify']}\"\n",
    "            if modifyname=='None':\n",
    "                name='resnet50'+f'_{h}'+f\"_id{config['idx']}\"\n",
    "            else:\n",
    "                name='resnet50'+f'_{h}'+'_'+modifyname+f\"_id{config['idx']}\"\n",
    "            ResNet50_one_run(name,dataconfig,config,Modeltype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17aa1860",
   "metadata": {},
   "source": [
    "## Test LoKr ResNet50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93f47d2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in [1,2,3]:\n",
    "    for h in [1,2,5,17]:\n",
    "        Model_dict={'None': MPCmodel}\n",
    "        if h<17:\n",
    "            Model_dict['Modify']=LossModifiedModel_final\n",
    "        if h==1:\n",
    "            Model_dict['Test']=LossModifiedModel_test\n",
    "        for modifyname,Modeltype in Model_dict.items():\n",
    "            config=baseconfig_resnet50.copy()\n",
    "            config.update({\n",
    "                \"architecture\":\"LoKr_ResNet50\",\n",
    "                \"horizon\":h,\n",
    "                \"stride\":1,\n",
    "                'idx':idx,\n",
    "                'modify':modifyname,\n",
    "                'device':device,\n",
    "            })\n",
    "            if modifyname=='Modify':\n",
    "                config['lambda_modify']=1.\n",
    "                modifyname+=f\"_lamb{config['lambda_modify']}\"\n",
    "            if modifyname=='None':\n",
    "                name='lokr_resnet50'+f'_{h}'+f\"_id{config['idx']}\"\n",
    "            else:\n",
    "                name='lokr_resnet50'+f'_{h}'+'_'+modifyname+f\"_id{config['idx']}\"\n",
    "            LoKr_ResNet50_one_run(name,dataconfig,config,Modeltype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c621148d",
   "metadata": {},
   "source": [
    "## Test ViT-b16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "f390da57",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "import peft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "0502b4e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lianhai/anaconda3/envs/mytorch/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "checkpoint = \"google/vit-base-patch16-224-in21k\"\n",
    "image_processor = transformers.AutoProcessor.from_pretrained(checkpoint)\n",
    "normalize = transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)\n",
    "size = (\n",
    "    image_processor.size[\"shortest_edge\"]\n",
    "    if \"shortest_edge\" in image_processor.size\n",
    "    else (image_processor.size[\"height\"], image_processor.size[\"width\"])\n",
    ")\n",
    "transform = transforms.Compose([transforms.Resize(size), transforms.ToTensor(), normalize])\n",
    "training_data = datasets.CIFAR100(\n",
    "    root=\"data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=transform,\n",
    "\n",
    ")\n",
    "\n",
    "test_data = datasets.CIFAR100(\n",
    "    root=\"data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "train_dataloader = DataLoader(training_data, batch_size=batch_size,num_workers=2,shuffle=True,)\n",
    "test_dataloader = DataLoader(test_data, batch_size=batch_size,num_workers=2,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "7b06d093",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseconfig_vitb16={\n",
    "        \"learning_rate\": learning_rate,\n",
    "        \"architecture\": \"ViT-b16\",\n",
    "        \"transformer_channels\":[12],\n",
    "        \"transformer_repeats\":[768],\n",
    "        \"dataset\": \"CIFAR100\",\n",
    "        \"epochs\": epochs,\n",
    "        'optimizer':'sgd',\n",
    "        'momentum':0.9,\n",
    "        'gpu':gpu,\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d985fc4c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/home/lianhai/anaconda3/envs/mytorch/lib/python3.9/site-packages/pytorch_lightning/loggers/wandb.py:396: UserWarning: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.\n",
      "  rank_zero_warn(\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n",
      "\n",
      "  | Name        | Type             | Params\n",
      "-------------------------------------------------\n",
      "0 | model       | Vit_B16          | 85.9 M\n",
      "1 | loss_fn     | CrossEntropyLoss | 0     \n",
      "2 | metrics     | MetricCollection | 0     \n",
      "3 | val_metrics | MetricCollection | 0     \n",
      "4 | train_loss  | MeanMetric       | 0     \n",
      "5 | val_loss    | MeanMetric       | 0     \n",
      "6 | blocks      | ModuleList       | 85.1 M\n",
      "7 | stem        | ViTEmbeddings    | 742 K \n",
      "8 | lossblocks  | ModuleList       | 941 K \n",
      "-------------------------------------------------\n",
      "978 K     Trainable params\n",
      "85.8 M    Non-trainable params\n",
      "86.8 M    Total params\n",
      "347.101   Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.38it/s]s=4.530, acc=0.0625, val_loss=4.550, val_acc=0.0465]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.27it/s]s=4.550, acc=0.000, val_loss=4.490, val_acc=0.138]  \n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.43it/s]s=4.450, acc=0.0625, val_loss=4.440, val_acc=0.269]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.26it/s]s=4.350, acc=0.500, val_loss=4.390, val_acc=0.396] \n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.26it/s]s=4.370, acc=0.312, val_loss=4.330, val_acc=0.496]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.26it/s]s=4.280, acc=0.750, val_loss=4.280, val_acc=0.567]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.24it/s]s=4.250, acc=0.500, val_loss=4.220, val_acc=0.619]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:30<00:00,  5.15it/s]s=4.120, acc=0.625, val_loss=4.160, val_acc=0.659]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.38it/s]s=4.110, acc=0.688, val_loss=4.090, val_acc=0.687]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:30<00:00,  5.16it/s]s=3.990, acc=0.688, val_loss=4.010, val_acc=0.708]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:30<00:00,  5.23it/s]ss=3.950, acc=0.750, val_loss=3.930, val_acc=0.724]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:30<00:00,  5.23it/s]ss=3.840, acc=0.812, val_loss=3.850, val_acc=0.736]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:30<00:00,  5.21it/s]ss=3.770, acc=0.750, val_loss=3.760, val_acc=0.745]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.36it/s]ss=3.710, acc=0.688, val_loss=3.670, val_acc=0.755]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.24it/s]ss=3.370, acc=0.812, val_loss=3.580, val_acc=0.762]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.38it/s]ss=3.520, acc=0.938, val_loss=3.480, val_acc=0.768]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.29it/s]ss=3.430, acc=0.750, val_loss=3.390, val_acc=0.772]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.25it/s]ss=3.500, acc=0.625, val_loss=3.300, val_acc=0.773]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:30<00:00,  5.18it/s]ss=3.320, acc=0.750, val_loss=3.230, val_acc=0.772]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:30<00:00,  5.18it/s]ss=3.250, acc=0.875, val_loss=3.220, val_acc=0.761]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:30<00:00,  5.22it/s]ss=3.050, acc=0.875, val_loss=3.170, val_acc=0.762]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.46it/s]ss=3.240, acc=0.625, val_loss=3.110, val_acc=0.759]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.51it/s]ss=3.150, acc=0.500, val_loss=3.070, val_acc=0.707]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.41it/s]ss=2.520, acc=0.875, val_loss=3.000, val_acc=0.693]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.40it/s]ss=3.270, acc=0.562, val_loss=2.870, val_acc=0.698]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.42it/s]ss=2.580, acc=0.625, val_loss=2.760, val_acc=0.696]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.39it/s]ss=2.320, acc=0.750, val_loss=2.650, val_acc=0.695]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.42it/s]ss=2.330, acc=0.875, val_loss=2.560, val_acc=0.692]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.41it/s]ss=2.950, acc=0.500, val_loss=2.490, val_acc=0.692]\n",
      "Epoch 29: 100%|██████████| 782/782 [05:01<00:00,  2.59it/s, v_num=fs7k, loss=2.950, acc=0.500, val_loss=2.490, val_acc=0.692]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=30` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 29: 100%|██████████| 782/782 [05:02<00:00,  2.59it/s, v_num=fs7k, loss=2.950, acc=0.500, val_loss=2.490, val_acc=0.692]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/home/lianhai/anaconda3/envs/mytorch/lib/python3.9/site-packages/pytorch_lightning/loggers/wandb.py:396: UserWarning: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.\n",
      "  rank_zero_warn(\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n",
      "\n",
      "  | Name        | Type             | Params\n",
      "-------------------------------------------------\n",
      "0 | model       | Vit_B16          | 85.9 M\n",
      "1 | loss_fn     | CrossEntropyLoss | 0     \n",
      "2 | metrics     | MetricCollection | 0     \n",
      "3 | val_metrics | MetricCollection | 0     \n",
      "4 | train_loss  | MeanMetric       | 0     \n",
      "5 | val_loss    | MeanMetric       | 0     \n",
      "6 | blocks      | ModuleList       | 85.1 M\n",
      "7 | stem        | ViTEmbeddings    | 742 K \n",
      "8 | lossblocks  | ModuleList       | 941 K \n",
      "-------------------------------------------------\n",
      "978 K     Trainable params\n",
      "85.8 M    Non-trainable params\n",
      "86.8 M    Total params\n",
      "347.101   Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.52it/s]s=4.600, acc=0.000, val_loss=4.560, val_acc=0.039]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.43it/s]s=4.500, acc=0.125, val_loss=4.500, val_acc=0.109] \n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.46it/s]s=4.430, acc=0.438, val_loss=4.450, val_acc=0.232] \n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.43it/s]s=4.400, acc=0.375, val_loss=4.400, val_acc=0.361]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.48it/s]s=4.350, acc=0.375, val_loss=4.340, val_acc=0.463]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.37it/s]s=4.300, acc=0.438, val_loss=4.290, val_acc=0.540]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.40it/s]s=4.330, acc=0.312, val_loss=4.230, val_acc=0.599]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.48it/s]s=4.150, acc=0.688, val_loss=4.170, val_acc=0.648]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.45it/s]s=4.170, acc=0.562, val_loss=4.100, val_acc=0.683]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.48it/s]s=4.100, acc=0.562, val_loss=4.030, val_acc=0.713]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.51it/s]ss=4.010, acc=0.625, val_loss=3.970, val_acc=0.727]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.51it/s]ss=3.920, acc=0.688, val_loss=3.910, val_acc=0.733]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.40it/s]ss=3.910, acc=0.688, val_loss=3.860, val_acc=0.728]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.38it/s]ss=3.950, acc=0.688, val_loss=3.810, val_acc=0.722]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.52it/s]ss=3.600, acc=0.938, val_loss=3.740, val_acc=0.729]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.44it/s]ss=3.590, acc=0.750, val_loss=3.650, val_acc=0.742]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.41it/s]ss=3.550, acc=0.688, val_loss=3.560, val_acc=0.753]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.47it/s]ss=3.400, acc=0.812, val_loss=3.470, val_acc=0.758]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.41it/s]ss=3.590, acc=0.625, val_loss=3.380, val_acc=0.760]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.48it/s]ss=2.970, acc=0.938, val_loss=3.280, val_acc=0.760]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.45it/s]ss=2.850, acc=0.750, val_loss=3.170, val_acc=0.762]\n",
      "Epoch 21:  88%|████████▊ | 690/782 [07:37<01:01,  1.51it/s, v_num=dqb0, loss=2.980, acc=0.812, val_loss=3.170, val_acc=0.762]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.49it/s]ss=2.390, acc=0.938, val_loss=2.470, val_acc=0.808]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.54it/s]ss=2.240, acc=0.875, val_loss=2.310, val_acc=0.814]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.50it/s]ss=2.180, acc=0.688, val_loss=2.160, val_acc=0.817]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.45it/s]ss=2.070, acc=0.750, val_loss=2.000, val_acc=0.818]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.54it/s]ss=1.530, acc=0.875, val_loss=1.870, val_acc=0.816]\n",
      "Epoch 29: 100%|██████████| 782/782 [14:40<00:00,  1.13s/it, v_num=a81x, loss=1.530, acc=0.875, val_loss=1.870, val_acc=0.816]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=30` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 29: 100%|██████████| 782/782 [14:41<00:00,  1.13s/it, v_num=a81x, loss=1.530, acc=0.875, val_loss=1.870, val_acc=0.816]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/home/lianhai/anaconda3/envs/mytorch/lib/python3.9/site-packages/pytorch_lightning/loggers/wandb.py:396: UserWarning: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.\n",
      "  rank_zero_warn(\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n",
      "\n",
      "  | Name        | Type             | Params\n",
      "-------------------------------------------------\n",
      "0 | model       | Vit_B16          | 85.9 M\n",
      "1 | loss_fn     | CrossEntropyLoss | 0     \n",
      "2 | metrics     | MetricCollection | 0     \n",
      "3 | val_metrics | MetricCollection | 0     \n",
      "4 | train_loss  | MeanMetric       | 0     \n",
      "5 | val_loss    | MeanMetric       | 0     \n",
      "6 | blocks      | ModuleList       | 85.1 M\n",
      "7 | stem        | ViTEmbeddings    | 742 K \n",
      "8 | lossblocks  | ModuleList       | 941 K \n",
      "-------------------------------------------------\n",
      "978 K     Trainable params\n",
      "85.8 M    Non-trainable params\n",
      "86.8 M    Total params\n",
      "347.101   Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.55it/s]s=4.560, acc=0.000, val_loss=4.560, val_acc=0.0391]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.53it/s]s=4.520, acc=0.0625, val_loss=4.500, val_acc=0.115] \n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.59it/s]s=4.440, acc=0.250, val_loss=4.450, val_acc=0.229] \n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.58it/s]s=4.430, acc=0.312, val_loss=4.400, val_acc=0.356]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.48it/s]s=4.360, acc=0.375, val_loss=4.340, val_acc=0.469]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.47it/s]s=4.270, acc=0.500, val_loss=4.290, val_acc=0.554]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.51it/s]s=4.210, acc=0.625, val_loss=4.230, val_acc=0.614]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.48it/s]s=4.130, acc=0.688, val_loss=4.170, val_acc=0.658]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.49it/s]s=4.050, acc=0.750, val_loss=4.110, val_acc=0.697]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.56it/s]s=4.080, acc=0.688, val_loss=4.040, val_acc=0.728]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.46it/s]ss=4.000, acc=0.688, val_loss=3.970, val_acc=0.752]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.58it/s]ss=3.940, acc=0.750, val_loss=3.900, val_acc=0.771]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.52it/s]ss=3.930, acc=0.812, val_loss=3.840, val_acc=0.780]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.48it/s]ss=3.780, acc=0.875, val_loss=3.760, val_acc=0.786]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.52it/s]ss=3.700, acc=0.688, val_loss=3.680, val_acc=0.793]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.50it/s]ss=3.630, acc=0.812, val_loss=3.580, val_acc=0.801]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.50it/s]ss=3.460, acc=0.938, val_loss=3.490, val_acc=0.808]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.56it/s]ss=3.330, acc=0.812, val_loss=3.380, val_acc=0.812]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.56it/s]ss=3.160, acc=0.688, val_loss=3.280, val_acc=0.814]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.53it/s]ss=3.310, acc=0.875, val_loss=3.170, val_acc=0.819]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.56it/s]ss=3.460, acc=0.562, val_loss=3.060, val_acc=0.818]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.47it/s]ss=3.130, acc=0.875, val_loss=2.940, val_acc=0.818]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.55it/s]ss=2.980, acc=0.688, val_loss=2.820, val_acc=0.817]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.49it/s]ss=2.700, acc=0.875, val_loss=2.680, val_acc=0.816]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:29<00:00,  5.38it/s]ss=2.500, acc=0.938, val_loss=2.540, val_acc=0.815]\n",
      "Epoch 25:  77%|███████▋  | 600/782 [13:01<03:57,  1.30s/it, v_num=1osz, loss=2.490, acc=0.734, val_loss=2.540, val_acc=0.815]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.53it/s]ss=2.020, acc=0.812, val_loss=1.680, val_acc=0.842]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.47it/s]ss=1.240, acc=1.000, val_loss=1.540, val_acc=0.841]\n",
      "Epoch 29: 100%|██████████| 782/782 [14:47<00:00,  1.13s/it, v_num=5ud8, loss=1.240, acc=1.000, val_loss=1.540, val_acc=0.841]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=30` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 29: 100%|██████████| 782/782 [14:47<00:00,  1.14s/it, v_num=5ud8, loss=1.240, acc=1.000, val_loss=1.540, val_acc=0.841]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/home/lianhai/anaconda3/envs/mytorch/lib/python3.9/site-packages/pytorch_lightning/loggers/wandb.py:396: UserWarning: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.\n",
      "  rank_zero_warn(\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n",
      "\n",
      "  | Name        | Type             | Params\n",
      "-------------------------------------------------\n",
      "0 | model       | Vit_B16          | 85.9 M\n",
      "1 | loss_fn     | CrossEntropyLoss | 0     \n",
      "2 | metrics     | MetricCollection | 0     \n",
      "3 | val_metrics | MetricCollection | 0     \n",
      "4 | train_loss  | MeanMetric       | 0     \n",
      "5 | val_loss    | MeanMetric       | 0     \n",
      "6 | blocks      | ModuleList       | 85.1 M\n",
      "7 | stem        | ViTEmbeddings    | 742 K \n",
      "8 | lossblocks  | ModuleList       | 941 K \n",
      "-------------------------------------------------\n",
      "978 K     Trainable params\n",
      "85.8 M    Non-trainable params\n",
      "86.8 M    Total params\n",
      "347.101   Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.58it/s]s=4.520, acc=0.125, val_loss=4.560, val_acc=0.0274]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.56it/s]s=4.570, acc=0.0625, val_loss=4.510, val_acc=0.0983]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.57it/s]s=4.490, acc=0.0625, val_loss=4.460, val_acc=0.215] \n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.55it/s]s=4.410, acc=0.250, val_loss=4.400, val_acc=0.347] \n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.61it/s]s=4.310, acc=0.562, val_loss=4.350, val_acc=0.456]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.60it/s]s=4.310, acc=0.500, val_loss=4.290, val_acc=0.542]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:27<00:00,  5.63it/s]s=4.230, acc=0.625, val_loss=4.240, val_acc=0.604]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:27<00:00,  5.65it/s]s=4.060, acc=0.875, val_loss=4.180, val_acc=0.651]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:27<00:00,  5.62it/s]s=4.160, acc=0.625, val_loss=4.120, val_acc=0.685]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.52it/s]s=4.050, acc=0.562, val_loss=4.060, val_acc=0.714]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.57it/s]ss=4.110, acc=0.500, val_loss=4.030, val_acc=0.722]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.44it/s]ss=4.040, acc=0.812, val_loss=3.970, val_acc=0.743]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.51it/s]ss=3.930, acc=0.812, val_loss=3.900, val_acc=0.759]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.59it/s]ss=3.860, acc=0.812, val_loss=3.830, val_acc=0.772]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.59it/s]ss=3.720, acc=0.938, val_loss=3.750, val_acc=0.782]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.54it/s]ss=3.600, acc=0.938, val_loss=3.680, val_acc=0.792]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.58it/s]ss=3.590, acc=0.625, val_loss=3.580, val_acc=0.798]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.58it/s]ss=3.580, acc=0.875, val_loss=3.490, val_acc=0.804]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.55it/s]ss=3.230, acc=0.938, val_loss=3.390, val_acc=0.808]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.57it/s]ss=3.050, acc=0.812, val_loss=3.290, val_acc=0.813]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.60it/s]ss=3.020, acc=0.875, val_loss=3.170, val_acc=0.816]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.54it/s]ss=3.020, acc=0.625, val_loss=3.050, val_acc=0.818]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.55it/s]ss=2.770, acc=0.812, val_loss=2.910, val_acc=0.819]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.57it/s]ss=2.500, acc=0.875, val_loss=2.750, val_acc=0.822]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.60it/s]ss=2.900, acc=0.750, val_loss=2.580, val_acc=0.823]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.58it/s]ss=2.370, acc=0.875, val_loss=2.380, val_acc=0.827]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.54it/s]ss=2.290, acc=0.812, val_loss=2.190, val_acc=0.831]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.51it/s]ss=1.920, acc=0.750, val_loss=2.010, val_acc=0.837]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.45it/s]ss=1.970, acc=0.812, val_loss=1.840, val_acc=0.838]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.53it/s]ss=1.650, acc=0.875, val_loss=1.680, val_acc=0.839]\n",
      "Epoch 29: 100%|██████████| 782/782 [09:21<00:00,  1.39it/s, v_num=8mos, loss=1.650, acc=0.875, val_loss=1.680, val_acc=0.839]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=30` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 29: 100%|██████████| 782/782 [09:22<00:00,  1.39it/s, v_num=8mos, loss=1.650, acc=0.875, val_loss=1.680, val_acc=0.839]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/home/lianhai/anaconda3/envs/mytorch/lib/python3.9/site-packages/pytorch_lightning/loggers/wandb.py:396: UserWarning: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.\n",
      "  rank_zero_warn(\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n",
      "\n",
      "  | Name        | Type             | Params\n",
      "-------------------------------------------------\n",
      "0 | model       | Vit_B16          | 85.9 M\n",
      "1 | loss_fn     | CrossEntropyLoss | 0     \n",
      "2 | metrics     | MetricCollection | 0     \n",
      "3 | val_metrics | MetricCollection | 0     \n",
      "4 | train_loss  | MeanMetric       | 0     \n",
      "5 | val_loss    | MeanMetric       | 0     \n",
      "6 | blocks      | ModuleList       | 85.1 M\n",
      "7 | stem        | ViTEmbeddings    | 742 K \n",
      "8 | lossblocks  | ModuleList       | 941 K \n",
      "-------------------------------------------------\n",
      "978 K     Trainable params\n",
      "85.8 M    Non-trainable params\n",
      "86.8 M    Total params\n",
      "347.101   Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.57it/s]s=4.540, acc=0.125, val_loss=4.560, val_acc=0.0342]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.60it/s]s=4.520, acc=0.0625, val_loss=4.510, val_acc=0.102] \n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.53it/s]s=4.470, acc=0.250, val_loss=4.450, val_acc=0.228] \n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.46it/s]s=4.410, acc=0.375, val_loss=4.400, val_acc=0.359]\n",
      "Validation DataLoader 0: 100%|██████████| 157/157 [00:28<00:00,  5.50it/s]s=4.300, acc=0.500, val_loss=4.350, val_acc=0.477]\n",
      "Epoch 5:   3%|▎         | 20/782 [00:08<05:12,  2.44it/s, v_num=xqge, loss=4.340, acc=0.422, val_loss=4.350, val_acc=0.477] "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for idx in [1,2,3]:\n",
    "    for h in [1,2,5,12]:\n",
    "        Model_dict={'None': MPCmodel}\n",
    "        if h<12:\n",
    "            Model_dict['Modify']=LossModifiedModel_final\n",
    "        if h==1:\n",
    "            Model_dict['Test']=LossModifiedModel_test\n",
    "        for modifyname,Modeltype in Model_dict.items():\n",
    "            config=baseconfig_vitb16.copy()\n",
    "            config.update({\n",
    "                \"horizon\":h,\n",
    "                \"stride\":1,\n",
    "                'idx':idx,\n",
    "                'modify':modifyname,\n",
    "                'r':1,\n",
    "                'alpha':4,\n",
    "            })\n",
    "            if modifyname=='Modify':\n",
    "                config['lambda_modify']=1.\n",
    "                modifyname+=f\"_lamb{config['lambda_modify']}\"\n",
    "            if modifyname=='None':\n",
    "                name='vit-b16'+f'_{h}'+f\"_id{config['idx']}\"\n",
    "            else:\n",
    "                name='vit-b16'+f'_{h}'+'_'+modifyname+f\"_id{config['idx']}\"\n",
    "            ViTb16_one_run(name,dataconfig,config,Modeltype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e250acc2",
   "metadata": {},
   "source": [
    "## Test Memory Usage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6b79cb66",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.autograd.profiler import record_function\n",
    "import json\n",
    "import numpy as np\n",
    "test_input=torch.rand((64,3,224,224),device=device)\n",
    "test_output=torch.randint(high=9,size=(64,),device=device)\n",
    "def get_max_gpu_usage(model,optimizer):\n",
    "    with torch.profiler.profile(\n",
    "           activities=[\n",
    "               torch.profiler.ProfilerActivity.CPU,\n",
    "               torch.profiler.ProfilerActivity.CUDA,\n",
    "           ],\n",
    "           schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),\n",
    "           record_shapes=True,\n",
    "           profile_memory=True,\n",
    "           with_stack=True,\n",
    "           # on_trace_ready=trace_handler,\n",
    "       ) as prof:\n",
    "           # Run the PyTorch Model inside the profile context.\n",
    "    \n",
    "        for i in range(15):\n",
    "            model.count=i\n",
    "            prof.step()\n",
    "            if isinstance(model,BaseLossModifiedModel):\n",
    "                if i<2:\n",
    "                    with record_function(\"## loss modify ##\"):\n",
    "                       model.train_small_batches(test_input,test_output,i,optimizer=opt)\n",
    "                else:\n",
    "                    model.lambda_scale=1.\n",
    "                    with record_function(\"## mpc ##\"):\n",
    "                       model.train_one_batch(test_input,test_output,i,optimizer=opt)\n",
    "            else:\n",
    "                with record_function(\"## mpc ##\"):\n",
    "                   model.train_one_batch(test_input,test_output,i,optimizer=opt)\n",
    "        \n",
    "            with record_function(\"## optimizer ##\"):\n",
    "               optimizer.step()\n",
    "               optimizer.zero_grad(set_to_none=True)\n",
    "        \n",
    "        # Construct the memory timeline HTML plot.\n",
    "    prof.export_memory_timeline(f\"./tmp.json\", device=device)\n",
    "    with open('./tmp.json', 'r') as f:\n",
    "        memory_data = json.load(f)\n",
    "    return np.array(memory_data[1]).sum(1).max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a26f44",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lianhai/anaconda3/envs/mytorch/lib/python3.9/site-packages/torch/profiler/profiler.py:406: UserWarning: Profiler won't be using warmup, this can skew profiler results\n",
      "  warn(\"Profiler won't be using warmup, this can skew profiler results\")\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from Models import Create_LoKr_ResNet50,Create_ResNet50,Create_ViTb16\n",
    "from ModifiedModel import BaseMPCmodel,BaseLossModifiedModel_a_lambda,BaseLossModifiedModel\n",
    "df_memory=pd.DataFrame()\n",
    "# df_memory=pd.read_csv('memory_usage',index_col=0)\n",
    "for model_name,creat_model_fn in [('ResNet-50',Create_ResNet50),('LoKr-ResNet-50',Create_LoKr_ResNet50),('ViT-b16',Create_ViTb16)]:\n",
    "    print(model_name)\n",
    "    for idx in [1,2,3]:\n",
    "        print(idx)\n",
    "        for name,Modeltype in [('mpc',BaseMPCmodel),('Modify',BaseLossModifiedModel_a_lambda)]:\n",
    "            for h in [17,5,2,1]:\n",
    "#                 if df_memory[(df_memory.horizon==h)&(df_memory.method==name)&(df_memory.model==model_name)&(df_memory.idx==idx)].shape[0]:\n",
    "#                     print(name,h,df_memory[(df_memory.horizon==h)&(df_memory.method==name)&(df_memory.model==model_name)].memory.mean())\n",
    "#                     continue\n",
    "                config={'horizon':h,'stride':1,'learning_rate':learning_rate}\n",
    "                model=creat_model_fn(config,dataconfig,Modeltype).to(device)\n",
    "                model.device=device\n",
    "                model.period=300\n",
    "                model._small_batches=2\n",
    "                model.batch_size=64\n",
    "                model.modify_batch=10\n",
    "                opt=torch.optim.SGD(model.model.parameters(),lr=learning_rate,momentum=0.9)\n",
    "                memory=get_max_gpu_usage(model,opt)/1024**2\n",
    "                df_memory=pd.concat([df_memory,pd.DataFrame({'model':model_name,'method':name,'horizon':h,'memory':memory,'idx':idx},index=[0])],ignore_index=True)\n",
    "                df_memory.to_csv('memory_usage')\n",
    "                print(h,name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5087513",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:mytorch]",
   "language": "python",
   "name": "conda-env-mytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
