{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PACOH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-26T02:21:44.513545Z",
     "start_time": "2021-05-26T02:21:43.744671Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import sys\n",
    "import os\n",
    "import argparse\n",
    "from meta_learn import GPRegressionPACMAML\n",
    "from experiments.data_sim import SinusoidDataset\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-26T02:22:25.522619Z",
     "start_time": "2021-05-26T02:22:00.405250Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a33aadf659994f9d9ce8de60fd940879",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2021-05-25 22:22:00,583 -INFO]  Iter 1/2000 - Loss: 0.137260 - Time 0.09 sec - Valid-RMSE: 0.943\n",
      "[2021-05-25 22:22:01,199 -INFO]  Iter 50/2000 - Loss: 0.052326 - Time 0.61 sec - Valid-RMSE: 0.691\n",
      "[2021-05-25 22:22:01,867 -INFO]  Iter 100/2000 - Loss: 0.042182 - Time 0.66 sec - Valid-RMSE: 0.716\n",
      "[2021-05-25 22:22:02,499 -INFO]  Iter 150/2000 - Loss: 0.041738 - Time 0.63 sec - Valid-RMSE: 0.689\n",
      "[2021-05-25 22:22:03,119 -INFO]  Iter 200/2000 - Loss: 0.041642 - Time 0.62 sec - Valid-RMSE: 0.660\n",
      "[2021-05-25 22:22:03,760 -INFO]  Iter 250/2000 - Loss: 0.040725 - Time 0.64 sec - Valid-RMSE: 0.626\n",
      "[2021-05-25 22:22:04,401 -INFO]  Iter 300/2000 - Loss: 0.040152 - Time 0.64 sec - Valid-RMSE: 0.564\n",
      "[2021-05-25 22:22:05,029 -INFO]  Iter 350/2000 - Loss: 0.038971 - Time 0.62 sec - Valid-RMSE: 0.524\n",
      "[2021-05-25 22:22:05,706 -INFO]  Iter 400/2000 - Loss: 0.038231 - Time 0.67 sec - Valid-RMSE: 0.492\n",
      "[2021-05-25 22:22:06,358 -INFO]  Iter 450/2000 - Loss: 0.038342 - Time 0.65 sec - Valid-RMSE: 0.475\n",
      "[2021-05-25 22:22:07,023 -INFO]  Iter 500/2000 - Loss: 0.038197 - Time 0.65 sec - Valid-RMSE: 0.484\n",
      "[2021-05-25 22:22:07,658 -INFO]  Iter 550/2000 - Loss: 0.038381 - Time 0.64 sec - Valid-RMSE: 0.466\n",
      "[2021-05-25 22:22:08,307 -INFO]  Iter 600/2000 - Loss: 0.038175 - Time 0.65 sec - Valid-RMSE: 0.462\n",
      "[2021-05-25 22:22:08,948 -INFO]  Iter 650/2000 - Loss: 0.038204 - Time 0.64 sec - Valid-RMSE: 0.459\n",
      "[2021-05-25 22:22:09,608 -INFO]  Iter 700/2000 - Loss: 0.037767 - Time 0.65 sec - Valid-RMSE: 0.458\n",
      "[2021-05-25 22:22:10,236 -INFO]  Iter 750/2000 - Loss: 0.037496 - Time 0.63 sec - Valid-RMSE: 0.454\n",
      "[2021-05-25 22:22:10,950 -INFO]  Iter 800/2000 - Loss: 0.037653 - Time 0.71 sec - Valid-RMSE: 0.444\n",
      "[2021-05-25 22:22:11,623 -INFO]  Iter 850/2000 - Loss: 0.037509 - Time 0.67 sec - Valid-RMSE: 0.441\n",
      "[2021-05-25 22:22:12,224 -INFO]  Iter 900/2000 - Loss: 0.037450 - Time 0.59 sec - Valid-RMSE: 0.415\n",
      "[2021-05-25 22:22:12,834 -INFO]  Iter 950/2000 - Loss: 0.036615 - Time 0.61 sec - Valid-RMSE: 0.385\n",
      "[2021-05-25 22:22:13,437 -INFO]  Iter 1000/2000 - Loss: 0.034342 - Time 0.60 sec - Valid-RMSE: 0.344\n",
      "[2021-05-25 22:22:14,035 -INFO]  Iter 1050/2000 - Loss: 0.031782 - Time 0.59 sec - Valid-RMSE: 0.290\n",
      "[2021-05-25 22:22:14,638 -INFO]  Iter 1100/2000 - Loss: 0.030495 - Time 0.60 sec - Valid-RMSE: 0.255\n",
      "[2021-05-25 22:22:15,244 -INFO]  Iter 1150/2000 - Loss: 0.029550 - Time 0.60 sec - Valid-RMSE: 0.274\n",
      "[2021-05-25 22:22:15,840 -INFO]  Iter 1200/2000 - Loss: 0.029149 - Time 0.59 sec - Valid-RMSE: 0.285\n",
      "[2021-05-25 22:22:16,432 -INFO]  Iter 1250/2000 - Loss: 0.029208 - Time 0.59 sec - Valid-RMSE: 0.262\n",
      "[2021-05-25 22:22:17,050 -INFO]  Iter 1300/2000 - Loss: 0.028726 - Time 0.61 sec - Valid-RMSE: 0.248\n",
      "[2021-05-25 22:22:17,648 -INFO]  Iter 1350/2000 - Loss: 0.029139 - Time 0.59 sec - Valid-RMSE: 0.244\n",
      "[2021-05-25 22:22:18,250 -INFO]  Iter 1400/2000 - Loss: 0.028699 - Time 0.60 sec - Valid-RMSE: 0.246\n",
      "[2021-05-25 22:22:18,846 -INFO]  Iter 1450/2000 - Loss: 0.028904 - Time 0.60 sec - Valid-RMSE: 0.249\n",
      "[2021-05-25 22:22:19,444 -INFO]  Iter 1500/2000 - Loss: 0.027956 - Time 0.59 sec - Valid-RMSE: 0.248\n",
      "[2021-05-25 22:22:20,040 -INFO]  Iter 1550/2000 - Loss: 0.028598 - Time 0.59 sec - Valid-RMSE: 0.253\n",
      "[2021-05-25 22:22:20,642 -INFO]  Iter 1600/2000 - Loss: 0.028616 - Time 0.59 sec - Valid-RMSE: 0.238\n",
      "[2021-05-25 22:22:21,242 -INFO]  Iter 1650/2000 - Loss: 0.028189 - Time 0.59 sec - Valid-RMSE: 0.252\n",
      "[2021-05-25 22:22:21,841 -INFO]  Iter 1700/2000 - Loss: 0.028120 - Time 0.59 sec - Valid-RMSE: 0.235\n",
      "[2021-05-25 22:22:22,431 -INFO]  Iter 1750/2000 - Loss: 0.028120 - Time 0.59 sec - Valid-RMSE: 0.244\n",
      "[2021-05-25 22:22:23,027 -INFO]  Iter 1800/2000 - Loss: 0.028899 - Time 0.59 sec - Valid-RMSE: 0.246\n",
      "[2021-05-25 22:22:23,645 -INFO]  Iter 1850/2000 - Loss: 0.028174 - Time 0.61 sec - Valid-RMSE: 0.243\n",
      "[2021-05-25 22:22:24,270 -INFO]  Iter 1900/2000 - Loss: 0.028853 - Time 0.62 sec - Valid-RMSE: 0.230\n",
      "[2021-05-25 22:22:24,898 -INFO]  Iter 1950/2000 - Loss: 0.028905 - Time 0.62 sec - Valid-RMSE: 0.239\n",
      "[2021-05-25 22:22:25,512 -INFO]  Iter 2000/2000 - Loss: 0.028805 - Time 0.61 sec - Valid-RMSE: 0.237\n"
     ]
    }
   ],
   "source": [
    "n_task = 20\n",
    "m = 5\n",
    "\n",
    "m_i=20\n",
    "r_i = m_i\n",
    "beta = 30*m_i\n",
    "lr = 3e-3\n",
    "var = 3\n",
    "random_train = 33\n",
    "random_state = np.random.RandomState(random_train)\n",
    "task_environment = SinusoidDataset(random_state=random_state)\n",
    "\n",
    "meta_train_data = task_environment.generate_meta_train_data(n_tasks=n_task, n_samples=r_i)\n",
    "\n",
    "random_state = np.random.RandomState(25)\n",
    "test_environment = SinusoidDataset(random_state=random_state)\n",
    "meta_test_data = test_environment.generate_meta_test_data(n_tasks=n_task, n_samples_context=m, \n",
    "                                                          n_samples_test=100)\n",
    "\n",
    "pacmaml = GPRegressionPACMAML(meta_train_data, beta=beta, var=var, num_iter_fit=2000,\n",
    "                                      random_seed=4, lr_params = lr, train_number = m_i,\n",
    "                                      val_number = r_i-m_i, theorem='PACOH')\n",
    "val_result = pacmaml.meta_fit(meta_test_data, log_period=50, verbose = True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PACMAML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-26T02:40:52.185253Z",
     "start_time": "2021-05-26T02:39:01.299991Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2021-05-25 22:39:01,465 -INFO]  Iter 1/3000 - Loss: 0.342370 - Time 0.07 sec - Valid-RMSE: 0.972\n",
      "[2021-05-25 22:39:03,214 -INFO]  Iter 50/3000 - Loss: 0.385586 - Time 1.76 sec - Valid-RMSE: 0.918\n",
      "[2021-05-25 22:39:04,997 -INFO]  Iter 100/3000 - Loss: 0.285023 - Time 1.76 sec - Valid-RMSE: 0.925\n",
      "[2021-05-25 22:39:06,778 -INFO]  Iter 150/3000 - Loss: 0.305446 - Time 1.77 sec - Valid-RMSE: 0.942\n",
      "[2021-05-25 22:39:08,501 -INFO]  Iter 200/3000 - Loss: 0.337604 - Time 1.71 sec - Valid-RMSE: 0.941\n",
      "[2021-05-25 22:39:10,238 -INFO]  Iter 250/3000 - Loss: 0.322798 - Time 1.73 sec - Valid-RMSE: 0.930\n",
      "[2021-05-25 22:39:12,055 -INFO]  Iter 300/3000 - Loss: 0.307715 - Time 1.80 sec - Valid-RMSE: 0.897\n",
      "[2021-05-25 22:39:14,016 -INFO]  Iter 350/3000 - Loss: 0.280420 - Time 1.95 sec - Valid-RMSE: 0.865\n",
      "[2021-05-25 22:39:15,850 -INFO]  Iter 400/3000 - Loss: 0.278417 - Time 1.80 sec - Valid-RMSE: 0.875\n",
      "[2021-05-25 22:39:17,658 -INFO]  Iter 450/3000 - Loss: 0.290983 - Time 1.79 sec - Valid-RMSE: 0.868\n",
      "[2021-05-25 22:39:19,445 -INFO]  Iter 500/3000 - Loss: 0.277506 - Time 1.76 sec - Valid-RMSE: 0.860\n",
      "[2021-05-25 22:39:21,269 -INFO]  Iter 550/3000 - Loss: 0.273261 - Time 1.81 sec - Valid-RMSE: 0.844\n",
      "[2021-05-25 22:39:23,058 -INFO]  Iter 600/3000 - Loss: 0.257036 - Time 1.76 sec - Valid-RMSE: 0.835\n",
      "[2021-05-25 22:39:24,937 -INFO]  Iter 650/3000 - Loss: 0.254546 - Time 1.86 sec - Valid-RMSE: 0.806\n",
      "[2021-05-25 22:39:27,107 -INFO]  Iter 700/3000 - Loss: 0.226713 - Time 2.14 sec - Valid-RMSE: 0.757\n",
      "[2021-05-25 22:39:29,531 -INFO]  Iter 750/3000 - Loss: 0.206727 - Time 2.41 sec - Valid-RMSE: 0.714\n",
      "[2021-05-25 22:39:31,330 -INFO]  Iter 800/3000 - Loss: 0.169377 - Time 1.77 sec - Valid-RMSE: 0.641\n",
      "[2021-05-25 22:39:33,141 -INFO]  Iter 850/3000 - Loss: 0.139753 - Time 1.79 sec - Valid-RMSE: 0.595\n",
      "[2021-05-25 22:39:34,960 -INFO]  Iter 900/3000 - Loss: 0.127833 - Time 1.79 sec - Valid-RMSE: 0.542\n",
      "[2021-05-25 22:39:36,971 -INFO]  Iter 950/3000 - Loss: 0.114206 - Time 1.99 sec - Valid-RMSE: 0.493\n",
      "[2021-05-25 22:39:38,857 -INFO]  Iter 1000/3000 - Loss: 0.104288 - Time 1.86 sec - Valid-RMSE: 0.458\n",
      "[2021-05-25 22:39:40,894 -INFO]  Iter 1050/3000 - Loss: 0.095825 - Time 2.02 sec - Valid-RMSE: 0.417\n",
      "[2021-05-25 22:39:42,921 -INFO]  Iter 1100/3000 - Loss: 0.085107 - Time 2.00 sec - Valid-RMSE: 0.399\n",
      "[2021-05-25 22:39:44,953 -INFO]  Iter 1150/3000 - Loss: 0.072252 - Time 2.01 sec - Valid-RMSE: 0.358\n",
      "[2021-05-25 22:39:47,022 -INFO]  Iter 1200/3000 - Loss: 0.070364 - Time 2.04 sec - Valid-RMSE: 0.344\n",
      "[2021-05-25 22:39:49,119 -INFO]  Iter 1250/3000 - Loss: 0.078745 - Time 2.08 sec - Valid-RMSE: 0.383\n",
      "[2021-05-25 22:39:51,182 -INFO]  Iter 1300/3000 - Loss: 0.075130 - Time 2.03 sec - Valid-RMSE: 0.336\n",
      "[2021-05-25 22:39:53,017 -INFO]  Iter 1350/3000 - Loss: 0.073001 - Time 1.82 sec - Valid-RMSE: 0.337\n",
      "[2021-05-25 22:39:54,739 -INFO]  Iter 1400/3000 - Loss: 0.067442 - Time 1.69 sec - Valid-RMSE: 0.323\n",
      "[2021-05-25 22:39:56,475 -INFO]  Iter 1450/3000 - Loss: 0.064484 - Time 1.72 sec - Valid-RMSE: 0.323\n",
      "[2021-05-25 22:39:58,194 -INFO]  Iter 1500/3000 - Loss: 0.063767 - Time 1.70 sec - Valid-RMSE: 0.320\n",
      "[2021-05-25 22:39:59,924 -INFO]  Iter 1550/3000 - Loss: 0.065799 - Time 1.72 sec - Valid-RMSE: 0.296\n",
      "[2021-05-25 22:40:01,655 -INFO]  Iter 1600/3000 - Loss: 0.056486 - Time 1.71 sec - Valid-RMSE: 0.287\n",
      "[2021-05-25 22:40:03,420 -INFO]  Iter 1650/3000 - Loss: 0.057256 - Time 1.75 sec - Valid-RMSE: 0.286\n",
      "[2021-05-25 22:40:05,147 -INFO]  Iter 1700/3000 - Loss: 0.061630 - Time 1.71 sec - Valid-RMSE: 0.295\n",
      "[2021-05-25 22:40:06,888 -INFO]  Iter 1750/3000 - Loss: 0.058217 - Time 1.73 sec - Valid-RMSE: 0.288\n",
      "[2021-05-25 22:40:08,608 -INFO]  Iter 1800/3000 - Loss: 0.057094 - Time 1.70 sec - Valid-RMSE: 0.280\n",
      "[2021-05-25 22:40:10,366 -INFO]  Iter 1850/3000 - Loss: 0.063932 - Time 1.72 sec - Valid-RMSE: 0.289\n",
      "[2021-05-25 22:40:12,082 -INFO]  Iter 1900/3000 - Loss: 0.057511 - Time 1.72 sec - Valid-RMSE: 0.299\n",
      "[2021-05-25 22:40:13,858 -INFO]  Iter 1950/3000 - Loss: 0.058668 - Time 1.76 sec - Valid-RMSE: 0.289\n",
      "[2021-05-25 22:40:15,586 -INFO]  Iter 2000/3000 - Loss: 0.061382 - Time 1.71 sec - Valid-RMSE: 0.324\n",
      "[2021-05-25 22:40:17,324 -INFO]  Iter 2050/3000 - Loss: 0.065120 - Time 1.72 sec - Valid-RMSE: 0.400\n",
      "[2021-05-25 22:40:19,054 -INFO]  Iter 2100/3000 - Loss: 0.065346 - Time 1.71 sec - Valid-RMSE: 0.320\n",
      "[2021-05-25 22:40:20,804 -INFO]  Iter 2150/3000 - Loss: 0.061110 - Time 1.73 sec - Valid-RMSE: 0.293\n",
      "[2021-05-25 22:40:22,537 -INFO]  Iter 2200/3000 - Loss: 0.066732 - Time 1.71 sec - Valid-RMSE: 0.279\n",
      "[2021-05-25 22:40:24,270 -INFO]  Iter 2250/3000 - Loss: 0.058871 - Time 1.72 sec - Valid-RMSE: 0.295\n",
      "[2021-05-25 22:40:26,045 -INFO]  Iter 2300/3000 - Loss: 0.062803 - Time 1.75 sec - Valid-RMSE: 0.293\n",
      "[2021-05-25 22:40:27,788 -INFO]  Iter 2350/3000 - Loss: 0.062541 - Time 1.73 sec - Valid-RMSE: 0.292\n",
      "[2021-05-25 22:40:29,527 -INFO]  Iter 2400/3000 - Loss: 0.070145 - Time 1.72 sec - Valid-RMSE: 0.295\n",
      "[2021-05-25 22:40:31,268 -INFO]  Iter 2450/3000 - Loss: 0.060091 - Time 1.73 sec - Valid-RMSE: 0.305\n",
      "[2021-05-25 22:40:33,003 -INFO]  Iter 2500/3000 - Loss: 0.058839 - Time 1.71 sec - Valid-RMSE: 0.266\n",
      "[2021-05-25 22:40:34,753 -INFO]  Iter 2550/3000 - Loss: 0.061190 - Time 1.73 sec - Valid-RMSE: 0.305\n",
      "[2021-05-25 22:40:36,531 -INFO]  Iter 2600/3000 - Loss: 0.066586 - Time 1.76 sec - Valid-RMSE: 0.306\n",
      "[2021-05-25 22:40:38,272 -INFO]  Iter 2650/3000 - Loss: 0.061435 - Time 1.72 sec - Valid-RMSE: 0.309\n",
      "[2021-05-25 22:40:40,008 -INFO]  Iter 2700/3000 - Loss: 0.064000 - Time 1.72 sec - Valid-RMSE: 0.278\n",
      "[2021-05-25 22:40:41,782 -INFO]  Iter 2750/3000 - Loss: 0.058877 - Time 1.76 sec - Valid-RMSE: 0.274\n",
      "[2021-05-25 22:40:43,523 -INFO]  Iter 2800/3000 - Loss: 0.061856 - Time 1.72 sec - Valid-RMSE: 0.282\n",
      "[2021-05-25 22:40:46,975 -INFO]  Iter 2850/3000 - Loss: 0.061835 - Time 3.44 sec - Valid-RMSE: 0.283\n",
      "[2021-05-25 22:40:48,676 -INFO]  Iter 2900/3000 - Loss: 0.061809 - Time 1.68 sec - Valid-RMSE: 0.277\n",
      "[2021-05-25 22:40:50,439 -INFO]  Iter 2950/3000 - Loss: 0.059732 - Time 1.75 sec - Valid-RMSE: 0.271\n",
      "[2021-05-25 22:40:52,159 -INFO]  Iter 3000/3000 - Loss: 0.060268 - Time 1.70 sec - Valid-RMSE: 0.272\n"
     ]
    }
   ],
   "source": [
    "n_task = 20\n",
    "m = 5\n",
    "m_i = 50\n",
    "beta = 30*m_i\n",
    "alpha = 0.2*beta\n",
    "lr = 3e-3\n",
    "var = 3\n",
    "random_train = 33\n",
    "random_state = np.random.RandomState(random_train)\n",
    "task_environment = SinusoidDataset(random_state=random_state)\n",
    "\n",
    "meta_train_data = task_environment.generate_meta_train_data(n_tasks=n_task, n_samples=m_i)\n",
    "\n",
    "random_state = np.random.RandomState(25)\n",
    "test_environment = SinusoidDataset(random_state=random_state)\n",
    "meta_test_data = test_environment.generate_meta_test_data(n_tasks=n_task, n_samples_context=m, \n",
    "                                                          n_samples_test=100)\n",
    "\n",
    "pacmaml = GPRegressionPACMAML(meta_train_data, beta=beta, alpha=alpha, var=var, num_iter_fit=3000,\n",
    "                                      random_seed=4, lr_params = lr, train_number = m,\n",
    "                                      val_number = m_i-m, theorem='PACMAML')\n",
    "val_result = pacmaml.meta_fit(meta_test_data, log_period=50, verbose = True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
