{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5af836d2-1283-4ddc-9f9a-f39da68614a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import random\n",
    "import string\n",
    "import uuid\n",
    "import shutil\n",
    "\n",
    "from typing import Optional, Union\n",
    "from pprint import pprint\n",
    "\n",
    "import configargparse\n",
    "import sys\n",
    "from contextlib import contextmanager, redirect_stderr, redirect_stdout\n",
    "import torch.nn as nn\n",
    "\n",
    "from bo import gp_initialize_model, gp_optimize_acqf_and_get_observation\n",
    "import numpy as np\n",
    "\n",
    "@contextmanager\n",
    "def suppress_output():\n",
    "    \"\"\"\n",
    "        A context manager that redirects stdout and stderr to devnull\n",
    "        https://stackoverflow.com/a/52442331\n",
    "    \"\"\"\n",
    "    with open(os.devnull, 'w') as fnull:\n",
    "        with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:\n",
    "            yield (err, out)\n",
    "\n",
    "with suppress_output():\n",
    "    import design_bench\n",
    "\n",
    "    from design_bench.datasets.discrete.tf_bind_8_dataset import TFBind8Dataset\n",
    "    from design_bench.datasets.discrete.tf_bind_10_dataset import TFBind10Dataset\n",
    "    from design_bench.datasets.discrete.cifar_nas_dataset import CIFARNASDataset\n",
    "    from design_bench.datasets.discrete.chembl_dataset import ChEMBLDataset\n",
    "\n",
    "    from design_bench.datasets.continuous.ant_morphology_dataset import AntMorphologyDataset\n",
    "    from design_bench.datasets.continuous.dkitty_morphology_dataset import DKittyMorphologyDataset\n",
    "    from design_bench.datasets.continuous.superconductor_dataset import SuperconductorDataset\n",
    "    # from design_bench.datasets.continuous.hopper_controller_dataset import HopperControllerDataset\n",
    "\n",
    "from util import TASKNAME2TASK, configure_gpu, set_seed, get_weights\n",
    "\n",
    "task = design_bench.make(TASKNAME2TASK['dkitty'])\n",
    "# dataset = task.dataset\n",
    "original_x = task.x\n",
    "original_y = task.y\n",
    "\n",
    "task_pred = task.predict(task.x)\n",
    "\n",
    "print(np.abs(original_y - task_pred).max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4f0dbc97-c914-4d44-ae3f-77512b9007bb",
   "metadata": {},
   "outputs": [
    {
     "ename": "Exception",
     "evalue": "\nMissing path to your environment variable. \nCurrent values LD_LIBRARY_PATH=\nPlease add following line to .bashrc:\nexport LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/ubuntu/.mujoco/mujoco210/bin",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mException\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-03f18ad6be87>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdesign_bench\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTASKNAME2TASK\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'dkitty'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;31m# dataset = task.dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0moriginal_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0moriginal_y\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/design_bench/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(task_name, dataset_kwargs, oracle_kwargs, **kwargs)\u001b[0m\n\u001b[1;32m    326\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    327\u001b[0m     return registry.make(task_name, dataset_kwargs=dataset_kwargs,\n\u001b[0;32m--> 328\u001b[0;31m                          oracle_kwargs=oracle_kwargs, **kwargs)\n\u001b[0m\u001b[1;32m    329\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/design_bench/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(self, task_name, dataset_kwargs, oracle_kwargs, **kwargs)\u001b[0m\n\u001b[1;32m    155\u001b[0m         return self.spec(task_name).make(\n\u001b[1;32m    156\u001b[0m             \u001b[0mdataset_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdataset_kwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m             oracle_kwargs=oracle_kwargs, **kwargs)\n\u001b[0m\u001b[1;32m    158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    159\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/design_bench/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(self, dataset_kwargs, oracle_kwargs, **kwargs)\u001b[0m\n\u001b[1;32m    109\u001b[0m         return Task(self.dataset, self.oracle,\n\u001b[1;32m    110\u001b[0m                     \u001b[0mdataset_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdataset_kwargs_final\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m                     oracle_kwargs=oracle_kwargs_final, **kwargs)\n\u001b[0m\u001b[1;32m    112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    113\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/design_bench/task.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, dataset, oracle, dataset_kwargs, oracle_kwargs, relabel)\u001b[0m\n\u001b[1;32m    261\u001b[0m         \u001b[0;31m# if self.entry_point is a string import it first\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    262\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moracle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m             \u001b[0moracle\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimport_name\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moracle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    265\u001b[0m         \u001b[0;31m# return if the oracle could not be loaded\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/design_bench/task.py\u001b[0m in \u001b[0;36mimport_name\u001b[0;34m(name)\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mimport_name\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m     \u001b[0mmod_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattr_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\":\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattr_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/importlib/__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m    125\u001b[0m                 \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    126\u001b[0m             \u001b[0mlevel\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/design_bench/oracles/exact/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mhopper_controller_oracle\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mHopperControllerOracle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mant_morphology_oracle\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAntMorphologyOracle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mdkitty_morphology_oracle\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDKittyMorphologyOracle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mtoy_continuous_oracle\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mToyContinuousOracle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mnas_bench_oracle\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mNASBenchOracle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/design_bench/oracles/exact/ant_morphology_oracle.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mmorphing_agents\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmujoco\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mant\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mMorphingAntEnv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmorphing_agents\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmujoco\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mant\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0melements\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLEG\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmorphing_agents\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmujoco\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mant\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0melements\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLEG_LOWER_BOUND\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmorphing_agents\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmujoco\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mant\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0melements\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLEG_UPPER_BOUND\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mdesign_bench\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moracles\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexact_oracle\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mExactOracle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/morphing_agents/mujoco/ant/env.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmorphing_agents\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmujoco\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mant\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdesigns\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnormalize_design_vector\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmujoco\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmujoco_env\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtempfile\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/gym/envs/mujoco/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmujoco\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmujoco_env\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mMujocoEnv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;31m# ^^^^^ so that user gets the correct error\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;31m# message if mujoco is not installed correctly\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmujoco\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mant\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAntEnv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmujoco\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhalf_cheetah\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mHalfCheetahEnv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/gym/envs/mujoco/mujoco_env.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m     \u001b[0;32mimport\u001b[0m \u001b[0mmujoco_py\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0;32mraise\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDependencyNotInstalled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/mujoco_py/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m#!/usr/bin/env python\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mmujoco_py\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuilder\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcymj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_mujoco_warnings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunctions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMujocoException\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmujoco_py\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerated\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mconst\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmujoco_py\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmjrenderpool\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mMjRenderPool\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmujoco_py\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmjviewer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mMjViewer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMjViewerBasic\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/mujoco_py/builder.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m    502\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    503\u001b[0m \u001b[0mmujoco_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdiscover_mujoco\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 504\u001b[0;31m \u001b[0mcymj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_cython_ext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmujoco_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    505\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    506\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/mujoco_py/builder.py\u001b[0m in \u001b[0;36mload_cython_ext\u001b[0;34m(mujoco_path)\u001b[0m\n\u001b[1;32m     72\u001b[0m         \u001b[0mBuilder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMacExtensionBuilder\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     73\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplatform\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'linux'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m         \u001b[0m_ensure_set_env_var\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"LD_LIBRARY_PATH\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlib_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     75\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetenv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'MUJOCO_PY_FORCE_CPU'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mget_nvidia_lib_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m             \u001b[0m_ensure_set_env_var\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"LD_LIBRARY_PATH\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_nvidia_lib_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/D4MO/lib/python3.7/site-packages/mujoco_py/builder.py\u001b[0m in \u001b[0;36m_ensure_set_env_var\u001b[0;34m(var_name, lib_path)\u001b[0m\n\u001b[1;32m    122\u001b[0m                         \u001b[0;34m\"Please add following line to .bashrc:\\n\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    123\u001b[0m                         \"export %s=$%s:%s\" % (var_name, os.environ.get(var_name, \"\"),\n\u001b[0;32m--> 124\u001b[0;31m                                               var_name, var_name, lib_path))\n\u001b[0m\u001b[1;32m    125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    126\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mException\u001b[0m: \nMissing path to your environment variable. \nCurrent values LD_LIBRARY_PATH=\nPlease add following line to .bashrc:\nexport LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/ubuntu/.mujoco/mujoco210/bin"
     ]
    }
   ],
   "source": [
    "task = design_bench.make(TASKNAME2TASK['dkitty'])\n",
    "# dataset = task.dataset\n",
    "original_x = task.x\n",
    "original_y = task.y\n",
    "\n",
    "task_pred = task.predict(task.x)\n",
    "\n",
    "print(np.abs(original_y - task_pred).max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 222,
   "id": "12a6ff0d-6fa9-4b52-95c0-6103b00889bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[383300.],\n",
       "       [374300.],\n",
       "       [325200.],\n",
       "       [371700.],\n",
       "       [375700.],\n",
       "       [336000.],\n",
       "       [367000.],\n",
       "       [374300.],\n",
       "       [370700.],\n",
       "       [331700.],\n",
       "       [308000.],\n",
       "       [378700.],\n",
       "       [364300.],\n",
       "       [379700.],\n",
       "       [364200.],\n",
       "       [377000.],\n",
       "       [376500.],\n",
       "       [373300.],\n",
       "       [370000.],\n",
       "       [374700.]], dtype=float32)"
      ]
     },
     "execution_count": 222,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task.y[:20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 223,
   "id": "ee3e56c6-a803-42f7-b9f0-75aba909e06e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[379929.66],\n",
       "       [376227.78],\n",
       "       [327299.66],\n",
       "       [378684.28],\n",
       "       [374106.12],\n",
       "       [335965.25],\n",
       "       [371553.62],\n",
       "       [388616.1 ],\n",
       "       [370014.94],\n",
       "       [344735.  ],\n",
       "       [318501.4 ],\n",
       "       [384000.38],\n",
       "       [367647.72],\n",
       "       [355881.28],\n",
       "       [369990.25],\n",
       "       [377041.2 ],\n",
       "       [380548.03],\n",
       "       [380548.03],\n",
       "       [377046.  ],\n",
       "       [378684.28]], dtype=float32)"
      ]
     },
     "execution_count": 223,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task_pred[:20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "id": "dbc17c71-1d70-4c8d-bab7-864b31121bc3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(17014, 86)\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "0.0\n",
      "[ 14.        0.        3.       40.      105.      120.       12.8\n",
      "  66.        4.        0.        4.       12.       99.925   100.\n",
      "  20.       15.        3.        0.        3.3      24.        5.\n",
      "  75.       79.5      34.9      14.       30.       35.38     45.\n",
      "  98.       20.       41.       46.       18.       19.        5.\n",
      "   0.        4.       11.        9.       96.71     99.976    99.992\n",
      "   6.       64.       45.       50.99745   7.       99.995    31.5\n",
      "  99.2      83.5      66.7       4.        0.        3.       24.\n",
      "  98.        4.998   185.        6.        0.        6.        6.\n",
      "   4.        5.        5.        5.        5.        5.       16.\n",
      "   7.       25.       55.       14.       97.24     10.       45.\n",
      "   5.8      64.        8.        7.       19.       14.        0.\n",
      "   0.        0.     ]\n",
      "185.0\n",
      "0.11304927\n"
     ]
    }
   ],
   "source": [
    "print(original_x.shape)\n",
    "print(np.min(original_x, 0))\n",
    "print(np.min(original_x))\n",
    "print(np.max(original_x, 0))\n",
    "print(np.max(original_x))\n",
    "print(np.mean(original_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 186,
   "id": "4767032a-c9a0-4e84-84bf-9d1798c43ad3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(17014, 1)\n",
      "0.00021\n",
      "74.0\n"
     ]
    }
   ],
   "source": [
    "print(original_y.shape)\n",
    "print(np.min(original_y))\n",
    "print(np.max(original_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 187,
   "id": "dc3e95a9-d979-4f0f-9c9d-00998160d3a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "task.map_normalize_x()\n",
    "task.map_normalize_y()\n",
    "\n",
    "normalize_x = task.x\n",
    "normalize_y = task.y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "id": "447c0ee9-1863-4eaa-b969-2256318bff61",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(17014, 86)\n",
      "[-0.07024373  0.         -0.10082777 -0.04498985 -0.15216358 -0.09674848\n",
      " -0.0981552  -0.65044296 -0.1137692   0.         -0.09483302 -0.10602219\n",
      " -0.06090649 -0.09574476 -0.06700765 -0.15591976 -0.08380999  0.\n",
      " -0.12770317 -0.18316641 -0.06530177 -0.06419145 -0.07353079 -0.02548972\n",
      " -0.02559821 -0.24043933 -0.06680049 -0.10110299 -0.4243231  -0.03529252\n",
      " -0.0737289  -0.09015684 -0.16165958 -0.13020928 -0.05178379  0.\n",
      " -0.07157228 -0.39027238 -0.3325495  -0.08561306 -0.10208447 -0.07839008\n",
      " -0.03957614 -0.08032851 -0.07569787 -0.06119504 -0.0478801  -0.01408633\n",
      " -0.10550281 -0.0716495  -0.06149168 -0.0630479  -0.05562524  0.\n",
      " -0.05591072 -0.4086696  -0.12389437 -0.19801255 -0.03464302 -0.17791522\n",
      "  0.         -0.13422918 -0.1170507  -0.14132783 -0.04679586 -0.08084951\n",
      " -0.08052266 -0.09088819 -0.05988756 -0.04789465 -0.10239963 -0.04847556\n",
      " -0.04739527 -0.06958839 -0.03596339 -0.08923171 -0.07962246 -0.12401104\n",
      " -0.0317126  -0.11319014 -0.13223165 -0.1410988  -0.26108015  0.\n",
      "  0.          0.        ]\n",
      "-0.65044296\n",
      "[ 47.11161    0.        20.799906  42.151882  90.05564   24.498062\n",
      "  76.15242   18.280373  32.964928   0.        35.35371   39.857002\n",
      "  79.3407    40.285736  38.29282   17.526285  22.831608   0.\n",
      "  21.356155  29.51336   24.044132  24.538132  20.806587 122.93493\n",
      "  96.91952   37.69575   54.499725  40.940025  47.02694   44.72745\n",
      "  32.83997   40.236122  14.841154  25.044083  55.09999    0.\n",
      "  29.460804  15.778523  21.691418  17.77785   18.363543  42.86249\n",
      "  82.8936    74.28762   39.96561   29.305563  37.662315 129.91223\n",
      "  53.98509   46.98005   40.57018   83.06526   41.659317   0.\n",
      "  35.05261   27.58572   37.73351   25.734581 129.0817    24.917873\n",
      "   0.        34.29658   37.82303   26.455982  69.75858   49.93417\n",
      "  52.320164  39.930325  36.871136  84.69465   22.736897 107.01295\n",
      "  57.75397   76.24952   73.848755  31.62758   46.4943    16.753288\n",
      "  79.71986   46.67039   31.958769  65.28541   23.77732    0.\n",
      "   0.         0.      ]\n",
      "129.91223\n",
      "-6.4238574e-09\n"
     ]
    }
   ],
   "source": [
    "print(normalize_x.shape)\n",
    "print(np.min(normalize_x, 0))\n",
    "print(np.min(normalize_x))\n",
    "print(np.max(normalize_x, 0))\n",
    "print(np.max(normalize_x))\n",
    "print(np.mean(normalize_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 189,
   "id": "c85f5c3b-73e3-4eeb-81b1-e4aa65a6a01a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(17014, 1)\n",
      "-0.9961795\n",
      "2.6516874\n"
     ]
    }
   ],
   "source": [
    "print(normalize_y.shape)\n",
    "print(np.min(normalize_y))\n",
    "print(np.max(normalize_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 213,
   "id": "2cbbc769-b44c-4050-ace2-bd3adb543ded",
   "metadata": {},
   "outputs": [],
   "source": [
    "# task.map_normalize_x()\n",
    "# task.map_normalize_y()\n",
    "task.map_denormalize_x()\n",
    "task.map_denormalize_y()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 214,
   "id": "e1df5bf9-a82b-4ccd-ac68-7c746329cd42",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(17014, 1) (17014, 1)\n"
     ]
    }
   ],
   "source": [
    "pred_original_x = task.predict(original_x)\n",
    "print(original_y.shape, pred_original_x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 215,
   "id": "aee3cf9a-a33b-41c5-ba2c-d69836743182",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[16.639427],\n",
       "       [17.297699],\n",
       "       [16.72577 ],\n",
       "       [15.684952],\n",
       "       [15.722851],\n",
       "       [32.62069 ],\n",
       "       [52.32652 ],\n",
       "       [16.783018],\n",
       "       [17.920729],\n",
       "       [17.871645]], dtype=float32)"
      ]
     },
     "execution_count": 215,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_original_x[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 216,
   "id": "cbacba22-ea33-42aa-8442-053964a6f059",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[29.],\n",
       "       [26.],\n",
       "       [19.],\n",
       "       [22.],\n",
       "       [23.],\n",
       "       [23.],\n",
       "       [11.],\n",
       "       [33.],\n",
       "       [36.],\n",
       "       [31.]], dtype=float32)"
      ]
     },
     "execution_count": 216,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task.y[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "id": "45bf9fbe-3a77-4146-b0d1-9c594ff652fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(17014, 1) (17014, 1) (17014, 1) (17014, 1)\n"
     ]
    }
   ],
   "source": [
    "pred_original_x = task.predict(original_x)\n",
    "pred_normalize_x = task.predict(normalize_x)\n",
    "print(original_y.shape, normalize_y.shape, pred_original_x.shape, pred_normalize_x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "id": "7f170c61-97a0-4859-9d9b-8d6252175f4b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.2912741],\n",
       "       [1.2249496],\n",
       "       [1.2282903],\n",
       "       [1.2282903],\n",
       "       [1.2912741],\n",
       "       [1.2912741],\n",
       "       [1.2898004],\n",
       "       [1.2282903],\n",
       "       [1.2291747],\n",
       "       [1.2291747]], dtype=float32)"
      ]
     },
     "execution_count": 202,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_original_x[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 199,
   "id": "73436c48-c1d9-4612-8ff6-55a206a9f727",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n"
     ]
    }
   ],
   "source": [
    "print(np.array_equal(pred_original_x, original_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "id": "3e0c5b72-06bc-4459-875b-4b396269c812",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "73.40921\n",
      "3.8703809\n"
     ]
    }
   ],
   "source": [
    "print(np.abs(pred_original_x-original_y).max())\n",
    "print(np.abs(pred_normalize_x-normalize_y).max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "id": "efa05611-bc9c-4c10-be02-4143e366a935",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[29.]\n",
      " [26.]\n",
      " [19.]\n",
      " [22.]\n",
      " [23.]\n",
      " [23.]\n",
      " [11.]\n",
      " [33.]\n",
      " [36.]\n",
      " [31.]]\n",
      "[[1.3019421]\n",
      " [1.1940252]\n",
      " [1.2165918]\n",
      " [1.2165918]\n",
      " [1.3019421]\n",
      " [1.300467 ]\n",
      " [1.3346857]\n",
      " [1.2292272]\n",
      " [1.2292272]\n",
      " [1.2301121]]\n"
     ]
    }
   ],
   "source": [
    "print(original_y[:10])\n",
    "print(pred_original_x[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "id": "17bec5a0-1ffa-4c27-9b1d-debddfca7bbf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.43338364]\n",
      " [ 0.28549674]\n",
      " [-0.05957274]\n",
      " [ 0.08831418]\n",
      " [ 0.13760981]\n",
      " [ 0.13760981]\n",
      " [-0.45393786]\n",
      " [ 0.6305662 ]\n",
      " [ 0.7784531 ]\n",
      " [ 0.5319749 ]]\n",
      "[[-0.2091926 ]\n",
      " [-0.19791245]\n",
      " [-0.1821518 ]\n",
      " [-0.2091926 ]\n",
      " [-0.16251968]\n",
      " [ 0.64472604]\n",
      " [ 1.5432149 ]\n",
      " [-0.15260701]\n",
      " [-0.13735726]\n",
      " [-0.13988246]]\n"
     ]
    }
   ],
   "source": [
    "print(normalize_y[:10])\n",
    "print(pred_normalize_x[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "id": "6d055af6-f54f-42d0-8527-3dffcf1bafce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "False\n",
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(np.array_equal(task.denormalize_y(normalize_y), original_y))\n",
    "print(np.array_equal(task.denormalize_x(normalize_x), original_x))\n",
    "\n",
    "print(np.array_equal(task.normalize_y(original_y), normalize_y))\n",
    "print(np.array_equal(task.normalize_x(original_x), normalize_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "65952e9c-df54-4c16-b0ea-51a151b26166",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.6293945e-06\n",
      "1.5258789e-05\n",
      "0.0\n",
      "0.0\n"
     ]
    }
   ],
   "source": [
    "print(np.abs(task.denormalize_y(normalize_y)-original_y).max())\n",
    "print(np.abs(task.denormalize_x(normalize_x)-original_x).max())\n",
    "\n",
    "print(np.abs(task.normalize_y(original_y)-normalize_y).max())\n",
    "print(np.abs(task.normalize_x(original_x)-normalize_x).max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "688875cd-1af4-46a9-9aee-e2815fb95741",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[29.],\n",
       "       [26.],\n",
       "       [19.],\n",
       "       [22.],\n",
       "       [23.],\n",
       "       [23.],\n",
       "       [11.],\n",
       "       [33.],\n",
       "       [36.],\n",
       "       [31.]], dtype=float32)"
      ]
     },
     "execution_count": 125,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task.denormalize_y(normalize_y)[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "id": "6914b722-fe96-456b-9224-eee4c4ea7017",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[29.],\n",
       "       [26.],\n",
       "       [19.],\n",
       "       [22.],\n",
       "       [23.],\n",
       "       [23.],\n",
       "       [11.],\n",
       "       [33.],\n",
       "       [36.],\n",
       "       [31.]], dtype=float32)"
      ]
     },
     "execution_count": 126,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "original_y[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c28fa09-7902-466f-a456-6917de006c07",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "D4MO",
   "language": "python",
   "name": "d4mo"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
