{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ee157b02",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import datetime\n",
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "import time\n",
    "from pathlib import Path\n",
    "import torch\n",
    "import torch.backends.cudnn as cudnn\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torchvision.transforms as transforms\n",
    "import timm\n",
    "from torchvision import datasets\n",
    "import util.misc as misc\n",
    "from util.misc import NativeScalerWithGradNormCount as NativeScaler\n",
    "import timm.optim.optim_factory as optim_factory\n",
    "import models_mae_shared\n",
    "from engine_pretrain import accuracy\n",
    "from einops import repeat\n",
    "import tqdm\n",
    "import math\n",
    "from scipy import stats\n",
    "from data import tt_image_folder\n",
    "from matplotlib import pyplot as plt\n",
    "from mae_utils import generate_encoder_attention_maps\n",
    "\n",
    "device = 'cuda'\n",
    "# data_path = '/home/group/ilsvrc/val'\n",
    "data_path = '/scratch/imagenet_c/imagenet_c/jpeg_compression/5'\n",
    "num_workers = 8\n",
    "pin_mem = True\n",
    "lr = 5.00e-06\n",
    "weight_decay = 0.05\n",
    "train_steps = 10\n",
    "mask_ratio = 0.75\n",
    "steps_per_example = 32\n",
    "input_size = 224\n",
    "batch_size = 32\n",
    "\n",
    "imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
    "imagenet_std = np.array([0.229, 0.224, 0.225])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "411b8331",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using dataset /scratch/imagenet_c/imagenet_c/jpeg_compression/5 with 5120000\n"
     ]
    }
   ],
   "source": [
    "minimizer = np.load('/home/code/test_time_training/ttt_mae/models/perm.npy') \n",
    "transform_val = transforms.Compose([\n",
    "        transforms.Resize(256, interpolation=3),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n",
    "transform_train = transforms.Compose([\n",
    "        transforms.RandomResizedCrop(input_size, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic\n",
    "        transforms.RandomHorizontalFlip(),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n",
    "dataset_train = tt_image_folder.ExtendedImageFolder(data_path, transform=transform_train, \n",
    "                                                    batch_size=batch_size, steps_per_example=steps_per_example, minimizer=minimizer)\n",
    "dataset_val = tt_image_folder.ExtendedImageFolder(data_path, transform=transform_val, batch_size=batch_size, steps_per_example=steps_per_example, minimizer=minimizer)\n",
    "num_classes = 1000\n",
    "print(f'Using dataset {data_path} with {len(dataset_val)}')\n",
    "dataset_val = iter(torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=2))\n",
    "dataset_train = iter(torch.utils.data.DataLoader(dataset_train, batch_size=1, shuffle=False, num_workers=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "301875aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "resume_model = '/home/code/deep_transformer_prior/mae_orig/demo/mae_visualize_vit_large.pth'\n",
    "resume_finetune = '/home/code/deep_transformer_prior/mae_orig/output_dir/mae_orig_finetune_with_decoder/checkpoint-89.pth'\n",
    "model_name = 'mae_vit_large_patch16'\n",
    "\n",
    "model = models_mae_shared.__dict__[model_name]()\n",
    "model_checkpoint = torch.load(resume_model, map_location='cpu')['model']\n",
    "head_checkpoint = torch.load(resume_finetune, map_location='cpu')['model']\n",
    "\n",
    "model_checkpoint[\"bn.running_mean\"] = head_checkpoint[\"head.0.running_mean\"]\n",
    "model_checkpoint[\"bn.running_var\"] = head_checkpoint[\"head.0.running_var\"]\n",
    "model_checkpoint[\"head.weight\"] = head_checkpoint[\"head.1.weight\"]\n",
    "model_checkpoint[\"head.bias\"] = head_checkpoint[\"head.1.bias\"]\n",
    "model.load_state_dict(model_checkpoint)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1a5b48d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, p in model.named_parameters():\n",
    "    if name.startswith('decoder'):\n",
    "        p.requires_grad = False\n",
    "parameters = [p for p in model.parameters() if p.requires_grad]\n",
    "optimizer = torch.optim.AdamW(parameters, lr=lr, betas=(0.9, 0.95))\n",
    "optimizer.zero_grad()\n",
    "loss_scaler = NativeScaler()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "53cc716d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [30:43<00:00,  2.71it/s]\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "all_results = []\n",
    "for i in tqdm.trange(5000):\n",
    "    (test_samples, labels) = next(dataset_train)\n",
    "    test_samples = test_samples.to(device, non_blocking=True)[0]\n",
    "    with torch.no_grad():\n",
    "        labels = labels.to(device, non_blocking=True)\n",
    "        labels = labels.repeat(test_samples.shape[0])\n",
    "        latent, _, _ = model.forward_encoder(test_samples, mask_ratio=0.05, input_mask=None)\n",
    "        #print(latent.shape)\n",
    "#         latent = torch.mean(latent[:, 0], axis=0, keepdim=True)\n",
    "        pred, _ = model.forward_head(latent[:, 0], labels)\n",
    "        all_acc = []\n",
    "#         for i in range(batch_size):\n",
    "#             (acc1, acc5) = accuracy(pred[i:i+1], labels[i:i+1], topk=(1, 5))\n",
    "#             all_acc.append(acc1.detach().cpu().numpy())\n",
    "        all_results.append((stats.mode(pred.argmax(axis=1).detach().cpu().numpy()).mode[0] == labels[0].cpu().detach().numpy()) * 100.)\n",
    "#         all_results.append((np.sum(all_acc) > batch_size * 50) * 100)\n",
    "#         print('step: {}, acc {}, loss {}'.format(, acc1.detach().cpu().numpy(), loss_dict))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2c90bb15",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "38.28\n"
     ]
    }
   ],
   "source": [
    "print(np.mean(all_results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "93084308",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Same CROP:\n",
    "\n",
    "# One random input %75: 27.28\n",
    "# What it was trained for: 38.86\n",
    "# Ensemble (32, 75%): 26.74 (mean)\n",
    "# Ensemble (32, 50%): 36.22 (mean)\n",
    "# Ensemble (32, 25%): 38.98 (mean)\n",
    "# Ensemble (32, 10%): 39.02 (mean)\n",
    "# Ensemble (32,  5%): 39.12 (mean)\n",
    "# Ensemble (32, 50%): 40.9 (max)\n",
    "# Ensemble (32, 5%): 43.28 (max)\n",
    "\n",
    "# Different CROPs\n",
    "# Ensemble (32, 5%): 35.96 (mean)\n",
    "# Ensemble (32, 50%): 35.26 (mean)\n",
    "# Ensemble (32, 5%): 38.28 (max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22e37af1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd047a03",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c22da2f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7781ff1c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
