{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9bffc4a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import datetime\n",
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "import time\n",
    "from pathlib import Path\n",
    "import torch\n",
    "import torch.backends.cudnn as cudnn\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torchvision.transforms as transforms\n",
    "import timm\n",
    "from torchvision import datasets\n",
    "import util.misc as misc\n",
    "from util.misc import NativeScalerWithGradNormCount as NativeScaler\n",
    "import timm.optim.optim_factory as optim_factory\n",
    "import models_mae_shared\n",
    "from engine_pretrain import accuracy\n",
    "from einops import repeat\n",
    "import tqdm\n",
    "import math\n",
    "\n",
    "device = 'cuda'\n",
    "data_path = '/shared/group/ilsvrc/train'\n",
    "num_workers = 16\n",
    "pin_mem = True\n",
    "weight_decay = 0.05\n",
    "train_steps = 10\n",
    "batch_size = 128\n",
    "mask_ratio = 0.75\n",
    "input_size = 224"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c9656116",
   "metadata": {},
   "outputs": [],
   "source": [
    "resume_model = '/home/test_time_training/ttt_mae/models/cifar10/mae_visualize_vit_large.pth'\n",
    "resume_finetune = '/home/test_time_training/ttt_mae/models/imagenet/checkpoint-89.pth'\n",
    "model_name = 'mae_vit_large_patch16'\n",
    "\n",
    "model = models_mae_shared.__dict__[model_name]()\n",
    "model_checkpoint = torch.load(resume_model, map_location='cpu')['model']\n",
    "head_checkpoint = torch.load(resume_finetune, map_location='cpu')['model']\n",
    "\n",
    "model_checkpoint[\"bn.running_mean\"] = head_checkpoint[\"head.0.running_mean\"]\n",
    "model_checkpoint[\"bn.running_var\"] = head_checkpoint[\"head.0.running_var\"]\n",
    "model_checkpoint[\"head.weight\"] = head_checkpoint[\"head.1.weight\"]\n",
    "model_checkpoint[\"head.bias\"] = head_checkpoint[\"head.1.bias\"]\n",
    "model.load_state_dict(model_checkpoint)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "13ae42e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/anaconda3/envs/taming/lib/python3.8/site-packages/torchvision/transforms/transforms.py:332: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using dataset /shared/group/ilsvrc/train with 1281167\n"
     ]
    }
   ],
   "source": [
    "transform_train = transforms.Compose([\n",
    "        transforms.Resize(256, interpolation=3),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n",
    "dataset_train = datasets.ImageFolder(data_path, transform=transform_train)\n",
    "num_classes = 1000\n",
    "data_loader = torch.utils.data.DataLoader(\n",
    "        dataset_train, shuffle=False,\n",
    "        batch_size=batch_size,\n",
    "        num_workers=16,\n",
    "        pin_memory=False,\n",
    "        drop_last=False)\n",
    "print(f'Using dataset {data_path} with {len(dataset_train)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "49133ba4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "6000it [34:17,  1.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving more... 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "7000it [1:08:24,  2.23s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving more... 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "8000it [1:48:47,  2.48s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving more... 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "9000it [2:34:30,  2.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving more... 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "10000it [3:26:34,  3.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving more... 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "10010it [3:27:10,  1.24s/it]\n"
     ]
    }
   ],
   "source": [
    "counter = 2\n",
    "for index, (samples, labels) in tqdm.tqdm(enumerate(data_loader)):\n",
    "    # Get the samples:\n",
    "    if index < 5156: continue\n",
    "    samples = samples.to(device, non_blocking=True)\n",
    "    labels = labels.to(device, non_blocking=True)\n",
    "    with torch.no_grad():\n",
    "        _, _, latent, _ = model(samples, None, mask_ratio=0)\n",
    "    if index == 5156:\n",
    "        all_labels = labels.detach().cpu().numpy()\n",
    "        all_latents = latent.detach().cpu().numpy()\n",
    "    all_latents = np.concatenate([all_latents, latent.detach().cpu().numpy()])\n",
    "    all_labels = np.concatenate([all_labels, labels.detach().cpu().numpy()])\n",
    "    if index % 1000 == 0:\n",
    "        print('Saving more...', counter)\n",
    "        with open(f'/home/test_time_training/ttt_mae/models/imagenet/labels_{counter}.npy', 'wb') as w:\n",
    "            np.save(w, all_labels[:-1])\n",
    "        with open(f'/home/test_time_training/ttt_mae/models/imagenet/latents_{counter}.npy', 'wb') as w:\n",
    "            np.save(w, all_latents[:-1])\n",
    "        all_labels[:-1] = all_labels[-1:]\n",
    "        all_latents[:-1] = all_latents[-1:]\n",
    "        counter += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "264fc9a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f'/home/test_time_training/ttt_mae/models/imagenet/labels_{counter}.npy', 'wb') as w:\n",
    "    np.save(w, all_labels)\n",
    "with open(f'/home/test_time_training/ttt_mae/models/imagenet/latents_{counter}.npy', 'wb') as w:\n",
    "    np.save(w, all_latents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4ed1492c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(621327, 1024)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_latents.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a69d4a1a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5156.0"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cb84ab45",
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "invalid syntax (935587369.py, line 1)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Input \u001b[0;32mIn [8]\u001b[0;36m\u001b[0m\n\u001b[0;31m    392832 +\u001b[0m\n\u001b[0m             ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
     ]
    }
   ],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
