{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9778dd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright authors of TSPulse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "58e51518-ee55-434e-abd2-aab5462a4ff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import numpy as np\n",
    "from torch.utils.data import DataLoader, default_collate\n",
    "import faiss\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "from models.tspulse import TSPulseForReconstruction\n",
    "from search_utils import SyntheticDataset, QueryDataset, RetrievedData, scaling, get_embedding, calc_precision_k"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c660825c-105a-4465-95e0-009a88d7d1c2",
   "metadata": {},
   "source": [
    "### Load TSPulse model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ea6670b9-aaca-4382-afa6-0e870a28b3f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at ../../model-binaries/tspulse_hybrid_sign20/tspulse_model were not used when initializing TSPulseForReconstruction: ['backbone.time_masker.mask_token']\n",
      "- This IS expected if you are initializing TSPulseForReconstruction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing TSPulseForReconstruction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "device = \"cpu\"\n",
    "channel_idx = 0  # univariate\n",
    "batch_size = 128\n",
    "\n",
    "model_path = \"../../model-binaries/tspulse_hybrid_sign20/tspulse_model\"\n",
    "model = TSPulseForReconstruction.from_pretrained(model_path, num_input_channels=1, mask_ratio=0).to(device)\n",
    "model.eval();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c215da0-d0c9-438b-9059-ec328c3a9773",
   "metadata": {},
   "source": [
    "### Generate data and create index set of embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4982a553-6ae9-4d0a-97b1-1998a720900c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "096f0635c83e4528b1f7b2aa9692f1b2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/14 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Generate train dataset and query dataset\n",
    "train_dataset = SyntheticDataset()\n",
    "# 20% random shift, 20% scaling of magnitude, 10% noise\n",
    "test_dataset = QueryDataset(train_dataset, max_shift=0.2, max_scale=0.2, noise_ratio=0.1)\n",
    "\n",
    "dataloader = DataLoader(\n",
    "    train_dataset, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=False if device == \"cpu\" else True\n",
    ")\n",
    "\n",
    "# compute embeddings for train dataset\n",
    "embeddings = []\n",
    "metadata = []\n",
    "for batch_idx, batch in enumerate(tqdm(dataloader, total=len(dataloader))):\n",
    "    embedding = get_embedding(device, batch, model)\n",
    "    embeddings.append(embedding.cpu().numpy())\n",
    "train_embeddings = np.concatenate(embeddings)\n",
    "\n",
    "# create index set of embeddings\n",
    "embs = train_embeddings[:, channel_idx, :]\n",
    "d = embs.shape[1]\n",
    "index = faiss.IndexFlatL2(d)\n",
    "index.add(embs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f8941e7-95b0-432a-94a8-c0c782b7c4b7",
   "metadata": {},
   "source": [
    "### Search query from index set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b7ae3e4f-e744-450d-84a8-f321804838d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c3e0810e12c24edcb4c00a7f46599659",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/14 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataloader = DataLoader(\n",
    "    test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=False if device == \"cpu\" else True\n",
    ")\n",
    "D_all, I_all = [], []\n",
    "for batch_idx, batch in enumerate(tqdm(dataloader)):\n",
    "    test_embedding = get_embedding(device, batch, model)\n",
    "    query_vector = test_embedding[:, channel_idx, :]  # [1, D]\n",
    "    D, I = index.search(query_vector.cpu(), k=3)\n",
    "    D_all.extend(D)\n",
    "    I_all.extend(I)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f29dd87e-fd27-4082-bdd2-70b887c77f4f",
   "metadata": {},
   "source": [
    "### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5c9e015f-8531-4848-aeee-2a9ec58ed3b7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "28bef79a0e90462b822d8db509cfa5e7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/14 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "family_match: PREC@3=0.713\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f73322e617b74749bcb91bb1a5100083",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/14 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fine_grained_match: PREC@3=0.695\n"
     ]
    }
   ],
   "source": [
    "for level in [\"family_match\", \"fine_grained_match\"]:\n",
    "    retrieveddata = RetrievedData(train_dataset, test_dataset, I_all, level, max_k=3)\n",
    "    dataloader = DataLoader(\n",
    "        retrieveddata,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=False,\n",
    "        drop_last=False,\n",
    "        pin_memory=False if device == \"cpu\" else True,\n",
    "        collate_fn=default_collate,\n",
    "    )\n",
    "    cmp = []\n",
    "    for batch_idx, batch in enumerate(tqdm(dataloader)):\n",
    "        label_test = batch[\"label_test\"]\n",
    "        labels_train = batch[\"labels_train\"]\n",
    "        cmp.append(label_test[:, None] == labels_train)\n",
    "\n",
    "    cmp = np.concatenate(cmp, axis=0)\n",
    "    prec_k = calc_precision_k(cmp, k=3)\n",
    "\n",
    "    print(f\"{level}: PREC@3={prec_k:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16463239-5b99-416d-bbca-7a9ca99de1ef",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
