{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to load tokens from patient data\n",
    "\n",
    "Generally we start off with some hydra command that defines where data is stored and the general model training recipe. Here we directly convert this to a python configuration so we can tinker with everything locally.\n",
    "\n",
    "\n",
    "If you want to prepend static data to the sequences, add `data.do_prepend_static_data=true` as an arg in the cmd variable.\n",
    "\n",
    "By default, the dataset extracts subsequences of limited length from the full patient data. However, setting `data.max_seq_len=1000000` effectively disables this behavior by making the maximum sequence length larger than any actual patient sequence, ensuring you get the complete data for each patient.\n",
    "\n",
    "Add `model.return_logits=true` to return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shlex\n",
    "\n",
    "import hydra\n",
    "import polars as pl\n",
    "\n",
    "from meds_torch.latest_dir import get_latest_directory\n",
    "\n",
    "# We need to import the resolvers so the cfg object has access to them\n",
    "from meds_torch.utils.resolvers import setup_resolvers\n",
    "\n",
    "setup_resolvers()\n",
    "\n",
    "ROOT_DIR = \"/storage/shared/mimic-iv/meds_v0.3.2/\"  # Replace with your actual root directory\n",
    "PRETRAIN_OUTPUT_DIR = f\"{ROOT_DIR}/results/zero_shot/eic_hparam_sweep\"\n",
    "MODEL_SWEEP_DIR = get_latest_directory(PRETRAIN_OUTPUT_DIR)\n",
    "BEST_CHECKPOINT = f\"{MODEL_SWEEP_DIR}/checkpoints/best_model.ckpt\"\n",
    "BEST_CONFIG = f\"{MODEL_SWEEP_DIR}/best_config.json\"\n",
    "MEDS_DIR = f\"{ROOT_DIR}/meds/\"\n",
    "TENSOR_DIR = f\"{ROOT_DIR}/eic_tensors/\"\n",
    "N = 10\n",
    "M = 10\n",
    "OUTPUT_DIR = f\"{ROOT_DIR}/results/zero_shot/inference/eic/debug/\"\n",
    "TUTORIAL_DIR = os.getcwd()\n",
    "\n",
    "cmd = f\"\"\"\n",
    "meds-torch-generate model=eic_forecasting experiment=eic_forecast_mtr \\\n",
    "    model.max_tokens_budget={M} data.subsequence_sampling_strategy=from_start data.max_seq_len=1000000 \\\n",
    "    data.dataloader.batch_size=512 model.generate_id=0 trainer.devices=[0] data.predict_dataset=test \\\n",
    "    data.do_include_subject_id=true data.do_include_prediction_time=true data.do_include_end_time=true \\\n",
    "    paths.meds_cohort_dir={MEDS_DIR} ckpt_path={BEST_CHECKPOINT} \\\n",
    "    paths.data_dir={TENSOR_DIR} paths.output_dir={OUTPUT_DIR} \\\n",
    "    \"hydra.searchpath=[pkg://meds_torch.configs,{TUTORIAL_DIR}/configs/]\"\n",
    "\"\"\"\n",
    "with hydra.initialize(version_base=\"1.3\", config_path=\"../src/meds_torch/configs\"):\n",
    "    args = shlex.split(cmd)[1:]\n",
    "    overrides = hydra._internal.utils.get_args(args).overrides\n",
    "    cfg = hydra.compose(config_name=\"generate_trajectories\", return_hydra_config=True, overrides=overrides)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Boom! We have a config! Let's use this to load the dataset and pull some data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from meds_torch.data.components.pytorch_dataset import PytorchDataset\n",
    "\n",
    "datamodule = hydra.utils.instantiate(cfg.data)\n",
    "datamodule.setup()\n",
    "val_pytorch_dataset: PytorchDataset = datamodule.data_val  # data_test or data_train works too"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = val_pytorch_dataset.collate([val_pytorch_dataset[0]])\n",
    "print(data.keys())\n",
    "print(data[\"code\"])  # sequence of observations for the patient\n",
    "print(data[\"subject_id\"])  # patient id\n",
    "print(data[\"end_time\"])  # datetime time of the last observation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To interpret these codes^, we can load the metadata that maps each vocabulary index to three pieces of information:\n",
    "\n",
    "- The `code` - the human-readable string representation\n",
    "- The `description` - a detailed explanation of what the code means\n",
    "- The `values/mean` - for codes that represent binned numeric values (like vital signs or lab results), this shows the mean value within that bin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata_df = pl.read_parquet(cfg.data.code_metadata_fp)\n",
    "metadata_df.with_columns((pl.col(\"values/sum\") / pl.col(\"values/n_occurrences\")).alias(\"values/mean\"))\n",
    "\n",
    "\n",
    "print(metadata_df[\"code\", \"description\", \"code/vocab_index\"])\n",
    "index_to_code = dict(zip(metadata_df[\"code/vocab_index\"], metadata_df[\"code\"]))\n",
    "interpetable_data = [index_to_code[index] for index in data[\"code\"].flatten().tolist()]\n",
    "interpetable_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also look at raw patient meds data prior to processing here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subject_id = 10000108\n",
    "med_df = pl.read_parquet(f\"{ROOT_DIR}/meds/data/**/*.parquet\")\n",
    "med_df.filter(pl.col(\"subject_id\").eq(subject_id) & pl.col(\"code\").str.starts_with(\"DIAGNOSIS//\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's now load a model and run some custom sequence of codes through it and see what it generates!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "model = hydra.utils.instantiate(cfg.model)\n",
    "checkpoint = torch.load(cfg.ckpt_path, map_location=\"cpu\")\n",
    "model.load_state_dict(checkpoint[\"state_dict\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's get some generated data. Each row in the generated table represents one of the next 10 predicted medical observations per patient, with these columns:\n",
    "\n",
    "- `subject_id`: Unique identifier for each patient\n",
    "- `prediction_time`: Timestamp of the last real observation in the patient's history used to make predictions - remains constant across all predictions for the same generation sequence.\n",
    "- `time`: Predicted timestamp when this observation would occur (calculated by adding the model's predicted time delta to `prediction_time`)\n",
    "- `code`: Human-readable name of the medical code (e.g., \"Heart Rate\", \"Blood Pressure\") \n",
    "- `code/vocab_index`: Internal integer id used by the model to represent each medical code\n",
    "- `numeric_value`: For quantitative measurements (like lab values or vital signs), represents the average value within the predicted categorical bin (e.g., if heart rates are binned into ranges of 60-70, 70-80, etc., shows the average within the predicted bin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "\n",
    "subject_id = torch.tensor([1, 2])\n",
    "custom_codes_patient_1 = torch.tensor([4, 5, 8, 182])\n",
    "custom_codes_patient_2 = torch.tensor([188, 48, 78, 178, 1827, 2973])\n",
    "\n",
    "print(\"We pad the codes with padding tokens (which are 0's)\")\n",
    "custom_codes = torch.nn.utils.rnn.pad_sequence(\n",
    "    [custom_codes_patient_1, custom_codes_patient_2], batch_first=True\n",
    ")\n",
    "print(custom_codes)\n",
    "mask = custom_codes != 0\n",
    "\n",
    "print(\"The mask is True for non-padding tokens\")\n",
    "print(mask)\n",
    "\n",
    "print(\"Add some end_time as the generated data will use this as the starting time.\")\n",
    "patient_1_end_time = datetime.datetime(year=2000, month=1, day=1)\n",
    "patient_2_end_time = datetime.datetime(year=2000, month=1, day=1)\n",
    "end_time = [patient_1_end_time, patient_2_end_time]\n",
    "\n",
    "print(\"Pass this to the model and we get the following keys\")\n",
    "batch = dict(code=custom_codes, mask=mask, prediction_time=end_time, end_time=end_time, subject_id=subject_id)\n",
    "output = model(batch)\n",
    "print(output.keys())\n",
    "\n",
    "print(\"And we can interpret the generated trajectory:\")\n",
    "output[\"GENERATE//0\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "if you are looking at logits, make sure you have `model.return_logits=true` and you'll find it in:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output[\"MODEL//LOGITS_SEQUENCE\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can always get the last embedding too:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output[\"BACKBONE//EMBEDDINGS\"]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "meds-torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
