{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for sample in samples:\n",
    "#     sample[\"conditions\"] = sample[\"conditions\"][0]\n",
    "#     sample[\"procedures\"] = sample[\"procedures\"][0]\n",
    "#     sample[\"drugs\"] = sample[\"drugs\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK = \"mortality\"\n",
    "DATASET = \"mimic3\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyhealth.datasets import SampleEHRDataset\n",
    "import json\n",
    "\n",
    "with open(f\"/shared/eng/pj20/kelpie_exp_data/ehr_data/{DATASET}_{TASK}_samples_train.json\", \"r\") as f:\n",
    "    samples_train = json.load(f)\n",
    "with open(f\"/shared/eng/pj20/kelpie_exp_data/ehr_data/{DATASET}_{TASK}_samples_test.json\", \"r\") as f:\n",
    "    samples_test = json.load(f)\n",
    "\n",
    "\n",
    "dataset_train = SampleEHRDataset(samples_train, dataset_name=DATASET, task_name=TASK)\n",
    "dataset_test = SampleEHRDataset(samples_test, dataset_name=DATASET, task_name=TASK)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(dataset_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyhealth.datasets import split_by_patient, get_dataloader\n",
    "\n",
    "# train_dataset, val_dataset, test_dataset = split_by_patient(\n",
    "#     dataset, [0.8, 0.1, 0.1], seed=528\n",
    "# )\n",
    "train_dataloader = get_dataloader(dataset_train, batch_size=32, shuffle=True)\n",
    "test_dataloader = get_dataloader(dataset_test, batch_size=32, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyhealth.trainer import Trainer\n",
    "from pyhealth.models import Deepr, AdaCare, StageNet, GRASP, Transformer, RETAIN, RNN\n",
    "import os\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"5\"\n",
    "\n",
    "model = GRASP(\n",
    "    dataset=dataset_train,\n",
    "    feature_keys=[\"conditions\", \"procedures\", \"drugs\"],\n",
    "    label_key=\"label\",\n",
    "    mode=\"binary\",\n",
    "    use_embedding=[True, True, True],\n",
    "    embedding_dim=128,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(model=model, metrics=['accuracy', 'f1', 'pr_auc', 'sensitivity', 'specificity'])\n",
    "trainer.train(\n",
    "    train_dataloader=train_dataloader,\n",
    "    val_dataloader=test_dataloader,\n",
    "    epochs=15,\n",
    "    optimizer_params = {\"lr\": 1e-3},\n",
    "    monitor=\"f1\",\n",
    ")\n",
    "\n",
    "# STEP 5: evaluate\n",
    "print(trainer.evaluate(test_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pyhealth",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
