{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Benchmarking Self-supervised Vision Transformers in Astronomy\n",
    "This is a demo using to evaluate our fine-tuned C-MAE models.\n",
    "\n",
    "## Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import Dataset\n",
    "import models_vit\n",
    "\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AverageMeter(object):\n",
    "    \"\"\"Computes and stores the average and current value\"\"\"\n",
    "    def __init__(self):\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        self.val = 0\n",
    "        self.avg = 0\n",
    "        self.sum = 0\n",
    "        self.count = 0\n",
    "\n",
    "    def update(self, val, n=1):\n",
    "        self.val = val\n",
    "        self.sum += val * n\n",
    "        self.count += n\n",
    "        self.avg = self.sum / self.count\n",
    "\n",
    "\n",
    "def accuracy(output, target, topk=(1,)):\n",
    "    \"\"\"Computes the accuracy over the k top predictions for the specified values of k\"\"\"\n",
    "    with torch.no_grad():\n",
    "        maxk = max(topk)\n",
    "        batch_size = target.size(0)\n",
    "\n",
    "        _, pred = output.topk(maxk, 1, True, True)\n",
    "        pred = pred.t()\n",
    "        correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
    "\n",
    "        res = []\n",
    "        for k in topk:\n",
    "            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n",
    "            res.append(correct_k.mul_(1.0 / batch_size))\n",
    "        return res\n",
    "\n",
    "\n",
    "# eval model\n",
    "def evalmodel(model, test_data):\n",
    "    val_acc1 = AverageMeter()\n",
    "    \n",
    "    # model to evaluate mode\n",
    "    model.eval()\n",
    "\n",
    "    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False,\n",
    "                                 num_workers=8, pin_memory=True)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for step, (images, labels) in enumerate(test_dataloader):\n",
    "            images, labels = images.cuda(), labels.cuda()\n",
    "            # compute output\n",
    "            pred = model(images)\n",
    "\n",
    "            # measure accuracy and record loss\n",
    "            acc1 = accuracy(pred, labels, topk=(1, ))\n",
    "            val_acc1.update(acc1[0], images.size(0))\n",
    "    \n",
    "    return val_acc1.avg"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. galaxy-desi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---------galaxy-desi---------\n",
      "16000 3999 8\n",
      "Evaluate ViT-L on galaxy-desi.\n",
      "load weights:  <All keys matched successfully>\n",
      "ViT-L acc: 0.8837209939956665\n",
      "Evaluate ViT-H on galaxy-desi.\n",
      "load weights:  <All keys matched successfully>\n",
      "ViT-H acc: 0.8909727931022644\n",
      "These results are shown in Table 2.\n"
     ]
    }
   ],
   "source": [
    "print('---------galaxy-desi---------')\n",
    "# prepare dataset and models\n",
    "train_dataset, val_dataset, num_class = Dataset.__dict__[\"galaxy\"](args=None)\n",
    "print(len(train_dataset), len(val_dataset), num_class)\n",
    "\n",
    "\n",
    "# ViT-Large\n",
    "vit_l = models_vit.__dict__[\"vit_large_patch16\"](\n",
    "    img_size=224,\n",
    "    num_classes=num_class, \n",
    "    drop_path_rate=0.1, \n",
    "    global_pool=True,\n",
    ")\n",
    "\n",
    "weights_l = \"./ft/galaxy-desi/large/weights/acc_0.8838_ckpt.pth\"\n",
    "ckpt_l = torch.load(weights_l, map_location=\"cpu\")[\"model\"]\n",
    "print(\"Evaluate ViT-L on galaxy-desi.\")\n",
    "msg = vit_l.load_state_dict(ckpt_l, strict=False)\n",
    "print(\"load weights: \", msg)\n",
    "acc = evalmodel(vit_l.cuda(), val_dataset)\n",
    "print(\"ViT-L acc: {}\".format(acc[0]))\n",
    "\n",
    "\n",
    "# ViT-Huge\n",
    "vit_h = models_vit.__dict__[\"vit_huge_patch14\"](\n",
    "    img_size=224,\n",
    "    num_classes=num_class, \n",
    "    drop_path_rate=0.1, \n",
    "    global_pool=True,\n",
    ")\n",
    "\n",
    "weights_h = \"./ft/galaxy-desi/huge/weights/acc_0.8910_ckpt.pth\"\n",
    "ckpt_h = torch.load(weights_h, map_location=\"cpu\")[\"model\"]\n",
    "print(\"Evaluate ViT-H on galaxy-desi.\")\n",
    "msg = vit_h.load_state_dict(ckpt_h, strict=False)\n",
    "print(\"load weights: \", msg)\n",
    "acc = evalmodel(vit_h.cuda(), val_dataset)\n",
    "print(\"ViT-H acc: {}\".format(acc[0]))\n",
    "\n",
    "\n",
    "print(\"These results are shown in Table 2.\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. galaxy-sdss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---------galaxy-sdss---------\n",
      "23037 5754 5\n",
      "Evaluate ViT-B on galaxy-sdss.\n",
      "load weights:  <All keys matched successfully>\n",
      "ViT-B acc: 0.960027813911438\n",
      "Evaluate ViT-L on galaxy-sdss.\n",
      "load weights:  <All keys matched successfully>\n",
      "ViT-L acc: 0.9575946927070618\n",
      "Evaluate ViT-H on galaxy-sdss.\n",
      "load weights:  <All keys matched successfully>\n",
      "ViT-H acc: 0.9602015614509583\n",
      "These results are shown in Table 4.\n"
     ]
    }
   ],
   "source": [
    "print('---------galaxy-sdss---------')\n",
    "# prepare dataset and models\n",
    "train_dataset, val_dataset, num_class = Dataset.__dict__[\"galaxy_sdss\"](args=None)\n",
    "print(len(train_dataset), len(val_dataset), num_class)\n",
    "\n",
    "\n",
    "# ViT-Base\n",
    "vit_b = models_vit.__dict__[\"vit_base_patch16\"](\n",
    "    img_size=224,\n",
    "    num_classes=num_class, \n",
    "    drop_path_rate=0.1, \n",
    "    global_pool=True,\n",
    ")\n",
    "\n",
    "weights_b = \"./run/sdss/base_acc_0.9600/ckpt.pth\"\n",
    "ckpt_b = torch.load(weights_b, map_location=\"cpu\")[\"model\"]\n",
    "print(\"Evaluate ViT-B on galaxy-sdss.\")\n",
    "msg = vit_b.load_state_dict(ckpt_b, strict=False)\n",
    "print(\"load weights: \", msg)\n",
    "acc = evalmodel(vit_b.cuda(), val_dataset)\n",
    "print(\"ViT-B acc: {}\".format(acc[0]))\n",
    "\n",
    "\n",
    "# ViT-Large\n",
    "vit_l = models_vit.__dict__[\"vit_large_patch16\"](\n",
    "    img_size=224,\n",
    "    num_classes=num_class, \n",
    "    drop_path_rate=0.1, \n",
    "    global_pool=True,\n",
    ")\n",
    "\n",
    "weights_l = \"./run/sdss/large_acc_0.9576/ckpt.pth\"\n",
    "ckpt_l = torch.load(weights_l, map_location=\"cpu\")[\"model\"]\n",
    "print(\"Evaluate ViT-L on galaxy-sdss.\")\n",
    "msg = vit_l.load_state_dict(ckpt_l, strict=False)\n",
    "print(\"load weights: \", msg)\n",
    "acc = evalmodel(vit_l.cuda(), val_dataset)\n",
    "print(\"ViT-L acc: {}\".format(acc[0]))\n",
    "\n",
    "\n",
    "# ViT-Huge\n",
    "vit_h = models_vit.__dict__[\"vit_huge_patch14\"](\n",
    "    img_size=224,\n",
    "    num_classes=num_class, \n",
    "    drop_path_rate=0.1, \n",
    "    global_pool=True,\n",
    ")\n",
    "\n",
    "weights_h = \"./run/sdss/huges_acc_0.9602/ckpt.pth\"\n",
    "ckpt_h = torch.load(weights_h, map_location=\"cpu\")[\"model\"]\n",
    "print(\"Evaluate ViT-H on galaxy-sdss.\")\n",
    "msg = vit_h.load_state_dict(ckpt_h, strict=False)\n",
    "print(\"load weights: \", msg)\n",
    "acc = evalmodel(vit_h.cuda(), val_dataset)\n",
    "print(\"ViT-H acc: {}\".format(acc[0]))\n",
    "\n",
    "\n",
    "print(\"These results are shown in Table 4.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3fb296c518e7aea6f35410d2d75da0f136b9330ebf7cbb2c1bb9ff470df8a33e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
