{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b4208549",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import datetime\n",
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "import time\n",
    "from pathlib import Path\n",
    "import torch\n",
    "import torch.backends.cudnn as cudnn\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torchvision.transforms as transforms\n",
    "import timm\n",
    "from torchvision import datasets\n",
    "import util.misc as misc\n",
    "from util.misc import NativeScalerWithGradNormCount as NativeScaler\n",
    "import timm.optim.optim_factory as optim_factory\n",
    "import models_mae_shared\n",
    "from engine_pretrain import accuracy\n",
    "from einops import repeat\n",
    "import tqdm\n",
    "import math\n",
    "from scipy import stats\n",
    "from data import tt_image_folder\n",
    "from matplotlib import pyplot as plt\n",
    "from mae_utils import generate_encoder_attention_maps\n",
    "\n",
    "device = 'cuda'\n",
    "# data_path = '/home/group/ilsvrc/val'\n",
    "data_path = '/scratch/data/imagenet_c/jpeg_compression/5'\n",
    "num_workers = 8\n",
    "pin_mem = True\n",
    "lr = 5.00e-06\n",
    "weight_decay = 0.05\n",
    "train_steps = 10\n",
    "mask_ratio = 0.75\n",
    "steps_per_example = 32\n",
    "input_size = 224\n",
    "batch_size = 32\n",
    "\n",
    "imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
    "imagenet_std = np.array([0.229, 0.224, 0.225])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b995d828",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/anaconda3/envs/taming/lib/python3.8/site-packages/torchvision/transforms/transforms.py:332: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.\n",
      "  warnings.warn(\n",
      "/home/anaconda3/envs/taming/lib/python3.8/site-packages/torchvision/transforms/transforms.py:890: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using dataset /scratch/data/imagenet_c/jpeg_compression/5 with 5120000\n"
     ]
    }
   ],
   "source": [
    "minimizer = np.load('/home/test_time_training/ttt_mae_v1/models/imagenet/perm.npy') \n",
    "transform_val = transforms.Compose([\n",
    "        transforms.Resize(256, interpolation=3),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n",
    "transform_train = transforms.Compose([\n",
    "        transforms.RandomResizedCrop(input_size, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic\n",
    "        transforms.RandomHorizontalFlip(),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n",
    "dataset_train = tt_image_folder.ExtendedImageFolder(data_path, transform=transform_train, \n",
    "                                                    batch_size=batch_size, steps_per_example=steps_per_example, minimizer=minimizer)\n",
    "dataset_val = tt_image_folder.ExtendedImageFolder(data_path, transform=transform_val, batch_size=batch_size, steps_per_example=steps_per_example, minimizer=minimizer)\n",
    "num_classes = 1000\n",
    "print(f'Using dataset {data_path} with {len(dataset_val)}')\n",
    "dataset_val = iter(torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=2))\n",
    "dataset_train = iter(torch.utils.data.DataLoader(dataset_train, batch_size=1, shuffle=False, num_workers=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dd4edbe9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_IncompatibleKeys(missing_keys=['classifier_pos_embed', 'classifier_embed.weight', 'classifier_embed.bias', 'classifier_blocks.0.norm1.weight', 'classifier_blocks.0.norm1.bias', 'classifier_blocks.0.attn.qkv.weight', 'classifier_blocks.0.attn.qkv.bias', 'classifier_blocks.0.attn.proj.weight', 'classifier_blocks.0.attn.proj.bias', 'classifier_blocks.0.norm2.weight', 'classifier_blocks.0.norm2.bias', 'classifier_blocks.0.mlp.fc1.weight', 'classifier_blocks.0.mlp.fc1.bias', 'classifier_blocks.0.mlp.fc2.weight', 'classifier_blocks.0.mlp.fc2.bias', 'classifier_blocks.1.norm1.weight', 'classifier_blocks.1.norm1.bias', 'classifier_blocks.1.attn.qkv.weight', 'classifier_blocks.1.attn.qkv.bias', 'classifier_blocks.1.attn.proj.weight', 'classifier_blocks.1.attn.proj.bias', 'classifier_blocks.1.norm2.weight', 'classifier_blocks.1.norm2.bias', 'classifier_blocks.1.mlp.fc1.weight', 'classifier_blocks.1.mlp.fc1.bias', 'classifier_blocks.1.mlp.fc2.weight', 'classifier_blocks.1.mlp.fc2.bias', 'classifier_blocks.2.norm1.weight', 'classifier_blocks.2.norm1.bias', 'classifier_blocks.2.attn.qkv.weight', 'classifier_blocks.2.attn.qkv.bias', 'classifier_blocks.2.attn.proj.weight', 'classifier_blocks.2.attn.proj.bias', 'classifier_blocks.2.norm2.weight', 'classifier_blocks.2.norm2.bias', 'classifier_blocks.2.mlp.fc1.weight', 'classifier_blocks.2.mlp.fc1.bias', 'classifier_blocks.2.mlp.fc2.weight', 'classifier_blocks.2.mlp.fc2.bias', 'classifier_blocks.3.norm1.weight', 'classifier_blocks.3.norm1.bias', 'classifier_blocks.3.attn.qkv.weight', 'classifier_blocks.3.attn.qkv.bias', 'classifier_blocks.3.attn.proj.weight', 'classifier_blocks.3.attn.proj.bias', 'classifier_blocks.3.norm2.weight', 'classifier_blocks.3.norm2.bias', 'classifier_blocks.3.mlp.fc1.weight', 'classifier_blocks.3.mlp.fc1.bias', 'classifier_blocks.3.mlp.fc2.weight', 'classifier_blocks.3.mlp.fc2.bias', 'classifier_blocks.4.norm1.weight', 'classifier_blocks.4.norm1.bias', 'classifier_blocks.4.attn.qkv.weight', 'classifier_blocks.4.attn.qkv.bias', 'classifier_blocks.4.attn.proj.weight', 'classifier_blocks.4.attn.proj.bias', 'classifier_blocks.4.norm2.weight', 'classifier_blocks.4.norm2.bias', 'classifier_blocks.4.mlp.fc1.weight', 'classifier_blocks.4.mlp.fc1.bias', 'classifier_blocks.4.mlp.fc2.weight', 'classifier_blocks.4.mlp.fc2.bias', 'classifier_blocks.5.norm1.weight', 'classifier_blocks.5.norm1.bias', 'classifier_blocks.5.attn.qkv.weight', 'classifier_blocks.5.attn.qkv.bias', 'classifier_blocks.5.attn.proj.weight', 'classifier_blocks.5.attn.proj.bias', 'classifier_blocks.5.norm2.weight', 'classifier_blocks.5.norm2.bias', 'classifier_blocks.5.mlp.fc1.weight', 'classifier_blocks.5.mlp.fc1.bias', 'classifier_blocks.5.mlp.fc2.weight', 'classifier_blocks.5.mlp.fc2.bias', 'classifier_blocks.6.norm1.weight', 'classifier_blocks.6.norm1.bias', 'classifier_blocks.6.attn.qkv.weight', 'classifier_blocks.6.attn.qkv.bias', 'classifier_blocks.6.attn.proj.weight', 'classifier_blocks.6.attn.proj.bias', 'classifier_blocks.6.norm2.weight', 'classifier_blocks.6.norm2.bias', 'classifier_blocks.6.mlp.fc1.weight', 'classifier_blocks.6.mlp.fc1.bias', 'classifier_blocks.6.mlp.fc2.weight', 'classifier_blocks.6.mlp.fc2.bias', 'classifier_blocks.7.norm1.weight', 'classifier_blocks.7.norm1.bias', 'classifier_blocks.7.attn.qkv.weight', 'classifier_blocks.7.attn.qkv.bias', 'classifier_blocks.7.attn.proj.weight', 'classifier_blocks.7.attn.proj.bias', 'classifier_blocks.7.norm2.weight', 'classifier_blocks.7.norm2.bias', 'classifier_blocks.7.mlp.fc1.weight', 'classifier_blocks.7.mlp.fc1.bias', 'classifier_blocks.7.mlp.fc2.weight', 'classifier_blocks.7.mlp.fc2.bias', 'classifier_blocks.8.norm1.weight', 'classifier_blocks.8.norm1.bias', 'classifier_blocks.8.attn.qkv.weight', 'classifier_blocks.8.attn.qkv.bias', 'classifier_blocks.8.attn.proj.weight', 'classifier_blocks.8.attn.proj.bias', 'classifier_blocks.8.norm2.weight', 'classifier_blocks.8.norm2.bias', 'classifier_blocks.8.mlp.fc1.weight', 'classifier_blocks.8.mlp.fc1.bias', 'classifier_blocks.8.mlp.fc2.weight', 'classifier_blocks.8.mlp.fc2.bias', 'classifier_blocks.9.norm1.weight', 'classifier_blocks.9.norm1.bias', 'classifier_blocks.9.attn.qkv.weight', 'classifier_blocks.9.attn.qkv.bias', 'classifier_blocks.9.attn.proj.weight', 'classifier_blocks.9.attn.proj.bias', 'classifier_blocks.9.norm2.weight', 'classifier_blocks.9.norm2.bias', 'classifier_blocks.9.mlp.fc1.weight', 'classifier_blocks.9.mlp.fc1.bias', 'classifier_blocks.9.mlp.fc2.weight', 'classifier_blocks.9.mlp.fc2.bias', 'classifier_blocks.10.norm1.weight', 'classifier_blocks.10.norm1.bias', 'classifier_blocks.10.attn.qkv.weight', 'classifier_blocks.10.attn.qkv.bias', 'classifier_blocks.10.attn.proj.weight', 'classifier_blocks.10.attn.proj.bias', 'classifier_blocks.10.norm2.weight', 'classifier_blocks.10.norm2.bias', 'classifier_blocks.10.mlp.fc1.weight', 'classifier_blocks.10.mlp.fc1.bias', 'classifier_blocks.10.mlp.fc2.weight', 'classifier_blocks.10.mlp.fc2.bias', 'classifier_blocks.11.norm1.weight', 'classifier_blocks.11.norm1.bias', 'classifier_blocks.11.attn.qkv.weight', 'classifier_blocks.11.attn.qkv.bias', 'classifier_blocks.11.attn.proj.weight', 'classifier_blocks.11.attn.proj.bias', 'classifier_blocks.11.norm2.weight', 'classifier_blocks.11.norm2.bias', 'classifier_blocks.11.mlp.fc1.weight', 'classifier_blocks.11.mlp.fc1.bias', 'classifier_blocks.11.mlp.fc2.weight', 'classifier_blocks.11.mlp.fc2.bias', 'classifier_norm.weight', 'classifier_norm.bias', 'classifier_pred.weight', 'classifier_pred.bias'], unexpected_keys=['mask_token', 'decoder_pos_embed', 'decoder_embed.weight', 'decoder_embed.bias', 'decoder_blocks.0.norm1.weight', 'decoder_blocks.0.norm1.bias', 'decoder_blocks.0.attn.qkv.weight', 'decoder_blocks.0.attn.qkv.bias', 'decoder_blocks.0.attn.proj.weight', 'decoder_blocks.0.attn.proj.bias', 'decoder_blocks.0.norm2.weight', 'decoder_blocks.0.norm2.bias', 'decoder_blocks.0.mlp.fc1.weight', 'decoder_blocks.0.mlp.fc1.bias', 'decoder_blocks.0.mlp.fc2.weight', 'decoder_blocks.0.mlp.fc2.bias', 'decoder_blocks.1.norm1.weight', 'decoder_blocks.1.norm1.bias', 'decoder_blocks.1.attn.qkv.weight', 'decoder_blocks.1.attn.qkv.bias', 'decoder_blocks.1.attn.proj.weight', 'decoder_blocks.1.attn.proj.bias', 'decoder_blocks.1.norm2.weight', 'decoder_blocks.1.norm2.bias', 'decoder_blocks.1.mlp.fc1.weight', 'decoder_blocks.1.mlp.fc1.bias', 'decoder_blocks.1.mlp.fc2.weight', 'decoder_blocks.1.mlp.fc2.bias', 'decoder_blocks.2.norm1.weight', 'decoder_blocks.2.norm1.bias', 'decoder_blocks.2.attn.qkv.weight', 'decoder_blocks.2.attn.qkv.bias', 'decoder_blocks.2.attn.proj.weight', 'decoder_blocks.2.attn.proj.bias', 'decoder_blocks.2.norm2.weight', 'decoder_blocks.2.norm2.bias', 'decoder_blocks.2.mlp.fc1.weight', 'decoder_blocks.2.mlp.fc1.bias', 'decoder_blocks.2.mlp.fc2.weight', 'decoder_blocks.2.mlp.fc2.bias', 'decoder_blocks.3.norm1.weight', 'decoder_blocks.3.norm1.bias', 'decoder_blocks.3.attn.qkv.weight', 'decoder_blocks.3.attn.qkv.bias', 'decoder_blocks.3.attn.proj.weight', 'decoder_blocks.3.attn.proj.bias', 'decoder_blocks.3.norm2.weight', 'decoder_blocks.3.norm2.bias', 'decoder_blocks.3.mlp.fc1.weight', 'decoder_blocks.3.mlp.fc1.bias', 'decoder_blocks.3.mlp.fc2.weight', 'decoder_blocks.3.mlp.fc2.bias', 'decoder_blocks.4.norm1.weight', 'decoder_blocks.4.norm1.bias', 'decoder_blocks.4.attn.qkv.weight', 'decoder_blocks.4.attn.qkv.bias', 'decoder_blocks.4.attn.proj.weight', 'decoder_blocks.4.attn.proj.bias', 'decoder_blocks.4.norm2.weight', 'decoder_blocks.4.norm2.bias', 'decoder_blocks.4.mlp.fc1.weight', 'decoder_blocks.4.mlp.fc1.bias', 'decoder_blocks.4.mlp.fc2.weight', 'decoder_blocks.4.mlp.fc2.bias', 'decoder_blocks.5.norm1.weight', 'decoder_blocks.5.norm1.bias', 'decoder_blocks.5.attn.qkv.weight', 'decoder_blocks.5.attn.qkv.bias', 'decoder_blocks.5.attn.proj.weight', 'decoder_blocks.5.attn.proj.bias', 'decoder_blocks.5.norm2.weight', 'decoder_blocks.5.norm2.bias', 'decoder_blocks.5.mlp.fc1.weight', 'decoder_blocks.5.mlp.fc1.bias', 'decoder_blocks.5.mlp.fc2.weight', 'decoder_blocks.5.mlp.fc2.bias', 'decoder_blocks.6.norm1.weight', 'decoder_blocks.6.norm1.bias', 'decoder_blocks.6.attn.qkv.weight', 'decoder_blocks.6.attn.qkv.bias', 'decoder_blocks.6.attn.proj.weight', 'decoder_blocks.6.attn.proj.bias', 'decoder_blocks.6.norm2.weight', 'decoder_blocks.6.norm2.bias', 'decoder_blocks.6.mlp.fc1.weight', 'decoder_blocks.6.mlp.fc1.bias', 'decoder_blocks.6.mlp.fc2.weight', 'decoder_blocks.6.mlp.fc2.bias', 'decoder_blocks.7.norm1.weight', 'decoder_blocks.7.norm1.bias', 'decoder_blocks.7.attn.qkv.weight', 'decoder_blocks.7.attn.qkv.bias', 'decoder_blocks.7.attn.proj.weight', 'decoder_blocks.7.attn.proj.bias', 'decoder_blocks.7.norm2.weight', 'decoder_blocks.7.norm2.bias', 'decoder_blocks.7.mlp.fc1.weight', 'decoder_blocks.7.mlp.fc1.bias', 'decoder_blocks.7.mlp.fc2.weight', 'decoder_blocks.7.mlp.fc2.bias', 'decoder_norm.weight', 'decoder_norm.bias', 'decoder_pred.weight', 'decoder_pred.bias'])\n",
      "Missing keys {'classifier_blocks.9.norm1.bias', 'classifier_blocks.11.attn.proj.weight', 'classifier_blocks.10.attn.proj.bias', 'classifier_blocks.3.mlp.fc2.weight', 'classifier_blocks.1.attn.qkv.bias', 'classifier_blocks.8.norm1.weight', 'classifier_blocks.10.attn.qkv.weight', 'classifier_blocks.1.attn.proj.weight', 'classifier_blocks.5.norm1.weight', 'classifier_blocks.11.mlp.fc1.bias', 'classifier_blocks.4.mlp.fc1.bias', 'classifier_blocks.6.attn.qkv.bias', 'classifier_blocks.2.mlp.fc1.weight', 'classifier_blocks.9.norm2.weight', 'classifier_blocks.11.attn.qkv.weight', 'classifier_blocks.0.attn.proj.bias', 'classifier_blocks.7.norm1.weight', 'classifier_blocks.4.attn.proj.bias', 'classifier_blocks.0.attn.qkv.weight', 'classifier_blocks.7.attn.qkv.weight', 'classifier_blocks.2.attn.proj.bias', 'classifier_blocks.9.attn.proj.weight', 'classifier_blocks.0.mlp.fc1.bias', 'classifier_blocks.2.norm2.weight', 'classifier_blocks.7.mlp.fc2.weight', 'classifier_blocks.5.attn.proj.weight', 'classifier_embed.weight', 'classifier_blocks.3.norm2.bias', 'classifier_blocks.9.attn.proj.bias', 'classifier_norm.bias', 'classifier_blocks.5.norm2.weight', 'classifier_blocks.6.norm1.bias', 'classifier_blocks.0.attn.qkv.bias', 'classifier_blocks.2.attn.proj.weight', 'classifier_blocks.8.mlp.fc1.bias', 'classifier_blocks.9.norm2.bias', 'classifier_blocks.3.norm2.weight', 'classifier_blocks.11.mlp.fc2.bias', 'classifier_blocks.1.norm1.bias', 'classifier_blocks.3.attn.qkv.weight', 'classifier_blocks.1.mlp.fc2.weight', 'classifier_blocks.5.attn.qkv.weight', 'classifier_blocks.9.norm1.weight', 'classifier_norm.weight', 'classifier_blocks.7.mlp.fc2.bias', 'classifier_blocks.3.norm1.weight', 'classifier_blocks.10.mlp.fc2.weight', 'classifier_blocks.10.mlp.fc2.bias', 'classifier_blocks.5.norm1.bias', 'classifier_blocks.2.attn.qkv.weight', 'classifier_pos_embed', 'classifier_blocks.4.attn.proj.weight', 'classifier_blocks.6.norm2.bias', 'classifier_blocks.11.attn.proj.bias', 'classifier_blocks.4.norm1.bias', 'classifier_blocks.6.attn.proj.bias', 'classifier_blocks.9.attn.qkv.bias', 'classifier_blocks.1.mlp.fc1.weight', 'classifier_blocks.11.mlp.fc2.weight', 'classifier_blocks.9.mlp.fc2.bias', 'classifier_blocks.7.mlp.fc1.weight', 'classifier_blocks.0.norm2.weight', 'classifier_blocks.4.mlp.fc1.weight', 'classifier_blocks.0.mlp.fc2.bias', 'classifier_blocks.6.mlp.fc2.bias', 'classifier_blocks.6.mlp.fc2.weight', 'classifier_blocks.4.norm2.bias', 'classifier_blocks.2.norm1.weight', 'classifier_blocks.10.mlp.fc1.bias', 'classifier_blocks.1.norm1.weight', 'classifier_blocks.2.norm1.bias', 'classifier_blocks.1.norm2.weight', 'classifier_blocks.5.mlp.fc2.weight', 'classifier_blocks.7.norm1.bias', 'classifier_blocks.8.mlp.fc2.bias', 'classifier_blocks.1.attn.qkv.weight', 'classifier_blocks.7.mlp.fc1.bias', 'classifier_blocks.8.norm2.weight', 'classifier_blocks.1.mlp.fc1.bias', 'classifier_blocks.3.attn.proj.weight', 'classifier_blocks.5.mlp.fc1.weight', 'classifier_blocks.10.norm1.bias', 'classifier_blocks.5.mlp.fc2.bias', 'classifier_blocks.9.mlp.fc1.bias', 'classifier_blocks.0.mlp.fc1.weight', 'classifier_blocks.7.norm2.bias', 'classifier_blocks.9.mlp.fc2.weight', 'classifier_blocks.6.norm2.weight', 'classifier_blocks.3.attn.proj.bias', 'classifier_blocks.11.mlp.fc1.weight', 'classifier_blocks.7.attn.proj.bias', 'classifier_blocks.11.norm2.weight', 'classifier_pred.weight', 'classifier_blocks.11.norm2.bias', 'classifier_blocks.8.norm2.bias', 'classifier_blocks.0.norm2.bias', 'classifier_blocks.3.mlp.fc1.weight', 'classifier_blocks.7.attn.proj.weight', 'classifier_blocks.0.norm1.weight', 'classifier_blocks.1.attn.proj.bias', 'classifier_blocks.5.mlp.fc1.bias', 'classifier_blocks.4.mlp.fc2.weight', 'classifier_blocks.10.norm1.weight', 'classifier_blocks.11.norm1.bias', 'classifier_blocks.7.attn.qkv.bias', 'classifier_blocks.10.attn.qkv.bias', 'classifier_blocks.6.norm1.weight', 'classifier_blocks.0.mlp.fc2.weight', 'classifier_blocks.4.norm2.weight', 'classifier_blocks.10.attn.proj.weight', 'classifier_blocks.8.attn.proj.weight', 'classifier_blocks.7.norm2.weight', 'classifier_pred.bias', 'classifier_blocks.0.attn.proj.weight', 'classifier_blocks.10.mlp.fc1.weight', 'classifier_blocks.8.attn.qkv.bias', 'classifier_blocks.3.attn.qkv.bias', 'classifier_blocks.9.mlp.fc1.weight', 'classifier_blocks.6.attn.proj.weight', 'classifier_blocks.3.mlp.fc1.bias', 'classifier_blocks.8.mlp.fc2.weight', 'classifier_blocks.8.attn.qkv.weight', 'classifier_blocks.1.norm2.bias', 'classifier_blocks.4.mlp.fc2.bias', 'classifier_blocks.3.norm1.bias', 'classifier_blocks.5.attn.qkv.bias', 'classifier_blocks.6.attn.qkv.weight', 'classifier_blocks.0.norm1.bias', 'classifier_blocks.2.attn.qkv.bias', 'classifier_embed.bias', 'classifier_blocks.2.norm2.bias', 'classifier_blocks.5.attn.proj.bias', 'classifier_blocks.10.norm2.weight', 'classifier_blocks.1.mlp.fc2.bias', 'classifier_blocks.8.mlp.fc1.weight', 'classifier_blocks.6.mlp.fc1.weight', 'classifier_blocks.8.attn.proj.bias', 'classifier_blocks.4.norm1.weight', 'classifier_blocks.11.attn.qkv.bias', 'classifier_blocks.2.mlp.fc1.bias', 'classifier_blocks.10.norm2.bias', 'classifier_blocks.9.attn.qkv.weight', 'classifier_blocks.8.norm1.bias', 'classifier_blocks.2.mlp.fc2.bias', 'classifier_blocks.5.norm2.bias', 'classifier_blocks.4.attn.qkv.bias', 'classifier_blocks.2.mlp.fc2.weight', 'classifier_blocks.6.mlp.fc1.bias', 'classifier_blocks.11.norm1.weight', 'classifier_blocks.4.attn.qkv.weight', 'classifier_blocks.3.mlp.fc2.bias'}\n"
     ]
    }
   ],
   "source": [
    "checkpoint = torch.load('models/cifar10/mae_visualize_vit_large.pth', map_location='cpu')\n",
    "\n",
    "checkpoint_model = checkpoint['model']\n",
    "model_name = 'mae_vit_large_patch16'\n",
    "model = models_mae_shared.__dict__[model_name](head_type='vit_head', no_decoder=True)\n",
    "\n",
    "msg = model.load_state_dict(checkpoint_model, strict=False)\n",
    "print(msg)\n",
    "\n",
    "print('Missing keys', set(msg.missing_keys))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4f3a4559",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in msg.missing_keys:\n",
    "    if i.startswith('classifier'):\n",
    "        continue\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2febb94c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
