{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fd3bf76-0dfc-4b36-aaf0-e24ebc14be3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "import torch\n",
    "from plotnine import *\n",
    "from scipy.stats import mannwhitneyu\n",
    "\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "sys.path.append(\"../../regLM/\")\n",
    "import reglm.dataset, reglm.lightning\n",
    "sys.path.append(\"../..\")\n",
    "import scripts.viz\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5de159a1-660d-463f-9b79-410dab5a9436",
   "metadata": {},
   "source": [
    "## Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83ed70f1-70b6-4efb-9ef1-d82dfa970064",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = reglm.lightning.LightningModel.load_from_checkpoint(\n",
    "    'lightning_logs/version_10/checkpoints/epoch=9-step=580648.ckpt').to(torch.device(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fd76087-30fb-432f-808d-0a7ff42dfb0e",
   "metadata": {},
   "source": [
    "## Load test set sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf3ad4a4-30c2-4498-a597-1b060320a902",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv('lm_data/test.csv', index_col=0, usecols=(0, 1, 6), dtype=str)\n",
    "test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd666afe-93e2-4bf3-890a-52b971e22877",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71ef259c-359a-4d06-840e-5dac87209eee",
   "metadata": {},
   "source": [
    "## Accuracy per class on test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "077fddc4-7eb2-4249-bd3a-81e5936e1029",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = reglm.dataset.CharDataset(seqs = test.seq.tolist(), labels=test.label.tolist(), rc=False)\n",
    "test['acc'] = model.compute_accuracy_on_dataset(\n",
    "    ds, batch_size=8, multidim_average='samplewise').squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "552b9b90-49e7-440f-b515-d56a6f32f00a",
   "metadata": {},
   "outputs": [],
   "source": [
    "test['acc'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "413843ca-99e4-4e49-8cc0-167c3c603c7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "mannwhitneyu(test.acc[test.label=='44'].sample(800), test.acc[test.label!='44'].sample(800), \n",
    "             alternative='greater').pvalue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0c6dae9-44fa-4929-bc89-f876c0a8ab6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.boxplot(\n",
    "    test[test.label.isin([\"00\", \"11\", \"22\", \"33\", \"44\"])],\n",
    "    value_col='acc'\n",
    ")+ylab(\"Nucleotide prediction accuracy\")+\\\n",
    "annotate('text', x=2.6, y = 0.6, label='p(44 vs. others) < 1e-250', size=9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5810e804-bb70-4be5-b9ab-7d2833277dc3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
