{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d6eadb26",
   "metadata": {},
   "source": [
    "# Tutorial: Evaluation Tasks and Datasets\n",
    "\n",
    "This notebook details the experimental tasks and associated evaluation environments as part of the `leon` package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20ea4248",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import leon\n",
    "from sklearn.metrics import mean_squared_error, r2_score\n",
    "from pathlib import Path\n",
    "from typing import Dict, NamedTuple, Optional, Tuple, Union"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15576935",
   "metadata": {},
   "source": [
    "The function `leon.pprint_registry()` returns a list of all of the implemented tasks available in `leon`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd9a8375",
   "metadata": {},
   "outputs": [],
   "source": [
    "leon.pprint_registry()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b31c0287",
   "metadata": {},
   "source": [
    "Each of the above tasks consists of a [`BaseTask`](../src/leon/envs/base.py) object, which itself inherits from the [`Env`](https://gymnasium.farama.org/api/env/) Gymnasium class for implementing RL environments. Importantly, we highlight that each of the above tasks are *not* RL tasks! Each `BaseTask` is implemented in a way such that there is no sequential interaction with the environment. Instead, we take advantage of the well-documented and time-tested `gymnasium` package for implementation, while taking inspiration from the [`design-bench`](https://github.com/brandontrabucco/design-bench) package with myopic evaluation environments for the core logic.\n",
    "\n",
    "Each `BaseTask` implements the following public properties:\n",
    "  - `train`: a [`BaseDataset`](../src/leon/data/base.py) dataset of conditional designs and associated objective values. The full `train` dataset is made available to each optimization method.\n",
    "  - `test`: a [`BaseDataset`](../src/leon/data/base.py) dataset of different conditional designs and objective values. The underlying objective values have undergone a task-dependent distribution shift. Each optimization method is evaluated on the `test` dataset.\n",
    "Furthermore, each `BaseTask` implements the following public methods:\n",
    "  - `__call__()`: evaluates the quality of an input design(s) according to the public underlying train objective function. This is the objective function that exactly describes the objective in the training dataset.\n",
    "  - `predict()`: evaluates the quality of an input design(s) according to the private underlying test objective function. This is the objective function that exactly describes the objective in the test dataset. Note that this function can only be called *exactly once* for each task instantiation, and should only be used for evaluation purposes.\n",
    "Both of the above methods should be treated as black-box functions and do not make gradient information available to optimization methods.\n",
    "\n",
    "We explore each implemented task below and include some helpful utility functions to reproduce the figures in our work:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "325f6a6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model_fit(task: leon.envs.BaseTask) -> Dict[str, np.ndarray]:\n",
    "    \"\"\"\n",
    "    Evaluate the model fit to the evaluationg datums for a given task.\n",
    "    Input:\n",
    "        task: the task to evaluate the model fit for.\n",
    "    Returns:\n",
    "        A dictionary containing the model fit evaluation results by datum.\n",
    "    \"\"\"\n",
    "    y_off_dist = np.array(task([task.test[i] for i in range(len(task.test))]))\n",
    "    y_in_dist = np.array(task.predict([task.test[i] for i in range(len(task.test))]))\n",
    "    y_gt = np.array(task.test.target.astype(float))\n",
    "    if task.test._mu is not NotImplemented:\n",
    "        y_off_dist = task.test.unnormalize(y_off_dist)\n",
    "        y_in_dist = task.test.unnormalize(y_in_dist)\n",
    "        y_gt = task.test.unnormalize(y_gt)\n",
    "    if task.task_name in [\"HIVDB-v0\"]:\n",
    "        # Convert negative viral loads to positive values.\n",
    "        y_off_dist, y_in_dist, y_gt = -1.0 * y_off_dist, -1.0 * y_in_dist, -1.0 * y_gt\n",
    "    return {\"y_off_dist\": y_off_dist, \"y_in_dist\": y_in_dist, \"y_gt\": y_gt}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02a361fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summative_statistics(\n",
    "    model_fit_results: Dict[str, np.ndarray], task: leon.envs.BaseTask, verbose: bool = True\n",
    ") -> Dict[str, float]:\n",
    "    \"\"\"\n",
    "    Print summative statistics about the fit of a model for a task.\n",
    "    Input:\n",
    "        model_fit_results: a dictionary containing the model fit results.\n",
    "        task: the task for which the model fit results were obtained.\n",
    "        verbose: whether to print the results to stdout.\n",
    "    Returns:\n",
    "        A dictionary containing the mean squared error and R^2 score for both\n",
    "        off-distribution and in-distribution datasets.\n",
    "    \"\"\"\n",
    "    y_off_dist = model_fit_results[\"y_off_dist\"]\n",
    "    y_in_dist = model_fit_results[\"y_in_dist\"]\n",
    "    y_gt = model_fit_results[\"y_gt\"]\n",
    "\n",
    "    nmse_off_dist = mean_squared_error(y_off_dist, y_gt) / np.abs(y_gt.mean())\n",
    "    r2_off_dist = r2_score(y_off_dist, y_gt)\n",
    "    nmse_in_dist = mean_squared_error(y_in_dist, y_gt) / np.abs(y_gt.mean())\n",
    "    r2_in_dist = r2_score(y_in_dist, y_gt)\n",
    "\n",
    "    if verbose:\n",
    "        print(\"Size of In-Distribution Dataset:\", len(task.train))\n",
    "        print(\"Size of Out-of-Distribution Dataset:\", len(task.test), end=\"\\n\\n\")\n",
    "        print(\"Model Trained on Off-Distribution Patients:\")\n",
    "        print(\"    Normalized Mean Squared Error (NMSE):\", nmse_off_dist)\n",
    "        print(\"    $R^2$ Score:\", r2_off_dist, end=\"\\n\\n\")\n",
    "        print(\"Model Trained on On-Distribution Patients:\")\n",
    "        print(\"    Normalized Mean Squared Error (NMSE):\", nmse_in_dist)\n",
    "        print(\"    $R^2$ Score:\", r2_in_dist)\n",
    "\n",
    "    return {\n",
    "        \"nmse_off_dist\": nmse_off_dist,\n",
    "        \"r2_off_dist\": r2_off_dist,\n",
    "        \"nmse_in_dist\": nmse_in_dist,\n",
    "        \"r2_in_dist\": r2_in_dist\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c8ef644",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_bland_altman(\n",
    "    model_fit_results: Dict[str, np.ndarray],\n",
    "    task: leon.envs.BaseTask,\n",
    "    task_name: str,\n",
    "    y_name: str,\n",
    "    figsize: Tuple[float, float] = (4, 4),\n",
    "    xy_annotation: Tuple[float, float] = (0.98, 0.76),\n",
    "    savepath: Optional[Union[Path, str]] = None\n",
    ") -> None:\n",
    "    \"\"\"\n",
    "    Plots the Bland-Altman plot for the train and test models.\n",
    "    Input:\n",
    "        model_fit_results: a dictionary containing the model fit results.\n",
    "        task: the task for which the model fit results were obtained.\n",
    "        task_name: the name of the task.\n",
    "        y_name: the name of the target variable.\n",
    "        figsize: the size of the figure to plot.\n",
    "        xy_annotation: the xy coordinates of the annotation.\n",
    "        savepath: the optional path to save the figure to.\n",
    "    Returns:\n",
    "        None.\n",
    "    \"\"\"\n",
    "    y_off_dist = model_fit_results[\"y_off_dist\"]\n",
    "    y_in_dist = model_fit_results[\"y_in_dist\"]\n",
    "    y_gt = model_fit_results[\"y_gt\"]\n",
    "\n",
    "    summative_stats = get_summative_statistics(model_fit_results, task, verbose=False)\n",
    "\n",
    "    mean_values_train = (y_off_dist + y_gt) / 2.0\n",
    "    diff_values_train = y_off_dist - y_gt\n",
    "    mean_values_test = (y_in_dist + y_gt) / 2.0\n",
    "    diff_values_test = y_in_dist - y_gt\n",
    "\n",
    "    mean_diff_train, std_diff_train = diff_values_train.mean(), diff_values_train.std()\n",
    "    mean_diff_test, std_diff_test = diff_values_test.mean(), diff_values_test.std()\n",
    "\n",
    "    plt.figure(figsize=figsize)\n",
    "    plt.scatter(mean_values_train, diff_values_train, color=\"red\", alpha=0.3, label=\"Source Model\", rasterized=True)\n",
    "    plt.scatter(mean_values_test, diff_values_test, color=\"green\", marker=\"^\", alpha=0.3, label=\"Target Model\", rasterized=True)\n",
    "    plt.axhline(mean_diff_train, color=\"red\", linestyle=\"--\", alpha=0.5)\n",
    "    plt.axhline(mean_diff_train + 1.96 * std_diff_train, color=\"red\", linestyle=\"--\", alpha=0.3)\n",
    "    plt.axhline(mean_diff_train - 1.96 * std_diff_train, color=\"red\", linestyle=\"--\", alpha=0.3)\n",
    "    plt.axhline(mean_diff_test, color=\"green\", linestyle=\"--\", alpha=0.5)\n",
    "    plt.axhline(mean_diff_test + 1.96 * std_diff_test, color=\"green\", linestyle=\"--\", alpha=0.3)\n",
    "    plt.axhline(mean_diff_test - 1.96 * std_diff_test, color=\"green\", linestyle=\"--\", alpha=0.3)\n",
    "\n",
    "    plt.xlabel(f\"Mean of Predicted and Ground Truth {y_name}\", fontweight=\"bold\")\n",
    "    plt.ylabel(f\"Predicted - Ground Truth {y_name}\", fontweight=\"bold\")\n",
    "    plt.title(task_name, fontweight=\"bold\")\n",
    "\n",
    "    plt.annotate(\n",
    "        fr\"Source Model NMSE = {summative_stats['nmse_off_dist']:.3f}\",\n",
    "        xy=(xy_annotation[0], xy_annotation[1] + 0.055),\n",
    "        xycoords=\"axes fraction\",\n",
    "        color=\"red\",\n",
    "        horizontalalignment=\"right\",\n",
    "        verticalalignment=\"top\"\n",
    "    )\n",
    "    plt.annotate(\n",
    "        fr\"Target Model NMSE = {summative_stats['nmse_in_dist']:.3f}\",\n",
    "        xy=xy_annotation,\n",
    "        xycoords=\"axes fraction\",\n",
    "        color=\"green\",\n",
    "        horizontalalignment=\"right\",\n",
    "        verticalalignment=\"top\"\n",
    "    )\n",
    "\n",
    "    plt.legend()\n",
    "\n",
    "    plt.grid(alpha=0.3)\n",
    "    plt.tight_layout()\n",
    "    if savepath is not None:\n",
    "        plt.savefig(savepath, dpi=600, bbox_inches=\"tight\", transparent=True)\n",
    "    else:\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd642441",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_correlation(\n",
    "    model_fit_results: Dict[str, np.ndarray],\n",
    "    task: leon.envs.BaseTask,\n",
    "    task_name: str,\n",
    "    y_name: str,\n",
    "    figsize: Tuple[float, float] = (4, 4),\n",
    "    xy_annotation: Tuple[float, float] = (0.95, 0.05),\n",
    "    savepath: Optional[Union[Path, str]] = None\n",
    ") -> None:\n",
    "    \"\"\"\n",
    "    Plots a correlation plot for the train and test models.\n",
    "    Input:\n",
    "        model_fit_results: a dictionary containing the model fit results.\n",
    "        task: the task for which the model fit results were obtained.\n",
    "        task_name: the name of the task.\n",
    "        y_name: the name of the target variable.\n",
    "        figsize: the size of the figure to plot.\n",
    "        xy_annotation: the xy coordinates of the annotation.\n",
    "        savepath: the optional path to save the figure to.\n",
    "    Returns:\n",
    "        None.\n",
    "    \"\"\"\n",
    "    y_off_dist = model_fit_results[\"y_off_dist\"]\n",
    "    y_in_dist = model_fit_results[\"y_in_dist\"]\n",
    "    y_gt = model_fit_results[\"y_gt\"]\n",
    "\n",
    "    plt.figure(figsize=figsize)\n",
    "    plt.scatter(y_gt, y_off_dist, color=\"red\", alpha=0.3, label=\"Source Model\", rasterized=True)\n",
    "    plt.scatter(y_gt, y_in_dist, color=\"green\", marker=\"^\", alpha=0.3, label=\"Target Model\", rasterized=True)\n",
    "\n",
    "    summative_stats = get_summative_statistics(model_fit_results, task, verbose=False)\n",
    "\n",
    "    def add_identity(axes, *line_args, **line_kwargs):\n",
    "        identity, = axes.plot([], [], *line_args, **line_kwargs)\n",
    "        def callback(axes):\n",
    "            low_x, high_x = axes.get_xlim()\n",
    "            low_y, high_y = axes.get_ylim()\n",
    "            low = max(low_x, low_y)\n",
    "            high = min(high_x, high_y)\n",
    "            identity.set_data([low, high], [low, high])\n",
    "        callback(axes)\n",
    "        axes.callbacks.connect(\"xlim_changed\", callback)\n",
    "        axes.callbacks.connect(\"ylim_changed\", callback)\n",
    "        return axes\n",
    "\n",
    "    add_identity(plt.gca(), color=\"k\", linestyle=\"--\", alpha=0.3)\n",
    "\n",
    "    plt.xlabel(f\"Ground Truth {y_name}\", fontweight=\"bold\")\n",
    "    plt.ylabel(f\"Model Predicted {y_name}\", fontweight=\"bold\")\n",
    "    plt.title(task_name, fontweight=\"bold\")\n",
    "\n",
    "    plt.annotate(\n",
    "        fr\"Source Model $R^2$ = {summative_stats['r2_off_dist']:.3f}\",\n",
    "        color=\"red\",\n",
    "        xy=(xy_annotation[0], xy_annotation[1] + 0.055),\n",
    "        xycoords=\"axes fraction\",\n",
    "        ha=\"right\",\n",
    "        va=\"bottom\"\n",
    "    )\n",
    "    plt.annotate(\n",
    "        fr\"Target Model $R^2$ = {summative_stats['r2_in_dist']:.3f}\",\n",
    "        color=\"green\",\n",
    "        xy=(xy_annotation[0], xy_annotation[1]),\n",
    "        xycoords=\"axes fraction\",\n",
    "        ha=\"right\",\n",
    "        va=\"bottom\"\n",
    "    )\n",
    "\n",
    "    plt.legend()\n",
    "\n",
    "    plt.grid(alpha=0.3)\n",
    "    plt.tight_layout()\n",
    "    if savepath is not None:\n",
    "        plt.savefig(savepath, dpi=600, bbox_inches=\"tight\", transparent=True)\n",
    "    else:\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff777439",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_train_test_distribution_shift(\n",
    "    task: leon.envs.BaseTask,\n",
    "    task_name: str,\n",
    "    y_name: str,\n",
    "    figsize: Tuple[float, float] = (4, 4),\n",
    "    savepath: Optional[Union[Path, str]] = None,\n",
    "    nbins: int = 50\n",
    ") -> None:\n",
    "    \"\"\"\n",
    "    Plots the distribution shift between the train and test datasets.\n",
    "    Input:\n",
    "        task: the task to plot the distribution shift for.\n",
    "        task_name: the name of the task.\n",
    "        y_name: the name of the target variable.\n",
    "        figsize: the size of the figure to plot.\n",
    "        savepath: the optional path to save the figure to.\n",
    "        nbins: the number of bins to use for the histogram.\n",
    "    Returns:\n",
    "        None.\n",
    "    \"\"\"\n",
    "    num_train, num_test = len(task.train), len(task.test)\n",
    "    plt.figure(figsize=figsize)\n",
    "    y_train, y_test = task.train.target, task.test.target\n",
    "    if task.train._mu is not NotImplemented:\n",
    "        y_train = task.train.unnormalize(y_train)\n",
    "        y_test = task.test.unnormalize(y_test)\n",
    "    if task.task_name in [\"HIVDB-v0\"]:\n",
    "        # Convert negative viral loads to positive values.\n",
    "        y_train, y_test = -1.0 * y_train, -1.0 * y_test\n",
    "    bins = np.linspace(min(np.min(y_train), np.min(y_test)), max(np.max(y_train), np.max(y_test)), num=nbins)\n",
    "    plt.hist(y_train, bins=bins, alpha=0.5, color=\"red\", label=fr\"Source Dataset ($N$ = {num_train})\", density=True)\n",
    "    plt.hist(y_test, bins=bins, alpha=0.5, color=\"green\", label=fr\"Target Dataset ($N$ = {num_test})\", density=True)\n",
    "\n",
    "    plt.xlabel(f\"Ground Truth {y_name}\", fontweight=\"bold\")\n",
    "    plt.ylabel(\"Density\", fontweight=\"bold\")\n",
    "    plt.title(task_name, fontweight=\"bold\")\n",
    "\n",
    "    plt.legend()\n",
    "    plt.grid(alpha=0.3)\n",
    "    plt.tight_layout()\n",
    "\n",
    "    if savepath is not None:\n",
    "        plt.savefig(savepath, dpi=600, bbox_inches=\"tight\", transparent=True)\n",
    "    else:\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4b98115",
   "metadata": {},
   "source": [
    "We also implement a simple `pprint()` function for viewing datums within each dataset associated with a task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b497b3b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pprint(x: NamedTuple) -> None:\n",
    "    \"\"\"\n",
    "    Pretty prints a NamedTuple datum.\n",
    "    Input:\n",
    "        x: the input NamedTuple object to print.\n",
    "    Returns:\n",
    "        None.\n",
    "    \"\"\"\n",
    "    print(type(x).__name__, \":\", sep=\"\")\n",
    "    for key, val in x._asdict().items():\n",
    "        print(\"  \", key, \": \", val, sep=\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19bee60f",
   "metadata": {},
   "source": [
    "A `BaseTask` should be instantiated using the `make()` function, which closely follows the logic of the [`gymnasium.make()`](https://gymnasium.farama.org/api/registry/) function. The two required arguments for the `make()` function are (1) a `task_name` (chosen from the `pprint_registry` above); and (2) a random `seed`. We set a random seed below for demonstration purposes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df44289e",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed: int = 2025\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac19b6b9",
   "metadata": {},
   "source": [
    "For the purposes of this notebook, we also set the `DEBUG` environmental variable to `True` to be able to evaluate the quality of the trained train and test models on the (typically hidden) test set of designs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18a5b940",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"DEBUG\"] = \"True\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61d1c4d6",
   "metadata": {},
   "source": [
    "## IWPC Warfarin Dosing Optimization Task\n",
    "\n",
    "This task is adapted from [The International Warfarin Pharmacogenetics Consortium. NEJM (2009)](https://www.nejm.org/doi/full/10.1056/NEJMoa0809329) and [Yao MS et al. NeurIPS (2024)](https://openreview.net/forum?id=3RxcarQFRn) on predicting the optimal warfarin dose for a patient given input clinical and pharmacogenetic data. Both the train and test models are instances of the [TabPFN](https://github.com/PriorLabs/TabPFN) regressor model from [Hollmann N et al. Nature (2025)](https://www.nature.com/articles/s41586-024-08328-6) to predict the optimality of a proposed warfarin dose given input patient covariates. The train model is trained only on White patients, and the test model is trained only on non-White patients. The goal of this task is to predict optimal warfarin doses for non-White patients given only access to the model trained only on White patients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fc25e34",
   "metadata": {},
   "outputs": [],
   "source": [
    "warfarin_task = leon.make(\"IWPCWarfarin-v0\", seed=seed)\n",
    "warfarin_model_fit_results = evaluate_model_fit(warfarin_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dddd31a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_bland_altman(warfarin_model_fit_results, warfarin_task, \"Warfarin\", \"Dose Optimality\", savepath=\"bland_altman_warfarin.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bee41bcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_correlation(warfarin_model_fit_results, warfarin_task, \"Warfarin\", \"Dose Optimality\", savepath=\"correlation_warfarin.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40654b3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_train_test_distribution_shift(warfarin_task, \"Warfarin\", \"Dose Optimality\", savepath=\"label_shift_warfarin.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c51b8dd2",
   "metadata": {},
   "source": [
    "Here is an example patient from the `train` dataset in this task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "307871cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "pprint(warfarin_task.train[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92001ae4",
   "metadata": {},
   "source": [
    "## Stanford HIV Database (HIVDB) Optimization Task\n",
    "\n",
    "[Ree SY et al. Nuc Acids Res (2003)](https://academic.oup.com/nar/article/31/1/298/2401450) established the [HIV Drug Resistance Database](https://hivdb.stanford.edu/pages/citation.html) of de-identified patient responses as a function of different antiretroviral treatment regimens. The raw data accumulated from multiple academic studies from 2000 to 2020 is available [here](https://hivdb.stanford.edu/clinical_studies/). In this adapted task, the train and test models are each fully-connected neural networks that are trained to predict the (negative of the) viral load of a patient given input patient covariates and proposed antiretroviral medications. The train model is trained only on patients included in studies published no later than 2017, and the test model is trained on patients included in studies published no earlier than 2018. The goal of this task is to predict an optimal antiretroviral medication regimen for patients after 2018 given only access to the model trained only on patient outcomes prior to 2017."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3b72bed",
   "metadata": {},
   "outputs": [],
   "source": [
    "hivdb_task = leon.make(\"HIVDB-v0\", seed=seed)\n",
    "hivdb_model_fit_results = evaluate_model_fit(hivdb_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46aaed8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_bland_altman(hivdb_model_fit_results, hivdb_task, \"HIV\", \"Viral Load\", savepath=\"bland_altman_hivdb.pdf\", xy_annotation=(0.98, 0.08))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3b2dc84",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_correlation(hivdb_model_fit_results, hivdb_task, \"HIV\", \"Viral Load\", savepath=\"correlation_hivdb.pdf\", xy_annotation=(0.98, 0.01))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9b32f94",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_train_test_distribution_shift(hivdb_task, \"HIV\", \"Viral Load\", savepath=\"label_shift_hivdb.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f343614a",
   "metadata": {},
   "source": [
    "Here is an example patient from the `train` dataset in this task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9eca8d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "pprint(hivdb_task.train[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "841da2df",
   "metadata": {},
   "source": [
    "For the purposes of this task, we restrict the set of (1) possible antiretroviral medications; (2) positions of protease mutations for a patient; and (3) positions of reverse transcriptase mutations for a patient to those where there are at least 10 patients in the `train` dataset (i.e., corresponding to a frequency of at least 0.8% in the `train` dataset) that have taken the particular medication or have a mutation at that particular position in the respective protein sequence. The final pruned list of features can be found in the [`hivdb_features.json`](../src/leon/data/hivdb_features.json) source file."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "leon",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
