{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adding Custom Functions to BALSA - Tutorial\n",
    "\n",
    "This notebook demonstrates how to add your custom optimisation function to the BALSA benchmark suite.\n",
    "\n",
    "## Prerequisites\n",
    "\n",
    "First, let's ensure we have BALSA installed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install balsa  # If not already installed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Creating Your Custom Function\n",
    "\n",
    "Let's create a custom objective function by inheriting from `ObjectiveFunction`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass, field\n",
    "from typing import Any, override\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../../\")  # Adds the project root directory to Python path\n",
    "\n",
    "from balsa.obj_func import ObjectiveFunction\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class MyCustomFunction(ObjectiveFunction):\n",
    "    \"\"\"Custom objective function for optimisation.\n",
    "\n",
    "    This is a simple example that implements a quadratic function.\n",
    "    \"\"\"\n",
    "\n",
    "    name: str = \"my_custom\"\n",
    "    dims: int = 5  # Example: 5-dimensional function\n",
    "    turn: float = 0.1\n",
    "    func_args: dict[str, Any] = field(\n",
    "        default_factory=lambda: {\n",
    "            \"param1\": 1.0,\n",
    "            \"param2\": 2.0,\n",
    "            \"lb\": [-5.0] * 5,  # Lower bounds\n",
    "            \"ub\": [5.0] * 5,  # Upper bounds\n",
    "        }\n",
    "    )\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        super().__post_init__()\n",
    "        assert self.dims > 0\n",
    "        self.lb = np.array(self.func_args.get(\"lb\"))\n",
    "        self.ub = np.array(self.func_args.get(\"ub\"))\n",
    "        self.param1 = self.func_args.get(\"param1\")\n",
    "        self.param2 = self.func_args.get(\"param2\")\n",
    "\n",
    "    @override\n",
    "    def _scaled(self, y: float) -> float:\n",
    "        return y  # No scaling in this example\n",
    "\n",
    "    @override\n",
    "    def __call__(self, x: np.ndarray, saver: bool = True, return_scaled=False) -> float:\n",
    "        self.counter += 1\n",
    "        assert len(x) == self.dims\n",
    "        assert x.ndim == 1\n",
    "\n",
    "        # Example quadratic function\n",
    "        y = float(np.sum(x**2))\n",
    "\n",
    "        self.tracker.track(y, x, saver)\n",
    "        return y if not return_scaled else self._scaled(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Testing the Function\n",
    "\n",
    "Let's test our custom function with some sample inputs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test input: [ 2.35529722 -1.21114182  4.05210127  3.35870916 -3.53762404]\n",
      "Function output: 47.2295252332561\n"
     ]
    }
   ],
   "source": [
    "# Create an instance of our custom function\n",
    "custom_func = MyCustomFunction()\n",
    "\n",
    "# Test with a random input\n",
    "test_input = np.random.uniform(low=-5.0, high=5.0, size=5)\n",
    "result = custom_func(test_input)\n",
    "\n",
    "print(f\"Test input: {test_input}\")\n",
    "print(f\"Function output: {result}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Using the Function with BALSA\n",
    "\n",
    "Now let's set up and run an optimization using our custom function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "This optimisation is based on a turbo optimiser\n",
      "Using dtype = torch.float32 \n",
      "Using device = cpu\n",
      "TR-0 starting from: 20.3\n",
      "TR-1 starting from: 16.41\n",
      "TR-2 starting from: 21.33\n",
      "TR-3 starting from: 16.54\n",
      "TR-4 starting from: 17.81\n",
      "35) New best @ TR-0: 12.82\n"
     ]
    }
   ],
   "source": [
    "from balsa.active_learning import OptimisationRegistry, OptimisationConfig, ActiveLearningPipeline\n",
    "\n",
    "# Register your custom function\n",
    "OptimisationRegistry.FUNCTIONS[\"my_custom\"] = MyCustomFunction\n",
    "\n",
    "# Configuration\n",
    "config = OptimisationConfig(\n",
    "    dims=5,\n",
    "    search_method=\"turbo\",\n",
    "    obj_func_name=\"my_custom\",\n",
    "    num_acquisitions=50,\n",
    "    num_samples_per_acquisition=1,\n",
    "    surrogate=\"default_surrogate\",\n",
    "    num_init_samples=30,\n",
    "    func_args={\n",
    "        \"param1\": 1.0,\n",
    "        \"param2\": 2.0,\n",
    "        \"lb\": [-5.0] * 5,\n",
    "        \"ub\": [5.0] * 5\n",
    "    }\n",
    ")\n",
    "\n",
    "# Run optimization\n",
    "pipeline = ActiveLearningPipeline(config)\n",
    "result = pipeline.run()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
