{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = \"drugrec\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "if task == \"mortality\" or task == \"readmission\":\n",
    "    with open(\"/data/pj20/exp_data/ccscm_ccsproc_atc3/clusters_th015.json\", \"r\") as f:\n",
    "        clusters = json.load(f)\n",
    "elif task == \"lenofstay\" or task == \"drugrec\":\n",
    "    with open(\"/data/pj20/exp_data/ccscm_ccsproc/clusters_th015.json\", \"r\") as f:\n",
    "        clusters = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def cosine_similarity(u, v):\n",
    "    return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "from get_emb import embedding_retriever\n",
    "\n",
    "terms = None\n",
    "if task == \"mortality\":\n",
    "    terms = ['death', 'mortality', 'cause death', 'lead to death', 'high risk', \"deadly\"]\n",
    "elif task == \"readmission\":\n",
    "    terms = ['rehospitalization', 'readmission']\n",
    "elif task == \"lenofstay\":\n",
    "    terms = [\"length of stay'\", \"bed days\", \"time in hospital\"]\n",
    "elif task == \"drugrec\":\n",
    "    terms = [\"drug recommendation\", \"prescription\", \"drug\", \"medication\", \"treatment\"]\n",
    "\n",
    "term_embs = []\n",
    "\n",
    "for term in terms:\n",
    "    term_embs.append(embedding_retriever(term))\n",
    "\n",
    "tmp = {}\n",
    "\n",
    "for clus in clusters.keys():\n",
    "    tmp[clus] = 0\n",
    "    for term_emb in term_embs:\n",
    "        tmp[clus] += cosine_similarity(clusters[clus]['embedding'], term_emb)   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4.134378024520449"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "float(tmp['0'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_similarity = min(tmp.values())\n",
    "max_similarity = max(tmp.values())\n",
    "\n",
    "for clus in clusters.keys():\n",
    "    tmp[clus] = (float(tmp[clus]) - min_similarity) / (max_similarity - min_similarity)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([3.43326015])"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "min_similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "for clus in clusters.keys():\n",
    "    tmp[clus] = float(tmp[clus] ** 3)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max(tmp.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "for clus in clusters.keys():\n",
    "    if task == \"mortality\":\n",
    "        clusters[clus]['attention_mortality'] = tmp[clus]\n",
    "    elif task == \"readmission\":\n",
    "        clusters[clus]['attention_readmission'] = tmp[clus]\n",
    "    elif task == \"lenofstay\":\n",
    "        clusters[clus]['attention_lenofstay'] = tmp[clus]\n",
    "    elif task == \"drugrec\":\n",
    "        clusters[clus]['attention_drugrec'] = tmp[clus]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if task == \"mortality\" or task == \"readmission\":\n",
    "\n",
    "    with open(\"/data/pj20/exp_data/ccscm_ccsproc_atc3/clusters_th015.json\", \"w\") as f:\n",
    "        json.dump(clusters, f)\n",
    "\n",
    "elif task == \"lenofstay\" or task == \"drugrec\":\n",
    "    \n",
    "        with open(\"/data/pj20/exp_data/ccscm_ccsproc/clusters_th015.json\", \"w\") as f:\n",
    "            json.dump(clusters, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "\n",
    "if task == \"mortality\" or task == \"readmission\":\n",
    "    attn_file = f\"/data/pj20/exp_data/ccscm_ccsproc_atc3/attention_weights_{task}.pkl\"\n",
    "elif task == \"lenofstay\" or task == \"drugrec\":\n",
    "    attn_file = f\"/data/pj20/exp_data/ccscm_ccsproc/attention_weights_{task}.pkl\"\n",
    "attn = np.ndarray(shape=(len(clusters), 1))\n",
    "\n",
    "for i in range(len(clusters)):\n",
    "    idx = str(i)\n",
    "    attn[i] = clusters[idx][f'attention_{task}']\n",
    "\n",
    "with open(attn_file, \"wb\") as f:\n",
    "    pickle.dump(attn, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kgc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
