{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import rootutils\n",
    "\n",
    "root = rootutils.setup_root(os.path.abspath(\"\"), dotenv=True, pythonpath=True, cwd=True)\n",
    "\n",
    "from dataclasses import dataclass, field\n",
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import polars as pl\n",
    "from dateutil.relativedelta import relativedelta\n",
    "from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict\n",
    "from omegaconf import DictConfig\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class DummyConfig:\n",
    "    \"\"\"Dummy configuration for testing MEDS dataset\"\"\"\n",
    "\n",
    "    schema_files_root: str\n",
    "    task_label_path: str\n",
    "    data_dir: str\n",
    "    task_name: str = \"dummy_task\"\n",
    "    max_seq_len: int = 64\n",
    "    do_prepend_static_data: bool = False\n",
    "    postpend_eos_token: bool = False\n",
    "    do_flatten_tensors: bool = True\n",
    "    EOS_TOKEN_ID: int = 5\n",
    "    do_include_subject_id: bool = True\n",
    "    do_include_subsequence_indices: bool = True\n",
    "    do_include_start_time_min: bool = True\n",
    "    do_include_end_time: bool = True\n",
    "    do_include_prediction_time: bool = True\n",
    "    subsequence_sampling_strategy: str = \"from_start\"\n",
    "    code_metadata_fp: str = field(init=False)\n",
    "\n",
    "    def __post_init__(self):\n",
    "        self.code_metadata_fp = self.data_dir + \"/metadata.parquet\"\n",
    "\n",
    "\n",
    "def create_dummy_dataset(\n",
    "    base_dir: str | Path, n_subjects: int = 3, split: str = \"train\", seed: int | None = 42, n_repeats: int = 3\n",
    ") -> DummyConfig:\n",
    "    if seed is not None:\n",
    "        np.random.seed(seed)\n",
    "\n",
    "    base_dir = Path(base_dir)\n",
    "\n",
    "    # Create directories\n",
    "    schema_dir = base_dir / \"schema\" / split\n",
    "    schema_dir.mkdir(parents=True, exist_ok=True)\n",
    "    base_dir.joinpath(\"data\").mkdir(exist_ok=True)\n",
    "\n",
    "    # Create static data\n",
    "    base_datetime = datetime(1995, 1, 1)\n",
    "    static_data = []\n",
    "    for subject_id in range(n_subjects):\n",
    "        static_data.append(\n",
    "            {\n",
    "                \"subject_id\": subject_id,\n",
    "                \"start_time\": base_datetime,\n",
    "                \"time\": [base_datetime + relativedelta(days=i) for i in range(8 * n_repeats)],\n",
    "                \"code\": [1, 2, 3],\n",
    "                \"numeric_value\": [0.1, 0.2, 0.3],\n",
    "            }\n",
    "        )\n",
    "    static_df = pl.DataFrame(static_data)\n",
    "    static_df.write_parquet(schema_dir / \"shard_0.parquet\", use_pyarrow=True)\n",
    "\n",
    "    # Create dynamic data with consistent sequence lengths\n",
    "    subject_dynamic_data = []\n",
    "    for subject_id in range(n_subjects):\n",
    "        length = np.random.randint(8, 8 * n_repeats)\n",
    "        dynamic_data = JointNestedRaggedTensorDict(\n",
    "            raw_tensors={\n",
    "                \"code\": ([[1], [2], [1], [2], [1], [2], [1], [3]] * n_repeats)[:length],\n",
    "                \"numeric_value\": (\n",
    "                    [\n",
    "                        [np.nan],\n",
    "                        [np.nan],\n",
    "                        [np.nan],\n",
    "                        [np.nan],\n",
    "                        [np.nan],\n",
    "                        [np.nan],\n",
    "                        [np.nan],\n",
    "                        [np.nan],\n",
    "                    ]\n",
    "                    * n_repeats\n",
    "                )[:length],\n",
    "                \"time_delta_days\": ([1, 1, 1, 1, 1, 1, 1, 1] * n_repeats)[:length],\n",
    "            }\n",
    "        )\n",
    "        subject_dynamic_data.append(dynamic_data)\n",
    "    dynamic_data = JointNestedRaggedTensorDict.vstack(subject_dynamic_data)\n",
    "\n",
    "    nrt_output_dir = base_dir / \"data\" / split\n",
    "    nrt_output_dir.mkdir(parents=True, exist_ok=True)\n",
    "    dynamic_data.save(nrt_output_dir / \"shard_0.nrt\")\n",
    "\n",
    "    # Create task labels\n",
    "    task_df = pl.DataFrame(\n",
    "        {\n",
    "            \"subject_id\": list(range(n_subjects)),\n",
    "            \"prediction_time\": [base_datetime + relativedelta(years=3)] * n_subjects,\n",
    "            \"boolean_value\": [i % 2 for i in range(n_subjects)],\n",
    "        }\n",
    "    )\n",
    "\n",
    "    task_fp = base_dir / \"task_labels.parquet\"\n",
    "    task_df.write_parquet(task_fp, use_pyarrow=True)\n",
    "\n",
    "    config = DummyConfig(\n",
    "        schema_files_root=str(base_dir / \"schema\"),\n",
    "        task_label_path=str(task_fp),\n",
    "        data_dir=str(base_dir),\n",
    "    )\n",
    "\n",
    "    metadata_df = pl.DataFrame(\n",
    "        {\n",
    "            \"code\": [\"a\", \"b\", \"c\"],\n",
    "            \"code/vocab_index\": [1, 2, 3],\n",
    "        }\n",
    "    )\n",
    "\n",
    "    metadata_df.write_parquet(config.code_metadata_fp)\n",
    "\n",
    "    return config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tempfile\n",
    "\n",
    "from meds_torch.data.components.pytorch_dataset import PytorchDataset\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
    "\n",
    "\n",
    "tmp_dir = tempfile.TemporaryDirectory()\n",
    "data_config = create_dummy_dataset(tmp_dir.name, n_subjects=64)\n",
    "dataset = PytorchDataset(data_config, split=\"train\")\n",
    "\n",
    "dynamic_data, subject_id, st, end = dataset.load_subject_dynamic_data(0)\n",
    "print(\"every patient has identical data:\")\n",
    "print(dynamic_data.flatten().to_dense()[\"code\"])\n",
    "print(\"the length of the data is:\", str(len(dynamic_data.flatten().to_dense()[\"code\"])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"This file prepares config fixtures for other tests.\"\"\"\n",
    "\n",
    "from pathlib import Path\n",
    "\n",
    "import hydra\n",
    "from hydra import compose, initialize\n",
    "\n",
    "from meds_torch.utils.resolvers import setup_resolvers\n",
    "\n",
    "setup_resolvers()\n",
    "\n",
    "\n",
    "def create_cfg(overrides, config_name=\"train.yaml\") -> DictConfig:\n",
    "    \"\"\"Helper function to create Hydra DictConfig with given overrides and common settings.\"\"\"\n",
    "    with initialize(version_base=\"1.3\", config_path=\"../src/meds_torch/configs\"):\n",
    "        cfg = compose(config_name=config_name, return_hydra_config=True, overrides=overrides)\n",
    "    return cfg\n",
    "\n",
    "\n",
    "output_dir = Path(tmp_dir.name) / \"output\"\n",
    "overrides = [\n",
    "    \"model=eic_forecasting\",\n",
    "    \"model/backbone=eic_transformer_decoder_alibi\",\n",
    "    \"trainer=gpu\",\n",
    "    \"experiment=eic_forecast_mtr\",\n",
    "    \"data.subsequence_sampling_strategy=random\",\n",
    "    f\"data.code_metadata_fp={data_config.code_metadata_fp}\",\n",
    "    \"model.optimizer.lr=0.001\",\n",
    "    \"trainer.max_epochs=30\",\n",
    "    f\"paths.output_dir={output_dir}\",\n",
    "    \"model.top_k_acc=[1]\",\n",
    "    f\"hydra.searchpath=[pkg://meds_torch.configs,{root}/MIMICIV_INDUCTIVE_EXPERIMENTS/configs/meds-torch-configs/]\",\n",
    "]\n",
    "cfg = create_cfg(overrides)\n",
    "model = hydra.utils.instantiate(cfg.model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=dataset.collate)\n",
    "val_dataloader = DataLoader(dataset, batch_size=8, shuffle=False, collate_fn=dataset.collate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# callbacks = instantiate_callbacks(cfg.get(\"callbacks\"))\n",
    "# logger = instantiate_loggers(cfg.get(\"logger\"))\n",
    "\n",
    "trainer = hydra.utils.instantiate(cfg.trainer)\n",
    "trainer.fit(\n",
    "    model=model,\n",
    "    train_dataloaders=train_dataloader,\n",
    "    val_dataloaders=val_dataloader,\n",
    "    ckpt_path=cfg.get(\"ckpt_path\"),\n",
    ")\n",
    "trainer.validate(model=model, dataloaders=val_dataloader, ckpt_path=cfg.get(\"ckpt_path\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from meds_torch.latest_dir import get_latest_directory\n",
    "\n",
    "print(\"Enter the following in terminal to view the tensorboard logs:\")\n",
    "print(\"tensorboard --logdir=%s\" % get_latest_directory(cfg.paths.output_dir) + \"/lightning_logs/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from meds_torch.input_encoder import INPUT_ENCODER_MASK_KEY, INPUT_ENCODER_TOKENS_KEY\n",
    "\n",
    "batch = dataset.collate([dataset[i] for i in range(8)])\n",
    "input_batch = model.input_encoder.forward(batch)\n",
    "prompts, mask = input_batch[INPUT_ENCODER_TOKENS_KEY], input_batch[INPUT_ENCODER_MASK_KEY]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "prompt_lengths = mask.sum(dim=-1)\n",
    "prompt_lengths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from x_transformers.autoregressive_wrapper import align_right\n",
    "\n",
    "align_right(prompts, prompt_lengths, pad_id=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    \"Let's generate the future conditioned on the past 24 tokens (using a sliding window with a max sequence length of 24):\"\n",
    ")\n",
    "future = model.model.generate(\n",
    "    prompts=prompts[:1, :],\n",
    "    mask=mask[:1, :],\n",
    "    get_next_token_time=None,\n",
    "    time_offset_years=None,\n",
    "    temperature=model.cfg.temperature,\n",
    "    eos_tokens=model.cfg.eos_tokens,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_batch.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_batch = dict(\n",
    "    time_delta_days=input_batch[\"time_delta_days\"], code=input_batch[\"code\"], mask=input_batch[\"mask\"]\n",
    ")\n",
    "torch.functional.F.softmax(model.forward(test_batch)[\"MODEL//LOGITS_SEQUENCE\"], dim=-1).argmax(dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"We observe that the future follos the repeating pattern 1,2,1,2,1,2,1,3, great!\")\n",
    "print(future[0][0, :24])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "meds-torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
