{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "331f696f-c0fe-4f74-b366-bccb8b5bd475",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "api = wandb.Api()\n",
    "\n",
    "def load_empirical_variable_cost_runs(acq, tfn):\n",
    "    runs = api.runs(path=\"ziv-scully-group/PandoraBayesOPT\", filters={\n",
    "        \"sweep\": \"8pz4jcth\", \n",
    "        \"config.problem\": tfn,\n",
    "        \"config.policy\": acq,\n",
    "        \"config.cost_function_type\": \"max\"})\n",
    "    \n",
    "    configs_and_metrics = []\n",
    "    for run in tqdm(runs):\n",
    "        metric_keys = [\"cumulative cost\",\"best observed\"]\n",
    "        history = run.scan_history(keys = metric_keys, page_size=1_000_000_000)\n",
    "        metrics = {k: [d[k] for d in history] for k in metric_keys}\n",
    "        configs_and_metrics.append((run.config, metrics))\n",
    "\n",
    "    return configs_and_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c1d277e2-7a92-4761-833c-41d20f915b8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "acquisition_functions = {\n",
    "    'RandomSearch': 'RandomSearch',\n",
    "    'ExpectedImprovementWithoutCost':'ExpectedImprovementWithoutCost',\n",
    "    'ExpectedImprovementPerUnitCost':'ExpectedImprovementPerUnitCost',\n",
    "    'BudgetedMultiStepLookaheadEI': 'BudgetedMultiStepLookaheadEI',\n",
    "    'Gittins_Lambda_0001':'Gittins_Lambda0001',\n",
    "    'Gittins_Step_Divide2':'Gittins_Step_Divide2',\n",
    "    }\n",
    "target_functions = [\"RobotPushing14D\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f267f939-8247-4106-8887-1a9153bd18d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 16/16 [00:14<00:00,  1.12it/s]\n",
      "100%|██████████| 16/16 [00:13<00:00,  1.15it/s]\n",
      "100%|██████████| 16/16 [00:11<00:00,  1.36it/s]\n",
      "100%|██████████| 16/16 [00:13<00:00,  1.18it/s]\n",
      "100%|██████████| 16/16 [00:12<00:00,  1.33it/s]\n",
      "100%|██████████| 16/16 [00:11<00:00,  1.39it/s]\n"
     ]
    }
   ],
   "source": [
    "grouped_runs = {(a,t): load_empirical_variable_cost_runs(a,t) for a in acquisition_functions.keys() for t in target_functions}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "124ece1e-1b36-40fb-ad1c-b9d01a242262",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RandomSearch RobotPushing14D (601, 16)\n",
      "ExpectedImprovementWithoutCost RobotPushing14D (601, 16)\n",
      "ExpectedImprovementPerUnitCost RobotPushing14D (601, 16)\n",
      "BudgetedMultiStepLookaheadEI RobotPushing14D (601, 16)\n",
      "Gittins_Lambda_0001 RobotPushing14D (601, 16)\n",
      "Gittins_Step_Divide2 RobotPushing14D (601, 16)\n"
     ]
    }
   ],
   "source": [
    "for a in acquisition_functions.keys():\n",
    "    for t in target_functions:\n",
    "        config_and_metrics_per_seed = grouped_runs[a,t]\n",
    "\n",
    "        cumulative_cost_per_seed = np.array([m['cumulative cost'] for (_,m) in config_and_metrics_per_seed]).T\n",
    "        best_observed_per_seed = np.array([m['best observed'] for (_,m) in config_and_metrics_per_seed]).T\n",
    "\n",
    "        print(a, t, best_observed_per_seed.shape)\n",
    "\n",
    "        best_observed_per_seed = -best_observed_per_seed\n",
    "        \n",
    "        best_25 = np.quantile(best_observed_per_seed, 0.25, axis=1)\n",
    "        best_50 = np.quantile(best_observed_per_seed, 0.5, axis=1)\n",
    "        best_75 = np.quantile(best_observed_per_seed, 0.75, axis=1)\n",
    "\n",
    "        output = np.stack((cumulative_cost_per_seed.mean(axis=1), best_25, best_50, best_75),axis=-1)\n",
    "\n",
    "        np.savetxt(f\"results/empirical_cost_aware/Empirical_VariableCost_{t}_max_{acquisition_functions[a]}.csv\", output, header=\"cc, q25, q50, q75\", delimiter=', ', comments='')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
