{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c8f9a130-fa86-46fb-ad0d-5d71019e0566",
   "metadata": {},
   "source": [
    "# Filter regLM generated promoters using the matched regression model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7ed5c74-e5c7-4d0d-b196-41959defd662",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "import torch\n",
    "from plotnine import *\n",
    "\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"2\"\n",
    "sys.path.append('../../')\n",
    "import scripts.viz, scripts.regression"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33d4d128-93e5-4c0b-86d5-f2dabfee0f46",
   "metadata": {},
   "source": [
    "## Load training set sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb5c7c4-6851-4339-b5c7-b30c2fd285cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('lm_data/train.csv', index_col=0, usecols=(0,1,4,5), dtype='str')\n",
    "train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac80604b-4ebd-4f47-a91b-2c73d35caf77",
   "metadata": {},
   "source": [
    "## Use the LM matched regression model to make predictions for real sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62a0efc8-457e-430d-8faf-feca39e93e1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "lm_model = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_joint/checkpoints/epoch=8-step=65331.ckpt')\n",
    "\n",
    "ds = scripts.regression.SeqDataset(train.seq.tolist())\n",
    "train[['lm_Complex','lm_Defined']] = lm_model.predict_on_dataset(\n",
    "    ds, batch_size=512, devices=[0], num_workers=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7740059f-11f3-4566-a7f6-0fd66ab0b48d",
   "metadata": {},
   "source": [
    "## Get mean and standard deviation of predicted activity for each token in each medium"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c203174-5303-4478-8f95-afaf03e5356f",
   "metadata": {},
   "outputs": [],
   "source": [
    "complex_means = train.groupby('complex_token').lm_Complex.aggregate(['mean', 'std']).reset_index()\n",
    "defined_means = train.groupby('defined_token').lm_Defined.aggregate(['mean', 'std']).reset_index()\n",
    "\n",
    "complex_means.columns = ['tok', 'mean', 'std']\n",
    "complex_means['cond'] = 'complex'\n",
    "defined_means.columns = ['tok', 'mean', 'std']\n",
    "defined_means['cond'] = 'defined'\n",
    "\n",
    "means = pd.concat([complex_means, defined_means])\n",
    "means['lower'] = means['mean'] - 2*means['std']\n",
    "means['upper'] = means['mean'] + 2*means['std']\n",
    "\n",
    "means.tok = means.tok.astype(str)\n",
    "means"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "917a98d9-80f8-4bfc-9ce5-e07560e20939",
   "metadata": {},
   "source": [
    "## Load generated sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1370ff6-677e-4073-84d0-01215ccc1714",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen = pd.read_csv('synthetic_promoters/lm.csv', index_col=0, dtype='str')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "049d837f-b3cf-4851-8d3a-7b97c7e8d3e7",
   "metadata": {},
   "source": [
    "## Use the LM regression model to make predictions for generated sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "106c1dba-4d35-46cc-a59a-6d88bb17ff81",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = scripts.regression.SeqDataset(gen.Sequence.tolist())\n",
    "gen[['lm_Complex','lm_Defined']] = lm_model.predict_on_dataset(\n",
    "    ds, batch_size=512, devices=[0], num_workers=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db10a456-707a-4dc8-82d2-056b7e907d98",
   "metadata": {},
   "source": [
    "## Filter generated sequences based on real"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0e2853c-ad9e-42ec-bfcd-6e796298072f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def in_dist(row, type='complex'):\n",
    "    curr_means = means[means.cond==type]\n",
    "    \n",
    "    if type == 'complex':\n",
    "        token = row.label[0]\n",
    "        pred = row.lm_Complex\n",
    "    else:\n",
    "        token = row.label[1]\n",
    "        pred = row.lm_Defined\n",
    "\n",
    "    lower = curr_means.loc[curr_means.tok==token, 'lower'].values[0]\n",
    "    upper = curr_means.loc[curr_means.tok==token, 'upper'].values[0]\n",
    "\n",
    "    return ((pred > lower) & (pred < upper))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cca76257-c532-4639-87c7-e456187ab164",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen['in_cdist'] = gen.apply(in_dist, axis=1, args=('complex',))\n",
    "gen['in_ddist'] = gen.apply(in_dist, axis=1, args=('defined',))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "489bb5ae-fd7a-4b8d-8103-bcddb757ec6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen[['in_cdist', 'in_ddist']].value_counts(normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e39f537-599a-46f7-bd95-2e9f8293abc1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "gen = gen[(gen.in_cdist) & (gen.in_ddist)].iloc[:, :2]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0101b887-112b-406b-9e7f-129b4cc91389",
   "metadata": {},
   "source": [
    "## Sample 100 sequences per category"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cf2ed4f-5f3a-4793-a69d-a401d744dfb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen = pd.concat([\n",
    "    gen[gen.label==x].sample(100) for x in [\"00\", \"11\", \"22\", \"33\", \"44\"]\n",
    "])\n",
    "gen.label.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b60028ce-cd9d-4864-aee4-ba05e90fabc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen.reset_index(drop=True).to_csv('synthetic_promoters/lm_filtered.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5148f86-50f8-4e16-b91a-9aa997456853",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
