{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "15ffba0a-07db-41ff-9cb6-9cee403fe594",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:57\u001b[0m, in \u001b[0;36m_wrapfunc\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m     56\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 57\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbound\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     58\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m     59\u001b[0m     \u001b[38;5;66;03m# A TypeError occurs if the object does have such a method in its\u001b[39;00m\n\u001b[1;32m     60\u001b[0m     \u001b[38;5;66;03m# class, but its signature is not identical to that of NumPy's. This\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     64\u001b[0m     \u001b[38;5;66;03m# Call _wrapit from within the except clause to ensure a potential\u001b[39;00m\n\u001b[1;32m     65\u001b[0m     \u001b[38;5;66;03m# exception has a traceback chain.\u001b[39;00m\n",
      "\u001b[0;31mTypeError\u001b[0m: repeat() got an unexpected keyword argument 'axis'",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[10], line 13\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[38;5;66;03m# c = problem.targets[torch.randperm(len(problem.targets))[:N]]       # randomly selected protein targets\u001b[39;00m\n\u001b[1;32m     12\u001b[0m c \u001b[38;5;241m=\u001b[39m problem\u001b[38;5;241m.\u001b[39mtargets[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 13\u001b[0m c \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrepeat\u001b[49m\u001b[43m(\u001b[49m\u001b[43mc\u001b[49m\u001b[43m[\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnewaxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     14\u001b[0m \u001b[38;5;28mprint\u001b[39m(c\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m     15\u001b[0m x \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat((a, c), dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)        \u001b[38;5;66;03m# input to the surrogate model\u001b[39;00m\n",
      "File \u001b[0;32m<__array_function__ internals>:200\u001b[0m, in \u001b[0;36mrepeat\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:479\u001b[0m, in \u001b[0;36mrepeat\u001b[0;34m(a, repeats, axis)\u001b[0m\n\u001b[1;32m    436\u001b[0m \u001b[38;5;129m@array_function_dispatch\u001b[39m(_repeat_dispatcher)\n\u001b[1;32m    437\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrepeat\u001b[39m(a, repeats, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m    438\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    439\u001b[0m \u001b[38;5;124;03m    Repeat elements of an array.\u001b[39;00m\n\u001b[1;32m    440\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    477\u001b[0m \n\u001b[1;32m    478\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 479\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_wrapfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrepeat\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrepeats\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:66\u001b[0m, in \u001b[0;36m_wrapfunc\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m     57\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m bound(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[1;32m     58\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m     59\u001b[0m     \u001b[38;5;66;03m# A TypeError occurs if the object does have such a method in its\u001b[39;00m\n\u001b[1;32m     60\u001b[0m     \u001b[38;5;66;03m# class, but its signature is not identical to that of NumPy's. This\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     64\u001b[0m     \u001b[38;5;66;03m# Call _wrapit from within the except clause to ensure a potential\u001b[39;00m\n\u001b[1;32m     65\u001b[0m     \u001b[38;5;66;03m# exception has a traceback chain.\u001b[39;00m\n\u001b[0;32m---> 66\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_wrapit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:43\u001b[0m, in \u001b[0;36m_wrapit\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m     41\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n\u001b[1;32m     42\u001b[0m     wrap \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m---> 43\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m, method)(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[1;32m     44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m wrap:\n\u001b[1;32m     45\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(result, mu\u001b[38;5;241m.\u001b[39mndarray):\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/_tensor.py:1030\u001b[0m, in \u001b[0;36mTensor.__array__\u001b[0;34m(self, dtype)\u001b[0m\n\u001b[1;32m   1028\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(Tensor\u001b[38;5;241m.\u001b[39m__array__, (\u001b[38;5;28mself\u001b[39m,), \u001b[38;5;28mself\u001b[39m, dtype\u001b[38;5;241m=\u001b[39mdtype)\n\u001b[1;32m   1029\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1030\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1031\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1032\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnumpy()\u001b[38;5;241m.\u001b[39mastype(dtype, copy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
      "\u001b[0;31mTypeError\u001b[0m: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first."
     ]
    }
   ],
   "source": [
    "from problem import *\n",
    "import torch\n",
    "\n",
    "\n",
    "problem = get_problem('molecular_discovery')\n",
    "# problem.targets : stores embeddings for each protein target, of which there are 102\n",
    "# problem.objectives : names of each of the 10 objectives\n",
    "\n",
    "N = 10      # batch size\n",
    "a = torch.randn(N, 32).cuda()       # compounds randomly sampled from latent space\n",
    "# c = problem.targets[torch.randperm(len(problem.targets))[:N]]       # randomly selected protein targets\n",
    "c = problem.targets[0]\n",
    "c = np.repeat(c[np.newaxis, :], 10, axis=0)\n",
    "print(c.shape)\n",
    "x = torch.cat((a, c), dim=1)        # input to the surrogate model\n",
    "\n",
    "print(problem.evaluate(x, 0).shape)     # output shape: (N, 10), each objective is in the range [0, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "83c42fb4-245a-492d-87ef-66a560ee78f3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.5156, 0.4910, 0.3313, 0.4172], device='cuda:0',\n",
       "       grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "problem.evaluate(x, 0)[0,:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e98dea57-ff1b-45a4-8bb6-4e6bd3022e2a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4172, device='cuda:0', grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "problem.evaluate(x, 0)[0,3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "035a5748-077c-418c-90a8-f3fa7f71232a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.6164, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.2750, device='cuda:0', grad_fn=<MinBackward1>)\n"
     ]
    }
   ],
   "source": [
    "print(problem.evaluate(x, 0)[:,3].max())\n",
    "print(problem.evaluate(x, 0)[:,3].min())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b71fdc5e-e5e4-4e60-bdab-e84533a12cea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([102, 64])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "problem.targets.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1614ac30-503a-4c54-a8ef-7b327bd0c83d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10, 64])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "204b452b-30b2-4c8b-b64c-2dcb6c29fe3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "problem.targets.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
