{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ee157b02",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "import argparse\n",
    "import datetime\n",
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "import time\n",
    "from pathlib import Path\n",
    "import torch\n",
    "import torch.backends.cudnn as cudnn\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import timm\n",
    "from torchvision import datasets\n",
    "import util.misc as misc\n",
    "from util.misc import NativeScalerWithGradNormCount as NativeScaler\n",
    "import timm.optim.optim_factory as optim_factory\n",
    "import models_mae_shared\n",
    "from engine_pretrain import accuracy\n",
    "from einops import repeat\n",
    "import tqdm\n",
    "import math\n",
    "from data import tt_image_folder\n",
    "from matplotlib import pyplot as plt\n",
    "from mae_utils import generate_encoder_attention_maps\n",
    "\n",
    "device = 'cuda'\n",
    "# data_path = '/home/group/ilsvrc/val'\n",
    "# data_path = '/scratch/data/imagenet_c/jpeg_compression/5'\n",
    "data_path = '/scratch/data/imagenet_c/shot_noise/5'\n",
    "num_workers = 8\n",
    "pin_mem = True\n",
    "lr =  5.00e-03 # 1.00e-02\n",
    "weight_decay = 0.0\n",
    "batch_size = 128\n",
    "mask_ratio = 0.75\n",
    "steps_per_example = 11\n",
    "input_size = 224\n",
    "\n",
    "seed = 0\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "# TODO: restart all the \n",
    "\n",
    "imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
    "imagenet_std = np.array([0.229, 0.224, 0.225])\n",
    "\n",
    "resume_model = '/home/test_time_training/ttt_mae_v1/models/imagenet/vit_head_2_layers/mae_pretrain_vit_large_full.pth'\n",
    "resume_finetune = '/home/test_time_training/ttt_mae_v1/models/imagenet/vit_head_2_layers/checkpoint-19.pth'\n",
    "model_name = 'mae_vit_large_patch16'\n",
    "norm_pix_loss=True\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "411b8331",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using dataset /scratch/data/imagenet_c/shot_noise/5 with 50\n",
      "Class classifier_pred.bias 68.0 70.0 /scratch/data/imagenet_c/shot_noise/5\n",
      "Using dataset /scratch/data/imagenet_c/shot_noise/5 with 50\n",
      "Class classifier_pred.bias 78.0 82.0 /scratch/data/imagenet_c/shot_noise/5\n",
      "Using dataset /scratch/data/imagenet_c/shot_noise/5 with 50\n",
      "Class classifier_pred.bias 80.0 82.0 /scratch/data/imagenet_c/shot_noise/5\n",
      "Using dataset /scratch/data/imagenet_c/shot_noise/5 with 50\n",
      "Class classifier_pred.bias 94.0 94.0 /scratch/data/imagenet_c/shot_noise/5\n",
      "Using dataset /scratch/data/imagenet_c/shot_noise/5 with 50\n"
     ]
    }
   ],
   "source": [
    "# Test initial results for imagenet class I\n",
    "# class_numbers = [599, 546, 577, 949, 107] # Top five:\n",
    "# class_names = ['honeycomb', 'electric guitar', 'gong, tam-tam', 'strawberry', 'jellyfish']\n",
    "\n",
    "class_numbers = [107, 597, 116, 25, 940]\n",
    "class_names = ['Jellyfish', 'Holster', 'Chiton', 'European Fire Salamander', 'Spaghetti squash']\n",
    "# class_numbers = [940]\n",
    "# class_names = ['Spaghetti squash']\n",
    "\n",
    "for name, class_number in zip(class_names, class_numbers):\n",
    "    minimizer = (50*class_number + np.arange(50)).astype(np.int64)\n",
    "    transform_val = transforms.Compose([\n",
    "            transforms.Resize(256, interpolation=3),\n",
    "            transforms.CenterCrop(224),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n",
    "    transform_train = transforms.Compose([\n",
    "            transforms.Resize(256, interpolation=3),\n",
    "            transforms.CenterCrop(224),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n",
    "    dataset_train = tt_image_folder.ExtendedImageFolder(data_path, transform=transform_train, \n",
    "                                                        batch_size=batch_size, \n",
    "                                                        steps_per_example=steps_per_example,\n",
    "                                                        minimizer=minimizer)\n",
    "    dataset_val = tt_image_folder.ExtendedImageFolder(data_path, transform=transform_val,\n",
    "                                                     minimizer=minimizer)\n",
    "    num_classes = 1000\n",
    "    print(f'Using dataset {data_path} with {len(dataset_val)}')\n",
    "    dataset_val = iter(torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=2))\n",
    "    dataset_train = iter(torch.utils.data.DataLoader(dataset_train, batch_size=1, shuffle=False, num_workers=2))\n",
    "    # Build the model:\n",
    "    \n",
    "\n",
    "    model = models_mae_shared.__dict__[model_name](num_classes=1000, \n",
    "                                                   head_type='vit_head', \n",
    "                                                   norm_pix_loss=norm_pix_loss, \n",
    "                                                   classifier_depth=12)\n",
    "    model_checkpoint = torch.load(resume_model, map_location='cpu')['model']\n",
    "    head_checkpoint = torch.load(resume_finetune, map_location='cpu')['model']\n",
    "\n",
    "    for key in head_checkpoint:\n",
    "        if key.startswith('classifier'):\n",
    "            model_checkpoint[key] = head_checkpoint[key]\n",
    "    model.load_state_dict(model_checkpoint)\n",
    "    model = model.to(device)\n",
    "    for name, p in model.named_parameters():\n",
    "        if name.startswith('decoder'):\n",
    "            p.requires_grad = False\n",
    "    parameters = [p for p in model.parameters() if p.requires_grad]\n",
    "    optimizer = torch.optim.SGD(parameters, lr=lr, momentum=0.9)\n",
    "    optimizer.zero_grad()\n",
    "    loss_scaler = NativeScaler()\n",
    "    \n",
    "    all_results = []\n",
    "    baselines = []\n",
    "    # Run the model:\n",
    "    for test_samples, labels in dataset_val:\n",
    "        test_samples = test_samples.to(device, non_blocking=True)[0]\n",
    "        for data_iter_step in range(steps_per_example):\n",
    "            optimizer.zero_grad()\n",
    "            with torch.no_grad():\n",
    "                model.eval()\n",
    "                labels = torch.LongTensor([labels]).to(device, non_blocking=True)\n",
    "                loss_dict, _, _, pred = model(test_samples, labels, mask_ratio=0)\n",
    "                (acc1, acc5) = accuracy(pred, labels, topk=(1, 5))\n",
    "                if data_iter_step == 0:\n",
    "                    baselines.append(acc1.detach().cpu().numpy())\n",
    "                elif data_iter_step == steps_per_example - 1:\n",
    "                    all_results.append(acc1.detach().cpu().numpy())\n",
    "                # print('step: {}, acc {}, loss {}'.format(data_iter_step, acc1.detach().cpu().numpy(), loss_dict))\n",
    "            model.train(True)    \n",
    "            samples, _ = next(dataset_train)\n",
    "            samples = samples.to(device, non_blocking=True)[0]\n",
    "            loss_dict, _, latent, pred = model(samples, None, mask_ratio=mask_ratio)\n",
    "            loss = torch.stack([loss_dict[l] for l in loss_dict]).sum()\n",
    "            loss_value = loss.item()\n",
    "            if not math.isfinite(loss_value):\n",
    "                print(\"Loss is {}, stopping training\".format(loss_value))\n",
    "                sys.exit(1)\n",
    "            loss_scaler(loss, optimizer, parameters=model.parameters(),\n",
    "                        update_grad=True)\n",
    "            optimizer.zero_grad()\n",
    "    print('Class', name, np.mean(baselines), np.mean(all_results), data_path) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc005008",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aed8ea75",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43c31afd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c113f8c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
