{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8cbf61e0-145e-4d7c-819d-47d97bb58e62",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment line below to install exlib\n",
    "# !pip install exlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81e9f56d-6d72-43ec-b23f-617b0d3f7be0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import yaml\n",
    "import argparse\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import sys\n",
    "sys.path.insert(0, \"../../src\")\n",
    "import exlib\n",
    "import math\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from datasets import load_dataset\n",
    "from collections import namedtuple\n",
    "from exlib.datasets.supernova import SupernovaDataset, SupernovaClsModel, SupernovaFixScore, get_supernova_scores\n",
    "from exlib.datasets.supernova_helper import *\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "# Baselines\n",
    "from exlib.features.time_series import *\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac1f3a45-00b0-473b-82a5-1f2d04a7b187",
   "metadata": {},
   "source": [
    "### Overview\n",
    "* The objective is to classify astronomical sources that vary with time into different classes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b440c7c7-6959-4968-95f3-4b8ac902923c",
   "metadata": {},
   "source": [
    "### Load datasets and pre-trained models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7c09859e-2bbd-4dda-9c89-d5785612c409",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num labels: 14\n",
      "Using Fourier PE\n",
      "classifier dropout: 0.2\n"
     ]
    }
   ],
   "source": [
    "test_dataset = SupernovaDataset(data_dir = \"anonymized-dataset\", split=\"test\")\n",
    "model = SupernovaClsModel(model_path = \"anonymized-model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f567d39-ad30-4821-b953-fddef4696a4f",
   "metadata": {},
   "source": [
    "### Model prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "280f4d98-0c0d-4023-a2ad-21c7b97ee9f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "original dataset size: 792\n",
      "remove nans dataset size: 792\n"
     ]
    }
   ],
   "source": [
    "model = model.to(device)\n",
    "test_dataloader = create_test_dataloader(\n",
    "    dataset=test_dataset,\n",
    "    batch_size=5,\n",
    "    compute_loss=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "263e9d59-1528-45dd-ab05-5dd94bac1258",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████| 159/159 [00:01<00:00, 131.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy: 0.7967171717171717\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# model prediction\n",
    "with torch.no_grad():\n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "    alignment_scores_all = []\n",
    "    for bi, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items() if k != \"objid\"}\n",
    "        outputs = model(**batch)\n",
    "        y_true.extend(batch['labels'].cpu().numpy())\n",
    "        y_pred.extend(torch.argmax(outputs.logits, dim=2).squeeze().cpu().numpy())\n",
    "# model prediction\n",
    "print(f\"accuracy: {sum([1 for i, j in zip(y_true, y_pred) if i == j]) / len(y_true)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f51151e-a6d5-43db-b9d4-1d7ab68eea0f",
   "metadata": {},
   "source": [
    "### Feature alignment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9446606d-46c6-46e7-b3a3-b306c769e111",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "original dataset size: 792\n",
      "remove nans dataset size: 792\n"
     ]
    }
   ],
   "source": [
    "test_dataloader = create_test_dataloader_raw(\n",
    "    dataset=test_dataset,\n",
    "    batch_size=5,\n",
    "    compute_loss=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26a073ab-5965-4598-b724-353c2bdd50f6",
   "metadata": {},
   "source": [
    "### Baselines\n",
    "- Identity\n",
    "- Random\n",
    "- 5 slices\n",
    "- 10 slices\n",
    "- 15 slices\n",
    "- Clustering\n",
    "- Archipelago"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7a0f0293-fb24-4352-96d1-6a876a00c49e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "original dataset size: 792\n",
      "remove nans dataset size: 792\n",
      "num labels: 14\n",
      "Using Fourier PE\n",
      "classifier dropout: 0.2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 159/159 [04:37<00:00,  1.74s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Avg alignment of identity features: 0.0152\n",
      "Avg alignment of random features: 0.0358\n",
      "Avg alignment of 5 features: 0.0337\n",
      "Avg alignment of 10 features: 0.0555\n",
      "Avg alignment of 15 features: 0.0554\n",
      "Avg alignment of clustering features: 0.2622\n",
      "Avg alignment of archipelago features: 0.2563\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "scores = get_supernova_scores(batch_size = 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "813af631-b74e-4121-823f-4e32c40a78e7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
