{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "SEED = 17\n",
    "\n",
    "# modelNames = [\"HGDNA_2k\", \"HGDNA_32k\", \"hyenadna-medium-160k-seqlen-hf\", \"DNABERT-2-117M\", \"nucleotide-transformer-v2-100m-multi-species\", \"caduceus-ps_seqlen-131k_d_model-256_n_layer-16\"]\n",
    "modelNames = [\"HGDNA\"]\n",
    "dataNames = [\"species_16384_cls\"]\n",
    "labelDict = {\"human\": 0, \"lemur\": 1, \"mouse\": 2, \"pig\": 3, \"hippo\": 4}\n",
    "\n",
    "curDir = os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "process HGDNA-species_16384_cls done\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import json\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn import svm\n",
    "from sklearn.metrics import classification_report, matthews_corrcoef\n",
    "\n",
    "plot_settings_dir = os.path.join(curDir, \"plot_settings.json\")\n",
    "with open(plot_settings_dir, \"r\") as f:\n",
    "    plot_settings = json.load(f)\n",
    "\n",
    "plt.rcParams[\"font.family\"] = plot_settings[\"font\"]\n",
    "plt.rcParams[\"font.size\"] = plot_settings[\"fontsize\"]\n",
    "\n",
    "for dataName in dataNames:\n",
    "    for modelName in modelNames:\n",
    "        trainDict = torch.load(os.path.join(curDir, f\"result/{modelName}/{dataName}/train.emb\"), map_location=\"cpu\", weights_only=True)\n",
    "        testDict = torch.load(os.path.join(curDir, f\"result/{modelName}/{dataName}/test.emb\"), map_location=\"cpu\", weights_only=True)\n",
    "\n",
    "        trainEmb, testEmb = trainDict[\"pred\"].to(torch.float32).numpy(), testDict[\"pred\"].to(torch.float32).numpy()\n",
    "        trainEmbTime = trainDict[\"wall_clock\"]\n",
    "        trainLabel, testLabel = trainDict[\"label\"].numpy(), testDict[\"label\"].numpy()\n",
    "\n",
    "        # t-SNE plotting\n",
    "        tsne = TSNE(n_components=2, random_state=SEED, n_jobs=-1)\n",
    "        x_tsne = tsne.fit_transform(trainEmb)\n",
    "        plt.subplots(figsize=(8, 8))\n",
    "        scatter = plt.scatter(x_tsne[:, 0], x_tsne[:, 1], c=trainLabel, cmap=plt.get_cmap(\"tab10\"), s=10)\n",
    "\n",
    "        # remove x and y ticks\n",
    "        plt.xticks([])\n",
    "        plt.yticks([])\n",
    "        # remove top and right spines\n",
    "        plt.gca().spines[\"top\"].set_visible(False)\n",
    "        plt.gca().spines[\"bottom\"].set_visible(False)\n",
    "        plt.gca().spines[\"right\"].set_visible(False)\n",
    "        plt.gca().spines[\"left\"].set_visible(False)\n",
    "\n",
    "        plt.legend(handles=scatter.legend_elements()[0], labels=list(labelDict.keys()), loc=\"upper right\")\n",
    "        plt.savefig(os.path.join(curDir, f\"result/{modelName}/{dataName}_emb.pdf\"), dpi=600, format=\"pdf\")\n",
    "        plt.clf()\n",
    "        plt.close()\n",
    "\n",
    "        # SNV zero-shot\n",
    "        model = svm.SVC(random_state=SEED)\n",
    "        model.fit(trainEmb, trainLabel)\n",
    "        predLabel = model.predict(testEmb)\n",
    "\n",
    "        res = classification_report(testLabel, predLabel, target_names=list(labelDict.keys()), output_dict=True)\n",
    "        mcc = matthews_corrcoef(testLabel, predLabel)\n",
    "\n",
    "        # save to json\n",
    "        res[\"MCC\"] = mcc\n",
    "        res[\"wall_clock\"] = trainEmbTime\n",
    "        with open(os.path.join(curDir, f\"result/{modelName}/{dataName}_SNV.json\"), \"w\") as f:\n",
    "            f.write(json.dumps(res, indent=4))\n",
    "        \n",
    "        print(f\"process {modelName}-{dataName} done\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fla",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
