{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "956e2f15",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc38efa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Callable, Union, Sequence\n",
    "import math\n",
    "import torch\n",
    "from scipy.spatial.distance import cdist\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def peaks(meshgrid: torch.Tensor) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    \"Peaks\" function that has multiple local minima.\n",
    "\n",
    "    :params meshgrid: tensor of shape [..., 2], the (x, y) coordinates\n",
    "    \"\"\"\n",
    "    meshgrid = torch.as_tensor(meshgrid, dtype=torch.float)\n",
    "    xx = meshgrid[..., 0]\n",
    "    yy = meshgrid[..., 1]\n",
    "    return 0.25 * (\n",
    "        3 * (1 - xx) ** 2 * torch.exp(-(xx**2) - (yy + 1) ** 2)\n",
    "        - 10 * (xx / 5 - xx**3 - yy**5) * torch.exp(-(xx**2) - yy**2)\n",
    "        - 1 / 3 * torch.exp(-((xx + 1) ** 2) - yy**2)\n",
    "    )\n",
    "\n",
    "\n",
    "def rastrigin(meshgrid: torch.Tensor, shift: int = 0) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    \"Rastrigin\" function with `A = 3`\n",
    "    https://en.wikipedia.org/wiki/Rastrigin_function\n",
    "\n",
    "    :params meshgrid: tensor of shape [..., 2], the (x, y) coordinates\n",
    "    \"\"\"\n",
    "    meshgrid = torch.as_tensor(meshgrid, dtype=torch.float)\n",
    "    xx = meshgrid[..., 0]\n",
    "    yy = meshgrid[..., 1]\n",
    "    A = 3\n",
    "    return A * 2 + (\n",
    "        ((xx - shift) ** 2 - A * torch.cos(2 * torch.tensor(math.pi, dtype=torch.float, device=xx.device) * xx))\n",
    "        + ((yy - shift) ** 2 - A * torch.cos(2 * torch.tensor(math.pi, dtype=torch.float, device=xx.device) * yy))\n",
    "    )\n",
    "\n",
    "\n",
    "def rosenbrock(meshgrid: torch.Tensor) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    \"Rosenbrock\" function\n",
    "    https://en.wikipedia.org/wiki/Rosenbrock_function\n",
    "\n",
    "    It has a global minimum at $(x , y) = (a, a^2) = (1, 1)$\n",
    "\n",
    "    :params meshgrid: tensor of shape [..., 2], the (x, y) coordinates\n",
    "    \"\"\"\n",
    "    meshgrid = torch.as_tensor(meshgrid, dtype=torch.float)\n",
    "    xx = meshgrid[..., 0]\n",
    "    yy = meshgrid[..., 1]\n",
    "\n",
    "    a = 1\n",
    "    b = 100\n",
    "    return (a - xx) ** 2 + b * (yy - xx**2) ** 2\n",
    "\n",
    "\n",
    "def simple_fn(meshgrid: torch.Tensor) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    :params meshgrid: tensor of shape [..., 2], the (x, y) coordinates\n",
    "    \"\"\"\n",
    "    meshgrid = torch.as_tensor(meshgrid, dtype=torch.float)\n",
    "    xx = meshgrid[..., 0]\n",
    "    yy = meshgrid[..., 1]\n",
    "\n",
    "    output = -1 / (1 + xx**2 + yy**2)\n",
    "\n",
    "    return output\n",
    "\n",
    "\n",
    "def simple_fn2(meshgrid: torch.Tensor) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    :params meshgrid: tensor of shape [..., 2], the (x, y) coordinates\n",
    "    \"\"\"\n",
    "    meshgrid = torch.as_tensor(meshgrid, dtype=torch.float)\n",
    "    xx = meshgrid[..., 0]\n",
    "    yy = meshgrid[..., 1]\n",
    "\n",
    "    output = (1 + xx**2 + yy**2) ** (1 / 2)\n",
    "\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce51c578",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALPHA = 1\n",
    "SAMPLE_MARKER = \".\"\n",
    "SAMPLE_MARKER_SIZE = 7.5\n",
    "\n",
    "\n",
    "GREEN = \"#3EB863\"\n",
    "PURPLE = \"#6a4c93\"\n",
    "RED = \"#CF294A\"\n",
    "BLUE = \"#275299\"\n",
    "\n",
    "\n",
    "anchors = np.asarray(\n",
    "    [\n",
    "        [-1.25, 0.5],\n",
    "        [0, 1.25],\n",
    "        [1.2, 0],\n",
    "    ]\n",
    ")\n",
    "\n",
    "A_w = 0.1\n",
    "B_w = 0.1\n",
    "C_w = 0.8\n",
    "assert A_w + B_w + C_w == 1\n",
    "\n",
    "point = A_w * anchors[0] + B_w * anchors[1] + C_w * anchors[2]\n",
    "\n",
    "colors = [\n",
    "    RED,\n",
    "    PURPLE,\n",
    "    GREEN,\n",
    "]\n",
    "sample_color = BLUE\n",
    "\n",
    "AXIS_OFF = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5acb846",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tueplots import bundles\n",
    "from tueplots import figsizes\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "from matplotlib.ticker import LinearLocator\n",
    "import numpy as np\n",
    "\n",
    "N_ROWS = 1\n",
    "N_COLS = 2\n",
    "RATIO = 1\n",
    "import matplotlib\n",
    "\n",
    "plt.style.use(\"default\")\n",
    "plt.rcParams.update(bundles.icml2022())\n",
    "plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=0.7))\n",
    "\n",
    "\n",
    "fig, [ax, proj] = plt.subplots(\n",
    "    N_ROWS,\n",
    "    N_COLS,\n",
    "    dpi=300,\n",
    "    sharey=False,\n",
    "    sharex=False,\n",
    "    subplot_kw={\"projection\": \"3d\"}\n",
    "    # constrained_layout=True\n",
    ")\n",
    "\n",
    "\n",
    "# Make data.\n",
    "X = np.arange(-2, 2, 0.01)\n",
    "Y = np.arange(-3, 3, 0.01)\n",
    "X, Y = np.meshgrid(X, Y)\n",
    "meshgrid = np.stack((X, Y), -1)\n",
    "\n",
    "\n",
    "# Plot surface\n",
    "Z = peaks(meshgrid)\n",
    "COUNT = 100\n",
    "surf = ax.plot_surface(\n",
    "    X, Y, Z, cmap=cm.coolwarm, alpha=ALPHA, linewidth=1, antialiased=True, shade=True, rcount=COUNT, ccount=COUNT\n",
    ")\n",
    "\n",
    "\n",
    "anchors = np.concatenate((anchors, peaks(anchors)[:, None].numpy()), axis=-1)\n",
    "point = np.asarray([*point, peaks(point).item()])\n",
    "\n",
    "# Plot anchors\n",
    "for anchor, symbol, color in zip(anchors, [\"*\", \"*\", \"*\"], colors):\n",
    "    ax.plot(\n",
    "        anchor[..., 0], anchor[..., 1], anchor[..., 2], c=color, marker=symbol, zorder=10, alpha=1, antialiased=True\n",
    "    )\n",
    "\n",
    "# Plot sample\n",
    "ax.plot(\n",
    "    [point[0]],\n",
    "    [point[1]],\n",
    "    [point[2]],\n",
    "    c=sample_color,\n",
    "    markersize=SAMPLE_MARKER_SIZE,\n",
    "    marker=SAMPLE_MARKER,\n",
    "    zorder=10,\n",
    "    antialiased=True,\n",
    ")\n",
    "\n",
    "# Plot anchors lines\n",
    "for anchor, color in zip(anchors, colors):\n",
    "    ax.plot(\n",
    "        [point[0], anchor[0]],\n",
    "        [point[1], anchor[1]],\n",
    "        [point[2], anchor[2]],\n",
    "        c=color,\n",
    "        markersize=0,\n",
    "        zorder=8,\n",
    "        linewidth=1,\n",
    "        linestyle=\"--\",\n",
    "        antialiased=True,\n",
    "    )\n",
    "\n",
    "# Plot relative axis\n",
    "anchors_dists = cdist(anchors, point[None]).squeeze()\n",
    "for dist_ax, color in zip(\n",
    "    (\n",
    "        ([0, anchors_dists[0]], [0, 0], [0, 0]),\n",
    "        ([0, 0], [0, anchors_dists[1]], [0, 0]),\n",
    "        ([0, 0], [0, 0], [0, anchors_dists[2]]),\n",
    "    ),\n",
    "    colors,\n",
    "):\n",
    "    proj.plot(*dist_ax, c=color, markersize=0, zorder=8, linewidth=2, linestyle=\"--\", antialiased=True)\n",
    "\n",
    "# Plot anchors axis ends\n",
    "for axis_end, symbol, color, zorder in zip(\n",
    "    (([anchors_dists[0]], [0], [0]), ([0], [anchors_dists[1]], [0]), ([0], [0], [anchors_dists[2]])),\n",
    "    [\"*\", \"*\", \"*\"],\n",
    "    colors,\n",
    "    (11, 9.5, 11),\n",
    "):\n",
    "    proj.plot(*axis_end, c=color, marker=symbol, markersize=10, zorder=zorder, alpha=1, antialiased=True)\n",
    "\n",
    "# Plot sample\n",
    "proj.plot(\n",
    "    [anchors_dists[0]],\n",
    "    [anchors_dists[1]],\n",
    "    [anchors_dists[2]],\n",
    "    c=sample_color,\n",
    "    markersize=SAMPLE_MARKER_SIZE,\n",
    "    marker=SAMPLE_MARKER,\n",
    "    zorder=zorder,\n",
    "    antialiased=True,\n",
    ")\n",
    "\n",
    "# Plot cube\n",
    "for lines, zorder in zip(\n",
    "    (\n",
    "        (\n",
    "            [anchors_dists[0], anchors_dists[0], anchors_dists[0], 0, 0, anchors_dists[0], anchors_dists[0]],\n",
    "            [0, anchors_dists[1], anchors_dists[1], anchors_dists[1], 0, 0, anchors_dists[1]],\n",
    "            [0, 0, anchors_dists[2], anchors_dists[2], anchors_dists[2], anchors_dists[2], anchors_dists[2]],\n",
    "        ),\n",
    "        (\n",
    "            [anchors_dists[0], anchors_dists[0]],\n",
    "            [\n",
    "                0,\n",
    "                0,\n",
    "            ],\n",
    "            [0, anchors_dists[2]],\n",
    "        ),\n",
    "        (\n",
    "            [anchors_dists[0], anchors_dists[0]],\n",
    "            [\n",
    "                0,\n",
    "                0,\n",
    "            ],\n",
    "            [0, anchors_dists[2]],\n",
    "        ),\n",
    "        ([0, anchors_dists[0]], [anchors_dists[1], anchors_dists[1]], [0, 0]),\n",
    "        ([0, 0], [anchors_dists[1], anchors_dists[1]], [0, anchors_dists[2]]),\n",
    "    ),\n",
    "    (10, 9, 9, 9, 9),\n",
    "):\n",
    "    proj.plot(*lines, c=sample_color, linestyle=\"--\", linewidth=0.5, zorder=zorder, alpha=1, antialiased=True)\n",
    "\n",
    "\n",
    "# proj.set_aspect('auto')\n",
    "proj.set_box_aspect((anchors_dists[0], anchors_dists[1], anchors_dists[2]))  # aspect ratio is 1:1:1 in data space\n",
    "\n",
    "proj.set_xlim3d(0, anchors_dists[0] + 0.1)\n",
    "proj.set_ylim3d(0, anchors_dists[1] + 0.1)\n",
    "proj.set_zlim3d(0, anchors_dists[2] + 0.1)\n",
    "proj.view_init(elev=17.0, azim=-50)\n",
    "ax.view_init(elev=40.0, azim=200)\n",
    "\n",
    "\n",
    "if AXIS_OFF:\n",
    "    ax.axis(\"off\")\n",
    "    proj.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9834150",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(\"teaser.svg\", bbox_inches=\"tight\", pad_inches=0)\n",
    "!rsvg-convert -f pdf -o teaser.pdf teaser.svg\n",
    "!rm teaser.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49dd3af3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# box = proj.get_position()\n",
    "# proj.set_position([box.x0, box.y0, box.x1, box.y1])\n",
    "# for axis in [proj.xaxis, proj.yaxis, proj.zaxis]:\n",
    "#     axis.set_ticklabels([])\n",
    "#     axis._axinfo['axisline']['linewidth'] = 1\n",
    "#     axis._axinfo['axisline']['color'] = (0, 0, 0)\n",
    "#     axis._axinfo['grid']['linewidth'] = 0.25\n",
    "#     axis._axinfo['grid']['linestyle'] = \"-\"\n",
    "#     axis._axinfo['grid']['color'] = (0, 0, 0)\n",
    "#     axis._axinfo['tick']['inward_factor'] = 0.0\n",
    "#     axis._axinfo['tick']['outward_factor'] = 0.0\n",
    "#     axis.set_pane_color((0.95, 0.95, 0.95))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
