{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ee157b02",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "import argparse\n",
    "import datetime\n",
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "import time\n",
    "from pathlib import Path\n",
    "import torch\n",
    "import torch.backends.cudnn as cudnn\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torchvision.transforms as transforms\n",
    "import timm\n",
    "from torchvision import datasets\n",
    "import util.misc as misc\n",
    "from util.misc import NativeScalerWithGradNormCount as NativeScaler\n",
    "import timm.optim.optim_factory as optim_factory\n",
    "import models_mae_shared\n",
    "from engine_pretrain import accuracy\n",
    "from einops import repeat\n",
    "import tqdm\n",
    "import math\n",
    "from data import tt_image_folder\n",
    "from matplotlib import pyplot as plt\n",
    "from mae_utils import generate_encoder_attention_maps\n",
    "\n",
    "device = 'cuda'\n",
    "# data_path = '/home/group/ilsvrc/val'\n",
    "# data_path = '/home/group/imagenet_c/jpeg_compression/5'\n",
    "data_path = '/scratch/data/imagenet_c/zoom_blur/5'\n",
    "num_workers = 8\n",
    "pin_mem = True\n",
    "lr =  5.00e-03 # 1.00e-02\n",
    "weight_decay = 0.0\n",
    "batch_size = 128\n",
    "mask_ratio = 0.75\n",
    "steps_per_example = 100\n",
    "input_size = 224\n",
    "\n",
    "seed = 0\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "# TODO: restart all the \n",
    "\n",
    "imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
    "imagenet_std = np.array([0.229, 0.224, 0.225])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "411b8331",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/anaconda3/envs/taming/lib/python3.8/site-packages/torchvision/transforms/transforms.py:332: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using dataset /scratch/data/imagenet_c/zoom_blur/5 with 5000\n"
     ]
    }
   ],
   "source": [
    "minimizer = np.load(\"/home/test_time_training/ttt_mae_v1/models/imagenet/perm.npy\")\n",
    "transform_val = transforms.Compose([\n",
    "        transforms.Resize(256, interpolation=3),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n",
    "transform_train = transforms.Compose([\n",
    "        transforms.Resize(256, interpolation=3),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n",
    "dataset_train = tt_image_folder.ExtendedImageFolder(data_path, transform=transform_train, \n",
    "                                                    batch_size=batch_size, \n",
    "                                                    steps_per_example=steps_per_example,\n",
    "                                                    minimizer=minimizer)\n",
    "dataset_val = tt_image_folder.ExtendedImageFolder(data_path, transform=transform_val,\n",
    "                                                 minimizer=minimizer)\n",
    "num_classes = 1000\n",
    "print(f'Using dataset {data_path} with {len(dataset_val)}')\n",
    "dataset_val = iter(torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=2))\n",
    "dataset_train = iter(torch.utils.data.DataLoader(dataset_train, batch_size=1, shuffle=False, num_workers=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "301875aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# resume_model = '/home/code/deep_transformer_prior/mae_orig/demo/mae_visualize_vit_large.pth'\n",
    "# # resume_model = '/home/code/test_time_training/ttt_mae/models/mae_pretrain_vit_large_full.pth'\n",
    "# resume_finetune = '/home/code/deep_transformer_prior/mae_orig/output_dir/mae_orig_finetune_with_decoder/checkpoint-89.pth'\n",
    "# # resume_finetune = '/home/code/test_time_training/ttt_mae/models/mae_linear_vit_large.pth'\n",
    "# model_name = 'mae_vit_large_patch16'\n",
    "# norm_pix_loss = False\n",
    "# model = models_mae_shared.__dict__[model_name](num_classes=1000, \n",
    "#                                                head_type='linear', \n",
    "#                                                norm_pix_loss=norm_pix_loss, \n",
    "#                                                classifier_depth=12)\n",
    "# model_checkpoint = torch.load(resume_model, map_location='cpu')['model']\n",
    "# head_checkpoint = torch.load(resume_finetune, map_location='cpu')['model']\n",
    "\n",
    "# model_checkpoint[\"bn.running_mean\"] = head_checkpoint[\"head.0.running_mean\"]\n",
    "# model_checkpoint[\"bn.running_var\"] = head_checkpoint[\"head.0.running_var\"]\n",
    "# model_checkpoint[\"head.weight\"] = head_checkpoint[\"head.1.weight\"]\n",
    "# model_checkpoint[\"head.bias\"] = head_checkpoint[\"head.1.bias\"]\n",
    "# model.load_state_dict(model_checkpoint)\n",
    "# model = model.to(device)\n",
    "\n",
    "\n",
    "\n",
    "# resume_model = '/home/code/deep_transformer_prior/mae_orig/demo/mae_visualize_vit_large.pth'\n",
    "# resume_model = '/home/code/test_time_training/ttt_mae/models/mae_pretrain_vit_large_full.pth'\n",
    "resume_model = '/home/test_time_training/ttt_mae_v1/models/imagenet/vit_head_2_layers/mae_pretrain_vit_large_full.pth'\n",
    "# resume_finetune = '/home/code/test_time_training/ttt_mae/models/vis_lr1e-3_wd.2_blk12_ep20.pth'\n",
    "# resume_finetune = '/home/code/test_time_training/ttt_mae/models/prob_lr1e-3_wd.2_blk12_ep20.pth'\n",
    "resume_finetune = '/home/test_time_training/ttt_mae_v1/models/imagenet/vit_head_2_layers/checkpoint-19.pth'\n",
    "model_name = 'mae_vit_large_patch16'\n",
    "norm_pix_loss=True\n",
    "model = models_mae_shared.__dict__[model_name](num_classes=1000, \n",
    "                                               head_type='vit_head', \n",
    "                                               norm_pix_loss=norm_pix_loss, \n",
    "                                               classifier_depth=12)\n",
    "model_checkpoint = torch.load(resume_model, map_location='cpu')['model']\n",
    "head_checkpoint = torch.load(resume_finetune, map_location='cpu')['model']\n",
    "\n",
    "for key in head_checkpoint:\n",
    "    if key.startswith('classifier'):\n",
    "        model_checkpoint[key] = head_checkpoint[key]\n",
    "model.load_state_dict(model_checkpoint)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1a5b48d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, p in model.named_parameters():\n",
    "    if name.startswith('decoder'):\n",
    "        p.requires_grad = False\n",
    "parameters = [p for p in model.parameters() if p.requires_grad]\n",
    "# optimizer = torch.optim.SGD(parameters, lr=lr)\n",
    "optimizer = torch.optim.SGD(parameters, lr=lr, momentum=0.9)\n",
    "# optimizer = torch.optim.AdamW(parameters, lr=lr, betas=(0.9, 0.95))\n",
    "# optimizer = torch.optim.Adam(parameters, lr=lr, betas=(0.9, 0.95))\n",
    "optimizer.zero_grad()\n",
    "loss_scaler = NativeScaler()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "53cc716d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test\n",
      "step: 0, acc 100.0, loss {'classification': tensor(0.5363, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2746, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 1, acc 100.0, loss {'classification': tensor(0.6305, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2636, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 2, acc 100.0, loss {'classification': tensor(0.6118, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2678, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 3, acc 100.0, loss {'classification': tensor(0.5843, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2608, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 4, acc 100.0, loss {'classification': tensor(0.5681, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2632, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 5, acc 100.0, loss {'classification': tensor(0.5459, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2615, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 6, acc 100.0, loss {'classification': tensor(0.5321, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2521, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 7, acc 100.0, loss {'classification': tensor(0.5885, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2531, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 8, acc 100.0, loss {'classification': tensor(0.6581, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2499, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 9, acc 0.0, loss {'classification': tensor(0.7252, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2459, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 10, acc 0.0, loss {'classification': tensor(0.7759, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2383, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 11, acc 0.0, loss {'classification': tensor(0.7872, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2382, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 12, acc 0.0, loss {'classification': tensor(0.7808, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2343, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 13, acc 0.0, loss {'classification': tensor(0.7700, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2306, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 14, acc 0.0, loss {'classification': tensor(0.7380, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2284, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 15, acc 0.0, loss {'classification': tensor(0.7271, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2216, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 16, acc 0.0, loss {'classification': tensor(0.7407, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2265, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 17, acc 0.0, loss {'classification': tensor(0.7756, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2229, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 18, acc 0.0, loss {'classification': tensor(0.8017, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2162, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 19, acc 0.0, loss {'classification': tensor(0.8029, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2155, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 20, acc 0.0, loss {'classification': tensor(0.7976, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2149, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 21, acc 0.0, loss {'classification': tensor(0.7936, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2153, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 22, acc 0.0, loss {'classification': tensor(0.7998, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2099, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 23, acc 0.0, loss {'classification': tensor(0.8225, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2096, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 24, acc 0.0, loss {'classification': tensor(0.8501, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2060, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 25, acc 0.0, loss {'classification': tensor(0.8621, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2096, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 26, acc 0.0, loss {'classification': tensor(0.8671, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2040, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 27, acc 0.0, loss {'classification': tensor(0.8655, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.1973, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 28, acc 0.0, loss {'classification': tensor(0.8690, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.2020, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 29, acc 0.0, loss {'classification': tensor(0.8783, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.1985, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 30, acc 0.0, loss {'classification': tensor(0.8749, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.1936, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 31, acc 0.0, loss {'classification': tensor(0.8704, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.1941, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 32, acc 0.0, loss {'classification': tensor(0.8701, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.1921, device='cuda:0', grad_fn=<DivBackward0>)}\n",
      "Test\n",
      "step: 33, acc 0.0, loss {'classification': tensor(0.8759, device='cuda:0')}\n",
      "loss!!! {'mae': tensor(0.1910, device='cuda:0', grad_fn=<DivBackward0>)}\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [5]\u001b[0m, in \u001b[0;36m<cell line: 6>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     45\u001b[0m         sys\u001b[38;5;241m.\u001b[39mexit(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     46\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss!!!\u001b[39m\u001b[38;5;124m'\u001b[39m, loss_dict)\n\u001b[0;32m---> 47\u001b[0m     \u001b[43mloss_scaler\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparameters\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     48\u001b[0m \u001b[43m                \u001b[49m\u001b[43mupdate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m     49\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m     52\u001b[0m latents \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mconcatenate(latents)\n",
      "File \u001b[0;32m~/test_time_training/ttt_mae_v1/util/misc.py:267\u001b[0m, in \u001b[0;36mNativeScalerWithGradNormCount.__call__\u001b[0;34m(self, loss, optimizer, clip_grad, parameters, create_graph, update_grad)\u001b[0m\n\u001b[1;32m    265\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_scaler\u001b[38;5;241m.\u001b[39munscale_(optimizer)\n\u001b[1;32m    266\u001b[0m         norm \u001b[38;5;241m=\u001b[39m get_grad_norm_(parameters)\n\u001b[0;32m--> 267\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_scaler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    268\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_scaler\u001b[38;5;241m.\u001b[39mupdate()\n\u001b[1;32m    269\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m~/anaconda3/envs/taming/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py:338\u001b[0m, in \u001b[0;36mGradScaler.step\u001b[0;34m(self, optimizer, *args, **kwargs)\u001b[0m\n\u001b[1;32m    334\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39munscale_(optimizer)\n\u001b[1;32m    336\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(optimizer_state[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfound_inf_per_device\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo inf checks were recorded for this optimizer.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 338\u001b[0m retval \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maybe_opt_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    340\u001b[0m optimizer_state[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstage\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m OptState\u001b[38;5;241m.\u001b[39mSTEPPED\n\u001b[1;32m    342\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m retval\n",
      "File \u001b[0;32m~/anaconda3/envs/taming/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py:284\u001b[0m, in \u001b[0;36mGradScaler._maybe_opt_step\u001b[0;34m(self, optimizer, optimizer_state, *args, **kwargs)\u001b[0m\n\u001b[1;32m    282\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_maybe_opt_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, optimizer, optimizer_state, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    283\u001b[0m     retval \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 284\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;43msum\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43moptimizer_state\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfound_inf_per_device\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m    285\u001b[0m         retval \u001b[38;5;241m=\u001b[39m optimizer\u001b[38;5;241m.\u001b[39mstep(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    286\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m retval\n",
      "File \u001b[0;32m~/anaconda3/envs/taming/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py:284\u001b[0m, in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    282\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_maybe_opt_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, optimizer, optimizer_state, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    283\u001b[0m     retval \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 284\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28msum\u001b[39m(\u001b[43mv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m optimizer_state[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfound_inf_per_device\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mvalues()):\n\u001b[1;32m    285\u001b[0m         retval \u001b[38;5;241m=\u001b[39m optimizer\u001b[38;5;241m.\u001b[39mstep(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    286\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m retval\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "latents =  []\n",
    "all_results = []\n",
    "patch = 50\n",
    "(test_samples, labels) = next(dataset_val)\n",
    "test_samples = test_samples.to(device, non_blocking=True)[0]\n",
    "for data_iter_step in range(steps_per_example):\n",
    "    optimizer.zero_grad()\n",
    "    with torch.no_grad():\n",
    "        model.eval()\n",
    "        # Print:\n",
    "        print('Test')\n",
    "#         y = torch.einsum('nchw->nhwc', test_samples).detach().cpu().numpy()\n",
    "#         image_to_save = y[0] * imagenet_std[np.newaxis, np.newaxis, :] + imagenet_mean[np.newaxis, np.newaxis,:]\n",
    "#         plt.imshow(image_to_save)\n",
    "#         plt.show()\n",
    "        labels = torch.LongTensor([labels]).to(device, non_blocking=True)\n",
    "        loss_dict, _, _, pred = model(test_samples, labels, mask_ratio=0)\n",
    "        (acc1, acc5) = accuracy(pred, labels, topk=(1, 5))\n",
    "        all_results.append(acc1.detach().cpu().numpy())\n",
    "        print('step: {}, acc {}, loss {}'.format(data_iter_step, acc1.detach().cpu().numpy(), loss_dict))\n",
    "        \n",
    "        # Show attentions:\n",
    "#         attentions = generate_encoder_attention_maps(model, test_samples, device)\n",
    "#         f, axarr = plt.subplots(len(attentions), attentions[0].shape[1], figsize=(16, 16))\n",
    "#         for block, attn in enumerate(attentions):\n",
    "#             for head_i, head in enumerate(attn[0, :, 1:, 1:]):\n",
    "#                 axarr[block,head_i].imshow(np.uint8(255 * np.reshape(head[patch], (14,14))))\n",
    "\n",
    "#         plt.show()\n",
    "        \n",
    "    model.train(True)    \n",
    "    samples, _ = next(dataset_train)\n",
    "    samples = samples.to(device, non_blocking=True)[0]\n",
    "#     y = torch.einsum('nchw->nhwc', samples).detach().cpu().numpy()\n",
    "#     image_to_save = y[-1] * imagenet_std[np.newaxis, np.newaxis, :] + imagenet_mean[np.newaxis, np.newaxis,:]\n",
    "#     plt.imshow(image_to_save)\n",
    "#     plt.show()\n",
    "    loss_dict, _, latent, pred = model(samples, None, mask_ratio=mask_ratio)\n",
    "    latents.append(latent.detach().cpu().numpy())\n",
    "    loss = torch.stack([loss_dict[l] for l in loss_dict]).sum()\n",
    "    loss_value = loss.item()\n",
    "\n",
    "    if not math.isfinite(loss_value):\n",
    "        print(\"Loss is {}, stopping training\".format(loss_value))\n",
    "        sys.exit(1)\n",
    "    print('loss!!!', loss_dict)\n",
    "    loss_scaler(loss, optimizer, parameters=model.parameters(),\n",
    "                update_grad=True)\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    \n",
    "latents = np.concatenate(latents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2bfec4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "\n",
    "\n",
    "from sklearn import decomposition\n",
    "from sklearn import datasets\n",
    "\n",
    "y = []\n",
    "for i in range(10):\n",
    "    y.extend([i] * 32)\n",
    "y = np.array(y)\n",
    "\n",
    "X = latents\n",
    "\n",
    "fig = plt.figure(1, figsize=(4, 3))\n",
    "plt.clf()\n",
    "ax = Axes3D(fig, rect=[0, 0, 0.95, 1], elev=48, azim=134)\n",
    "\n",
    "plt.cla()\n",
    "pca = decomposition.PCA(n_components=3)\n",
    "pca.fit(X)\n",
    "X = pca.transform(X)\n",
    "\n",
    "for label in range(10):\n",
    "    name = str(label)\n",
    "    ax.text3D(\n",
    "        X[y == label, 0].mean(),\n",
    "        X[y == label, 1].mean() + 1.5,\n",
    "        X[y == label, 2].mean(),\n",
    "        name,\n",
    "        horizontalalignment=\"center\",\n",
    "        bbox=dict(alpha=0.5, edgecolor=\"w\", facecolor=\"w\"),\n",
    "    )\n",
    "# Reorder the labels to have colors matching the cluster results\n",
    "# y = np.choose(y, [1, 2, 0]).astype(float)\n",
    "ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y, cmap=plt.cm.nipy_spectral, edgecolor=\"k\")\n",
    "\n",
    "ax.w_xaxis.set_ticklabels([])\n",
    "ax.w_yaxis.set_ticklabels([])\n",
    "ax.w_zaxis.set_ticklabels([])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8084abbd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "218bc4ca",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "942f0894",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "00ad61eb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5000,)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "308045cc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
