{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f32e2ce2",
   "metadata": {},
   "source": [
    "## Train your own models\n",
    "To train you simply need to call `train.train`.\n",
    "We give all necessary code. The most important bits are in the `priors` dir, e.g. `hebo_prior`, it stores the priors\n",
    "with which we train our models.\n",
    "\n",
    "### Training the HEBO+ model, `model_hebo_morebudget_9_unused_features_3.pt`\n",
    "You can train this model on 8 GPUs using `torchrun` or `submitit`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3d41a7e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4949ecba",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/anon/mtpfn/PFNs/pfns/priors/utils.py:293: FutureWarning: Decorating classes is deprecated and will be disabled in future versions. You should only decorate functions or methods. To preserve the current behavior of class decoration, you can directly decorate the `__init__` method and nothing else.\n",
      "  @torch.no_grad()\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from pfns import priors, encoders, utils, bar_distribution, train\n",
    "from ConfigSpace import hyperparameters as CSH\n",
    "\n",
    "config_heboplus = {\n",
    "     'priordataloader_class_or_get_batch': priors.get_batch_to_dataloader(\n",
    "         priors.get_batch_sequence(\n",
    "             priors.hebo_prior.get_batch,\n",
    "             priors.utils.sample_num_features_get_batch,\n",
    "         )\n",
    "     ),\n",
    "     'encoder_generator': encoders.get_normalized_uniform_encoder(encoders.get_variable_num_features_encoder(encoders.Linear)),\n",
    "     'emsize': 512,\n",
    "     'nhead': 4,\n",
    "     'warmup_epochs': 5,\n",
    "     'y_encoder_generator': encoders.Linear,\n",
    "     'batch_size': 128,\n",
    "     'scheduler': utils.get_cosine_schedule_with_warmup,\n",
    "     'extra_prior_kwargs_dict': {'num_features': 18,\n",
    "      'hyperparameters': {\n",
    "       'lengthscale_concentration': 1.2106559584074301,\n",
    "       'lengthscale_rate': 1.5212245992840594,\n",
    "       'outputscale_concentration': 0.8452312502679863,\n",
    "       'outputscale_rate': 0.3993553245745406,\n",
    "       'add_linear_kernel': False,\n",
    "       'power_normalization': False,\n",
    "       'hebo_warping': False,\n",
    "       'unused_feature_likelihood': 0.3,\n",
    "       'observation_noise': True}},\n",
    "     'epochs': 50,\n",
    "     'lr': 0.0001,\n",
    "     'seq_len': 60,\n",
    "     'single_eval_pos_gen': utils.get_uniform_single_eval_pos_sampler(50, min_len=1), #<function utils.get_uniform_single_eval_pos_sampler.<locals>.<lambda>()>,\n",
    "     'aggregate_k_gradients': 2,\n",
    "     'nhid': 1024,\n",
    "     'steps_per_epoch': 1024,\n",
    "     'weight_decay': 0.0,\n",
    "     'train_mixed_precision': False,\n",
    "     'efficient_eval_masking': True,\n",
    "     'nlayers': 12}\n",
    "\n",
    "\n",
    "config_heboplus_userpriors = {**config_heboplus,\n",
    "    'priordataloader_class_or_get_batch': priors.get_batch_to_dataloader(\n",
    "                              priors.get_batch_sequence(\n",
    "                                  priors.hebo_prior.get_batch,\n",
    "                                  priors.condition_on_area_of_opt.get_batch,\n",
    "                                  priors.utils.sample_num_features_get_batch,\n",
    "                              )),\n",
    "    'style_encoder_generator': encoders.get_normalized_uniform_encoder(encoders.get_variable_num_features_encoder(encoders.Linear))\n",
    "}\n",
    "\n",
    "config_bnn = {'priordataloader_class_or_get_batch': priors.get_batch_to_dataloader(\n",
    "         priors.get_batch_sequence(\n",
    "             priors.simple_mlp.get_batch,\n",
    "             priors.input_warping.get_batch,\n",
    "             priors.utils.sample_num_features_get_batch,\n",
    "         )\n",
    "     ),\n",
    "     'encoder_generator': encoders.get_normalized_uniform_encoder(encoders.get_variable_num_features_encoder(encoders.Linear)),\n",
    "     'emsize': 512,\n",
    "     'nhead': 4,\n",
    "     'warmup_epochs': 5,\n",
    "     'y_encoder_generator': encoders.Linear,\n",
    "     'batch_size': 128,\n",
    "     'scheduler': utils.get_cosine_schedule_with_warmup,\n",
    "     'extra_prior_kwargs_dict': {'num_features': 18,\n",
    "      'hyperparameters': {'mlp_num_layers': CSH.UniformIntegerHyperparameter('mlp_num_layers', 8, 15),\n",
    "       'mlp_num_hidden': CSH.UniformIntegerHyperparameter('mlp_num_hidden', 36, 150),\n",
    "       'mlp_init_std': CSH.UniformFloatHyperparameter('mlp_init_std',0.08896049884896237, 0.1928554813280186),\n",
    "       'mlp_sparseness': 0.1449806273312999,\n",
    "       'mlp_input_sampling': 'uniform',\n",
    "       'mlp_output_noise': CSH.UniformFloatHyperparameter('mlp_output_noise', 0.00035983014290491186, 0.0013416342770574585),\n",
    "       'mlp_noisy_targets': True,\n",
    "       'mlp_preactivation_noise_std': CSH.UniformFloatHyperparameter('mlp_preactivation_noise_std',0.0003145707276259681, 0.0013753183831259406),\n",
    "       'input_warping_c1_std': 0.9759720822120248,\n",
    "       'input_warping_c0_std': 0.8002534583197192,\n",
    "       'num_hyperparameter_samples_per_batch': 16}\n",
    "                                 },\n",
    "     'epochs': 50,\n",
    "     'lr': 0.0001,\n",
    "     'seq_len': 60,\n",
    "     'single_eval_pos_gen': utils.get_uniform_single_eval_pos_sampler(50, min_len=1), \n",
    "     'aggregate_k_gradients': 1,\n",
    "     'nhid': 1024,\n",
    "     'steps_per_epoch': 1024,\n",
    "     'weight_decay': 0.0,\n",
    "     'train_mixed_precision': True,\n",
    "     'efficient_eval_masking': True,\n",
    "}\n",
    "\n",
    "\n",
    "# now let's add the criterions, where we decide the border positions based on the prior\n",
    "def get_ys(config, device='cuda:0'):\n",
    "    bs = 128\n",
    "    all_targets = []\n",
    "    for num_hps in [2,8,12]: # a few different samples in case the number of features makes a difference in y dist\n",
    "        b = config['priordataloader_class_or_get_batch'].get_batch_method(\n",
    "            bs,1000,  num_hps, epoch=0, device=device, hyperparameters={**config['extra_prior_kwargs_dict']['hyperparameters'],\n",
    "                                                                        'num_hyperparameter_samples_per_batch': -1,})\n",
    "        all_targets.append(b.target_y.flatten())\n",
    "    return torch.cat(all_targets,0)\n",
    "\n",
    "def add_criterion(config, device='cuda:0'):\n",
    "    return {**config, 'criterion': bar_distribution.FullSupportBarDistribution(\n",
    "        bar_distribution.get_bucket_limits(1000,ys=get_ys(config,device).cpu())\n",
    "    )}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c8789428",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore', category=UserWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4dec5326",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n",
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n",
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n",
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n",
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n",
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n",
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n",
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n",
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n",
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n",
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n",
      "/home/anon/mtpfn/PFNs/pfns/priors/hebo_prior.py:413: InputDataWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n",
      "  model = botorch.models.SingleTaskGP(\n",
      "/ext3/miniconda3/envs/pfn/lib/python3.12/site-packages/botorch/models/utils/assorted.py:268: InputDataWarning: Data (outcome observations) is not standardized (std = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]]), mean = tensor([[0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.],\n",
      "        [0.]])).Please consider scaling the input to zero mean and unit variance.\n",
      "  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using 384000 y evals to estimate 1000 buckets. Cut off the last 0 ys.\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "train() missing 3 required positional arguments: 'train_loader', 'eval_loader', and 'task_encoder_generator'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[4], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# Now let's train either with\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m train\u001b[38;5;241m.\u001b[39mtrain(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39madd_criterion(config_heboplus,device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu:0\u001b[39m\u001b[38;5;124m'\u001b[39m))\n",
      "\u001b[0;31mTypeError\u001b[0m: train() missing 3 required positional arguments: 'train_loader', 'eval_loader', and 'task_encoder_generator'"
     ]
    }
   ],
   "source": [
    "# Now let's train either with\n",
    "train.train(**add_criterion(config_heboplus,device='cpu:0'))\n",
    "# or\n",
    "#train.train(**add_criterion(config_heboplus_userpriors))\n",
    "# or\n",
    "#train.train(**add_criterion(config_bnn))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c427f1e1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pfn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
