{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a616825b",
   "metadata": {},
   "source": [
    "# Tutorial for LLM-based Entropy-guided Optimization with kNowledgeable priors (LEON)\n",
    "\n",
    "This is a brief tutorial on our implementation of LEON from [our paper](https://openreview.net/forum?id=w025bYRVkO), and how you can use LEON to solve your own black-box optimization functions. We have made LEON available as a `pip` package: all you have to do to get started is run\n",
    "\n",
    "```bash\n",
    "python -m pip install -e .\n",
    "```\n",
    "\n",
    "We use Python 3.10 in our implementation. After successful installation, `leon` can be used just like any other Python package:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6d5220e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import leon"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "678141ca",
   "metadata": {},
   "source": [
    "For the purposes of this tutorial, we also need to set up a few environmental variables - you can add your Azure OpenAI endpoints and API keys to use for chat completion and text embedding tasks (which may be the same for some users!) below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2244d421",
   "metadata": {},
   "outputs": [],
   "source": [
    "%env AZURE_API_VERSION=\"TODO\"\n",
    "%env API_ENDPOINT_CHAT=\"https://TODO.openai.azure.com/\"\n",
    "%env API_KEY_CHAT=\"YOUR_API_KEY_HERE\"\n",
    "%env API_ENDPOINT_EMBED=\"https://TODO.openai.azure.com/\"\n",
    "%env API_KEY_EMBED=\"YOUR_API_KEY_HERE\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1606de79",
   "metadata": {},
   "source": [
    "We will also perform some additional configuration steps that are only needed for the purposes of this tutorial in a Jupyter notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91727d86",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "\n",
    "sys.modules.setdefault(\"tutorial\", sys.modules[__name__])\n",
    "seed = 2025"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "925dc7a7",
   "metadata": {},
   "source": [
    "At its core, our implementation of LEON revolves around the `leon.optimize.minimize()` function, which takes a black-box optimization task specification and performs a language model-based optimization procedure described in [our paper](https://openreview.net/forum?id=w025bYRVkO).\n",
    "\n",
    "Before we can use this function, we first need to define relevant classes to describe our optimization task. For the purposes of this tutorial, we consider a toy example task where the goal is to optimize a modified version of the [multi-objective-multi-fidelity Branin-Currin function](https://botorch.readthedocs.io/en/stable/test_functions.html#botorch.test_functions.multi_objective_multi_fidelity.MOMFBraninCurrin). More specificially, each 'patient' is described by a single parameter $w\\in [0, 1]$, and each 'treatment' is a two-dimensional vector $(x_1, x_2)\\in[0, 1]^2$. The source function (i.e., surrogate model) $\\hat{f}_\\text{source}: [0, 1]^3\\to\\mathbb{R}$ is given by\n",
    "$$\n",
    "\\hat{f}_\\text{source}(x_1, x_2; w) = [w\\cdot B(x_1, x_2, s=0.1)] + [(1-w)\\cdot C(x_1, x_2, s=0.1)]\n",
    "$$\n",
    "where $s$ is the fidelity parameter, $B(x_1, x_2, s)$ is the modified Branin function, $C(x_1, x_2, s)$ is the modified Currin function. Similarly, the target function (i.e., ground-truth objective) $f_\\text{target}: [0, 1]^3\\to\\mathbb{R}$ is given by\n",
    "$$\n",
    "f_\\text{target}(x_1, x_2; w) = [w\\cdot B(x_1, x_2, s=1.0)] + [(1-w)\\cdot C(x_1, x_2, s=1.0)]\n",
    "$$\n",
    "(Note the difference in the values of the fidelity parameters between the source and target functions.)\n",
    "\n",
    "We implement both functions in the `FooBarFunction` class below. Your task's implementation should also implement a `predict()` class function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db2fcba1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from botorch.test_functions.multi_objective_multi_fidelity import (\n",
    "    MOMFBraninCurrin\n",
    ")\n",
    "from typing import Any, Dict, List, Callable, NamedTuple\n",
    "\n",
    "\n",
    "class FooBarFunction(Callable):\n",
    "    def __init__(self, split: str):\n",
    "        self.fidelity = 0.1 + (0.9 * ([\"source\", \"target\"].index(split)))\n",
    "        self.obj_func = MOMFBraninCurrin(negate=True)\n",
    "        self.n_dim = self.obj_func.dim\n",
    "\n",
    "    def predict(\n",
    "        self, x: List[NamedTuple], **kwargs: Dict[str, Any]\n",
    "    ) -> List[float]:\n",
    "        del kwargs\n",
    "\n",
    "        y = []\n",
    "        for obs in x:\n",
    "            tensor_x = torch.tensor([obs.dim_1, obs.dim_2, self.fidelity])\n",
    "            tensor_x = torch.clamp(tensor_x, min=0.0, max=1.0)\n",
    "            tensor_w = torch.tensor([obs.w, 1.0 - obs.w])\n",
    "            y.append(float((self.obj_func(tensor_x) * tensor_w).sum().item()))\n",
    "\n",
    "        return y \n",
    "\n",
    "    def __call__(\n",
    "        self, x: List[NamedTuple], **kwargs: Dict[str, Any]\n",
    "    ) -> List[float]:\n",
    "        return self.predict(x, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "264ae82e",
   "metadata": {},
   "source": [
    "### Dataset Implementation\n",
    "\n",
    "LEON also leverages an offline dataset of patient data. The data from the source distribution and target distribution for our tutorial are both implemented in the `ToyDataset` class below. Each observation from the dataset is represented as a `FooBarObservation` object. Note that the design dimensions $x_1, x_2$, which are meant to represent patient treatments, are set as optional variables in the `FooBarObservation` implementation. This is because patients from the target distribution do not necessarily have a treatment administered yet.\n",
    "\n",
    "Furthermore, note the additional methods that are implemented in both the observation `NamedTuple` objects and also in the `leon.data.BaseDataset` subclass. You will need to implement similar methods for your own task, too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39632d54",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional, Union\n",
    "from leon.data import BaseDataset\n",
    "\n",
    "\n",
    "class FooBarObservation(NamedTuple):\n",
    "    dim_1: Optional[float]\n",
    "    dim_2: Optional[float]\n",
    "    w: float\n",
    "\n",
    "    def __str__(self) -> str:\n",
    "        return f\"({self.dim_1}, {self.dim_2}, {self.w})\"\n",
    "\n",
    "    def __repr__(self) -> str:\n",
    "        return str(self)\n",
    "\n",
    "    def as_tensor(self) -> torch.Tensor:\n",
    "        assert self.dim_1 is not None and self.dim_2 is not None\n",
    "        return torch.tensor(\n",
    "            [self.dim_1, self.dim_2, self.w], dtype=torch.float32\n",
    "        )\n",
    "\n",
    "    @classmethod\n",
    "    def ignored_features(cls) -> List[str]:\n",
    "        return [\"w\"]\n",
    "\n",
    "    @classmethod\n",
    "    def discrete_features(cls) -> Dict[str, List[Union[str, bool]]]:\n",
    "        return {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b39ecd60",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "class ToyDataset(BaseDataset):\n",
    "    def __init__(self, split: str, N: int = 100, seed: Optional[int] = 2025):\n",
    "        assert split in [\"source\", \"target\"]\n",
    "        super(ToyDataset, self).__init__(split)\n",
    "\n",
    "        # Construct a toy dataset of N points on the fly.\n",
    "        rng = np.random.default_rng(seed=seed)\n",
    "        obj_func = FooBarFunction(split)\n",
    "        self.X: List[NamedTuple] = [\n",
    "            FooBarObservation(x[0], x[1], x[2])\n",
    "            for x in rng.random(size=(2 * N, obj_func.n_dim + 1))\n",
    "        ]\n",
    "        self.y: np.ndarray = np.array(obj_func(self.X))\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx: int) -> FooBarObservation:\n",
    "        return self.X[idx]\n",
    "\n",
    "    @property\n",
    "    def data(self) -> pd.DataFrame:\n",
    "        return pd.DataFrame.from_records([x._asdict() for x in self.X])\n",
    "\n",
    "    @property\n",
    "    def target(self) -> np.ndarray:\n",
    "        return self.y\n",
    "\n",
    "    @property\n",
    "    def target_name(self) -> str:\n",
    "        return \"objective\"\n",
    "\n",
    "    def relabel(self, y: List[float]) -> None:\n",
    "        self.y = np.array(y)\n",
    "\n",
    "    def mask_designs(self) -> None:\n",
    "        self.X = [\n",
    "            FooBarObservation(None, None, w=self.X[i].w)\n",
    "            for i in range(len(self.X))\n",
    "        ]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e06fd731",
   "metadata": {},
   "source": [
    "### Task Implementation\n",
    "\n",
    "Now that we have defined our dataset, we are now ready to define our offline optimization task. Each task should inherit from the `leon.envs.BaseTask` base class, which requires you to define a number of class-specific methods for your implementation. We also need to define a way to represent each treatment design, which should inhert from the `pydantic.BaseModel` base class. We provide an example implementation `FooBarDesign` for the purposes of our toy task in this tutorial. Note that your implementation should similarly define a representation of each design in natural language that makes sense for your application."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "610a8a53",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydantic import BaseModel\n",
    "\n",
    "\n",
    "class FooBarDesign(BaseModel):\n",
    "    dim_1: float\n",
    "    dim_2: float\n",
    "\n",
    "    def __str__(self) -> str:\n",
    "        return f\"({self.dim_1:.4f}, {self.dim_2:.4f})\"\n",
    "\n",
    "    def __repr__(self) -> str:\n",
    "        return str(self)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac334108",
   "metadata": {},
   "outputs": [],
   "source": [
    "from leon.envs import BaseTask\n",
    "from typing import Any, Dict, NamedTuple, Sequence, Type\n",
    "\n",
    "\n",
    "class ToyTask(BaseTask):\n",
    "    def __init__(\n",
    "        self,\n",
    "        train: ToyDataset,\n",
    "        test: ToyDataset,\n",
    "        seed: Optional[int] = 2025,\n",
    "        **kwargs: Dict[str, Any]\n",
    "    ):\n",
    "        del kwargs\n",
    "        super(ToyTask, self).__init__(\n",
    "            task_name=\"ToyTask-v0\",\n",
    "            train=train,\n",
    "            test=test,\n",
    "            train_model=FooBarFunction(\"source\"),\n",
    "            test_model=FooBarFunction(\"target\"),\n",
    "            seed=seed\n",
    "        )\n",
    "\n",
    "    @property\n",
    "    def design_schema(self) -> Type[BaseModel]:\n",
    "        return FooBarDesign\n",
    "\n",
    "    def task_description(\n",
    "        self, *args: Sequence[Any], **kwargs: Dict[str, Any]\n",
    "    ) -> str:\n",
    "        del args, kwargs\n",
    "        return (\n",
    "            \"Propose a 2D coordinate that maximizes the weighted sum of the \"\n",
    "            \"Branin and Currin test functions.\"\n",
    "        )\n",
    "\n",
    "    @property\n",
    "    def ndim(self) -> int:\n",
    "        return self._train_model.n_dim - 1\n",
    "\n",
    "    def extend(self, x: List[BaseModel], ref: Any) -> List[Any]:\n",
    "        return [\n",
    "            FooBarObservation(design.dim_1, design.dim_2, ref.w)\n",
    "            for design in x\n",
    "        ]\n",
    "\n",
    "    def reduce(self, x: List[NamedTuple]) -> List[Any]:\n",
    "        return [FooBarDesign(dim_1=obs.dim_1, dim_2=obs.dim_2) for obs in x]\n",
    "\n",
    "    @property\n",
    "    def disease_name(self) -> str:\n",
    "        return \"a rare disease\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54345de4",
   "metadata": {},
   "source": [
    "### Task Registration and Instantiation\n",
    "\n",
    "We have now set up all the necessary components to instantiate the task environment for optimization! Our API for task registration and instantiation closely follows the [`gymnasium.make()`](https://gymnasium.farama.org/api/registry/) API (and is in fact built on top of it!). This means we first need to `register` our task and then `make` it to instantiate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74e7a57d",
   "metadata": {},
   "outputs": [],
   "source": [
    "leon.register(\n",
    "    id=\"ToyTask-v0\",  # Come up with a task ID string for your task here!\n",
    "    entry_point=\"tutorial:ToyTask\",  # TODO: Add your task spec here!\n",
    "    max_episode_steps=1,  # This should always be 1.\n",
    "    kwargs={\n",
    "        \"dataset\": \"tutorial:ToyDataset\",  # TODO: Add your data spec here!\n",
    "        \"train_split\": \"source\",  # TODO: Specify the source train split here.\n",
    "        \"test_split\": \"target\"  # TODO: Specify the target test split here.\n",
    "    },\n",
    "    disable_env_checker=True  # This should always be True.\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b57dd06c",
   "metadata": {},
   "outputs": [],
   "source": [
    "task = leon.make(\"ToyTask-v0\", seed=seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9f1bcb4",
   "metadata": {},
   "source": [
    "### Putting It All Together\n",
    "\n",
    "We are now ready to run LEON to optimize this task! We have defined the `leon.optimize.maximize()` function for easy interfacing with LEON; this function is heavily inspired by the [`scipy.optimize.minimize()`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html) API and has been made to be as close of a drop-in replacement as possible for SciPy-like optimization code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0d174ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimization_result = leon.optimize.maximize(\n",
    "    task,\n",
    "    x0=None,  # For most applications, you can safely keep this as None.\n",
    "    args=(0,),  # This should be a tuple of the index of the patient in the target (i.e., test) dataset that you want to optimize for.\n",
    "    options={  # Run `leon.optimize.show_options()` to see a full list of available configuration options.\n",
    "        \"knowledge_source\": [\"None\"],\n",
    "        \"embedder_name\": \"leon/random\",\n",
    "        \"seed\": seed\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82687e48",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimization_result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0636cc5a",
   "metadata": {},
   "source": [
    "For the full list of specification options for LEON to provide in the `options` dictionary above, run `leon.optimize.show_options()`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45238c4b",
   "metadata": {},
   "source": [
    "### Still Have Questions?\n",
    "\n",
    "Reach out to the corresponding paper of our paper [linked here](https://openreview.net/forum?id=w025bYRVkO)!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "leon",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
