{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Semi Synthetic Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Analyze the performance of various algorithms to solve the joint matching + activity task, when the number of volunteers is large and structured"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random \n",
    "import matplotlib.pyplot as plt\n",
    "import json \n",
    "import argparse \n",
    "import sys\n",
    "import secrets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr0/home/naveenr/miniconda3/envs/food/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from rmab.simulator import RMABSimulator\n",
    "from rmab.omniscient_policies import *\n",
    "from rmab.fr_dynamics import get_all_transitions\n",
    "from rmab.mcts_policies import *\n",
    "from rmab.utils import get_save_path, delete_duplicate_results, create_prob_distro\n",
    "import resource"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.set_per_process_memory_fraction(0.5)\n",
    "torch.set_num_threads(1)\n",
    "resource.setrlimit(resource.RLIMIT_AS, (30 * 1024 * 1024 * 1024, -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_jupyter = 'ipykernel' in sys.modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "if is_jupyter: \n",
    "    seed        = 43\n",
    "    n_arms      = 2\n",
    "    volunteers_per_arm = 2\n",
    "    budget      = 3\n",
    "    discount    = 0.9\n",
    "    alpha       = 3 \n",
    "    n_episodes  = 100\n",
    "    episode_len = 20 \n",
    "    n_epochs    = 1 \n",
    "    save_with_date = False \n",
    "    TIME_PER_RUN = 0.01 * 1000\n",
    "    lamb = 0.5\n",
    "    prob_distro = 'uniform'\n",
    "    policy_lr=0.001\n",
    "    value_lr=0.01\n",
    "    train_iterations = 30\n",
    "    test_iterations = 30\n",
    "    out_folder = 'mcts_exploration/mcts_ablation'\n",
    "else:\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument('--n_arms',         '-N', help='num beneficiaries (arms)', type=int, default=2)\n",
    "    parser.add_argument('--volunteers_per_arm',         '-V', help='volunteers per arm', type=int, default=5)\n",
    "    parser.add_argument('--episode_len',    '-H', help='episode length', type=int, default=20)\n",
    "    parser.add_argument('--n_episodes',     '-T', help='num episodes', type=int, default=100)\n",
    "    parser.add_argument('--budget',         '-B', help='budget', type=int, default=3)\n",
    "    parser.add_argument('--n_epochs',       '-E', help='number of epochs (num_repeats)', type=int, default=1)\n",
    "    parser.add_argument('--discount',       '-d', help='discount factor', type=float, default=0.9)\n",
    "    parser.add_argument('--alpha',          '-a', help='alpha: for conf radius', type=float, default=3)\n",
    "    parser.add_argument('--lamb',          '-l', help='lambda for matching-engagement tradeoff', type=float, default=0.5)\n",
    "    parser.add_argument('--seed',           '-s', help='random seed', type=int, default=42)\n",
    "    parser.add_argument('--prob_distro',           '-p', help='which prob distro [uniform,uniform_small,uniform_large,normal]', type=str, default='uniform')\n",
    "    parser.add_argument('--time_per_run',      '-t', help='time per MCTS run', type=float, default=.01*1000)\n",
    "    parser.add_argument('--policy_lr', help='Learning Rate Policy', type=float, default=0.001)\n",
    "    parser.add_argument('--value_lr', help='Learning Rate Value', type=float, default=0.01)\n",
    "    parser.add_argument('--train_iterations', help='Number of MCTS train iterations', type=int, default=30)\n",
    "    parser.add_argument('--test_iterations', help='Number of MCTS test iterations', type=int, default=30)\n",
    "    parser.add_argument('--out_folder', help='Which folder to write results to', type=str, default='mcts_exploration/mcts_ablation')\n",
    "\n",
    "    parser.add_argument('--use_date', action='store_true')\n",
    "\n",
    "    args = parser.parse_args()\n",
    "\n",
    "    n_arms      = args.n_arms\n",
    "    volunteers_per_arm = args.volunteers_per_arm\n",
    "    budget      = args.budget\n",
    "    discount    = args.discount\n",
    "    alpha       = args.alpha \n",
    "    seed        = args.seed\n",
    "    n_episodes  = args.n_episodes\n",
    "    episode_len = args.episode_len\n",
    "    n_epochs    = args.n_epochs\n",
    "    lamb = args.lamb\n",
    "    save_with_date = args.use_date\n",
    "    TIME_PER_RUN = args.time_per_run\n",
    "    prob_distro = args.prob_distro\n",
    "    policy_lr = args.policy_lr \n",
    "    value_lr = args.value_lr \n",
    "    out_folder = args.out_folder\n",
    "    train_iterations = args.train_iterations \n",
    "    test_iterations = args.test_iterations \n",
    "\n",
    "save_name = secrets.token_hex(4)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_states = 2\n",
    "n_actions = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_population_size = 100 # number of random arms to generate\n",
    "all_transitions = get_all_transitions(all_population_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_environment(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    if prob_distro == 'uniform':\n",
    "        match_probabilities = [np.random.random() for i in range(all_population_size * volunteers_per_arm)] \n",
    "    elif prob_distro == 'uniform_small':\n",
    "        match_probabilities = [np.random.random()/4 for i in range(all_population_size * volunteers_per_arm)] \n",
    "    elif prob_distro == 'uniform_large':\n",
    "        match_probabilities = [np.random.random()/4+0.75 for i in range(all_population_size * volunteers_per_arm)] \n",
    "    elif prob_distro == 'normal':\n",
    "        match_probabilities = [np.clip(np.random.normal(0.25, 0.1),0,1) for i in range(all_population_size * volunteers_per_arm)] \n",
    "\n",
    "    all_features = np.arange(all_population_size)\n",
    "    match_probabilities = create_prob_distro(prob_distro,all_population_size*volunteers_per_arm)\n",
    "    simulator = RMABSimulator(all_population_size, all_features, all_transitions,\n",
    "                n_arms, volunteers_per_arm, episode_len, n_epochs, n_episodes, budget, discount,number_states=n_states, reward_style='match',match_probability_list=match_probabilities,TIME_PER_RUN=TIME_PER_RUN)\n",
    "\n",
    "    return simulator "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_multi_seed(seed_list,policy,is_mcts=False,per_epoch_function=None,train_iterations=0,test_iterations=0,test_length=500):\n",
    "    memories = []\n",
    "    scores = {\n",
    "        'reward': [],\n",
    "        'time': [], \n",
    "        'match': [], \n",
    "        'active_rate': [],\n",
    "        'train_time': [],\n",
    "    }\n",
    "\n",
    "    for seed in seed_list:\n",
    "        simulator = create_environment(seed)\n",
    "        if is_mcts:\n",
    "            simulator.mcts_train_iterations = train_iterations\n",
    "            simulator.mcts_test_iterations = test_iterations\n",
    "            simulator.policy_lr = policy_lr\n",
    "            simulator.value_lr = value_lr\n",
    "\n",
    "        if is_mcts:\n",
    "            match, active_rate, memory = run_heterogenous_policy(simulator, n_episodes, n_epochs, discount,policy,seed,lamb=lamb,should_train=True,test_T=test_length,get_memory=True,per_epoch_function=per_epoch_function)\n",
    "        else:\n",
    "            match, active_rate = run_heterogenous_policy(simulator, n_episodes, n_epochs, discount,policy,seed,lamb=lamb,should_train=True,test_T=test_length,per_epoch_function=per_epoch_function)\n",
    "        time_whittle = simulator.time_taken\n",
    "        train_time = simulator.train_time\n",
    "        discounted_reward = get_discounted_reward(match,active_rate,discount,lamb)\n",
    "        scores['reward'].append(discounted_reward)\n",
    "        scores['time'].append(time_whittle)\n",
    "        scores['match'].append(np.mean(match))\n",
    "        scores['active_rate'].append(np.mean(active_rate))\n",
    "        scores['train_time'].append(train_time)\n",
    "        if is_mcts:\n",
    "            memories.append(memory)\n",
    "\n",
    "    return scores, memories, simulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "results['parameters'] = {'seed'      : seed,\n",
    "        'n_arms'    : n_arms,\n",
    "        'volunteers_per_arm': volunteers_per_arm, \n",
    "        'budget'    : budget,\n",
    "        'discount'  : discount, \n",
    "        'alpha'     : alpha, \n",
    "        'n_episodes': n_episodes, \n",
    "        'episode_len': episode_len, \n",
    "        'n_epochs'  : n_epochs, \n",
    "        'lamb': lamb,\n",
    "        'time_per_run': TIME_PER_RUN, \n",
    "        'prob_distro': prob_distro, \n",
    "        'policy_lr': policy_lr, \n",
    "        'value_lr': value_lr, \n",
    "        'train_iterations': train_iterations, \n",
    "        'test_iterations': test_iterations,} "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Index Policies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_list = [seed]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acting should always be good! 0.000 < 0.044\n",
      "acting should always be good! 0.000 < 0.162\n",
      "acting should always be good! 0.108 < 0.183\n",
      "good start state should always be good! 0.380 < 0.508\n",
      "good start state should always be good! 0.506 < 0.760\n",
      "cohort [26 66]\n",
      "instance 0, ep 1\n",
      "instance 0, ep 2\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[13], line 6\u001b[0m\n\u001b[1;32m      2\u001b[0m name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmcts\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m      4\u001b[0m train_iterations \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m30\u001b[39m\n\u001b[0;32m----> 6\u001b[0m rewards, memory, simulator \u001b[38;5;241m=\u001b[39m \u001b[43mrun_multi_seed\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43mpolicy\u001b[49m\u001b[43m,\u001b[49m\u001b[43mis_mcts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43mtrain_iterations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_iterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43mtest_iterations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_iterations\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      7\u001b[0m results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m_reward\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(name)] \u001b[38;5;241m=\u001b[39m rewards[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mreward\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m      8\u001b[0m results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m_match\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(name)] \u001b[38;5;241m=\u001b[39m  rewards[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmatch\u001b[39m\u001b[38;5;124m'\u001b[39m] \n",
      "Cell \u001b[0;32mIn[10], line 20\u001b[0m, in \u001b[0;36mrun_multi_seed\u001b[0;34m(seed_list, policy, is_mcts, per_epoch_function, train_iterations, test_iterations, test_length)\u001b[0m\n\u001b[1;32m     17\u001b[0m     simulator\u001b[38;5;241m.\u001b[39mvalue_lr \u001b[38;5;241m=\u001b[39m value_lr\n\u001b[1;32m     19\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_mcts:\n\u001b[0;32m---> 20\u001b[0m     match, active_rate, memory \u001b[38;5;241m=\u001b[39m \u001b[43mrun_heterogenous_policy\u001b[49m\u001b[43m(\u001b[49m\u001b[43msimulator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_episodes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_epochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdiscount\u001b[49m\u001b[43m,\u001b[49m\u001b[43mpolicy\u001b[49m\u001b[43m,\u001b[49m\u001b[43mseed\u001b[49m\u001b[43m,\u001b[49m\u001b[43mlamb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlamb\u001b[49m\u001b[43m,\u001b[49m\u001b[43mshould_train\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43mtest_T\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43mget_memory\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43mper_epoch_function\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mper_epoch_function\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     21\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     22\u001b[0m     match, active_rate \u001b[38;5;241m=\u001b[39m run_heterogenous_policy(simulator, n_episodes, n_epochs, discount,policy,seed,lamb\u001b[38;5;241m=\u001b[39mlamb,should_train\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,test_T\u001b[38;5;241m=\u001b[39mtest_length,per_epoch_function\u001b[38;5;241m=\u001b[39mper_epoch_function)\n",
      "File \u001b[0;32m~/projects/food_rescue_rmab/rmab/omniscient_policies.py:996\u001b[0m, in \u001b[0;36mrun_heterogenous_policy\u001b[0;34m(env, n_episodes, n_epochs, discount, policy, seed, per_epoch_function, lamb, get_memory, should_train, test_T)\u001b[0m\n\u001b[1;32m    993\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    994\u001b[0m     all_active_rate[epoch,t] \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39msum(state)\u001b[38;5;241m/\u001b[39m\u001b[38;5;28mlen\u001b[39m(state)\n\u001b[0;32m--> 996\u001b[0m action,memory \u001b[38;5;241m=\u001b[39m \u001b[43mpolicy\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43mbudget\u001b[49m\u001b[43m,\u001b[49m\u001b[43mlamb\u001b[49m\u001b[43m,\u001b[49m\u001b[43mmemory\u001b[49m\u001b[43m,\u001b[49m\u001b[43mper_epoch_results\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    997\u001b[0m next_state, reward, done, _ \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mstep(action)\n\u001b[1;32m    999\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m done \u001b[38;5;129;01mand\u001b[39;00m t\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m<\u001b[39m T: env\u001b[38;5;241m.\u001b[39mreset()\n",
      "File \u001b[0;32m~/projects/food_rescue_rmab/rmab/mcts_policies.py:601\u001b[0m, in \u001b[0;36mfull_mcts_policy\u001b[0;34m(env, state, budget, lamb, memory, per_epoch_results, contextual, group_setup, run_ucb, use_whittle)\u001b[0m\n\u001b[1;32m    599\u001b[0m new_group_index \u001b[38;5;241m=\u001b[39m deepcopy(group_indices)\n\u001b[1;32m    600\u001b[0m new_group_index[arm] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m--> 601\u001b[0m value_with_pull \u001b[38;5;241m=\u001b[39m \u001b[43mget_total_value\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43mall_match_probs\u001b[49m\u001b[43m,\u001b[49m\u001b[43mbest_group_arms\u001b[49m\u001b[43m,\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43mnew_group_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43mvalue_network\u001b[49m\u001b[43m,\u001b[49m\u001b[43mpolicy_network\u001b[49m\u001b[43m,\u001b[49m\u001b[43mlamb\u001b[49m\u001b[43m,\u001b[49m\u001b[43mnum_future_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mcontextual\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcontextual\u001b[49m\u001b[43m,\u001b[49m\u001b[43mmemoizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmemoizer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    602\u001b[0m value_with_pull \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m=\u001b[39m (\u001b[38;5;241m1\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m0.1\u001b[39m) \u001b[38;5;66;03m# Estimation error \u001b[39;00m\n\u001b[1;32m    603\u001b[0m upper_bound \u001b[38;5;241m=\u001b[39m current_value \u001b[38;5;241m+\u001b[39m (value_with_pull\u001b[38;5;241m-\u001b[39mcurrent_value)\u001b[38;5;241m*\u001b[39m(budget\u001b[38;5;241m-\u001b[39mk)\n",
      "File \u001b[0;32m~/projects/food_rescue_rmab/rmab/mcts_policies.py:170\u001b[0m, in \u001b[0;36mget_total_value\u001b[0;34m(env, all_match_probs, best_group_arms, state, group_indices, value_network, policy_network, lamb, num_future_samples, weighted, contextual, memoizer)\u001b[0m\n\u001b[1;32m    167\u001b[0m samples \u001b[38;5;241m=\u001b[39m samples \u001b[38;5;241m<\u001b[39m probs \n\u001b[1;32m    168\u001b[0m samples \u001b[38;5;241m=\u001b[39m samples\u001b[38;5;241m.\u001b[39mastype(\u001b[38;5;28mfloat\u001b[39m)\n\u001b[0;32m--> 170\u001b[0m future_actions \u001b[38;5;241m=\u001b[39m \u001b[43mget_action_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpolicy_network\u001b[49m\u001b[43m,\u001b[49m\u001b[43msamples\u001b[49m\u001b[43m,\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43mcontextual\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcontextual\u001b[49m\u001b[43m,\u001b[49m\u001b[43mmemoizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmemoizer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    172\u001b[0m future_actions \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor(future_actions)\n\u001b[1;32m    173\u001b[0m future_states \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor(samples)\n",
      "File \u001b[0;32m~/projects/food_rescue_rmab/rmab/mcts_policies.py:373\u001b[0m, in \u001b[0;36mget_action_state\u001b[0;34m(policy_network, state_list, env, contextual, memoizer)\u001b[0m\n\u001b[1;32m    362\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Given a state, find the best action using the policy network\u001b[39;00m\n\u001b[1;32m    363\u001b[0m \u001b[38;5;124;03m\u001b[39;00m\n\u001b[1;32m    364\u001b[0m \u001b[38;5;124;03mArguments: \u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    368\u001b[0m \u001b[38;5;124;03m\u001b[39;00m\n\u001b[1;32m    369\u001b[0m \u001b[38;5;124;03mReturns: Numpy array, 0-1 action for each agent\"\"\"\u001b[39;00m\n\u001b[1;32m    371\u001b[0m x_points \u001b[38;5;241m=\u001b[39m get_policy_network_input_many_state(env,np\u001b[38;5;241m.\u001b[39marray(state_list),contextual\u001b[38;5;241m=\u001b[39mcontextual,memoizer\u001b[38;5;241m=\u001b[39mmemoizer)\n\u001b[0;32m--> 373\u001b[0m policy_network_predictions \u001b[38;5;241m=\u001b[39m \u001b[43mpolicy_network\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_points\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mdetach() \n\u001b[1;32m    374\u001b[0m policy_network_predictions \u001b[38;5;241m=\u001b[39m policy_network_predictions\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m    376\u001b[0m action \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros(np\u001b[38;5;241m.\u001b[39marray(state_list)\u001b[38;5;241m.\u001b[39mshape, dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mint8)\n",
      "File \u001b[0;32m~/miniconda3/envs/food/lib/python3.8/site-packages/torch/nn/modules/module.py:889\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_slow_forward(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    888\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 889\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    890\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m hook \u001b[38;5;129;01min\u001b[39;00m itertools\u001b[38;5;241m.\u001b[39mchain(\n\u001b[1;32m    891\u001b[0m         _global_forward_hooks\u001b[38;5;241m.\u001b[39mvalues(),\n\u001b[1;32m    892\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks\u001b[38;5;241m.\u001b[39mvalues()):\n\u001b[1;32m    893\u001b[0m     hook_result \u001b[38;5;241m=\u001b[39m hook(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m, result)\n",
      "File \u001b[0;32m~/projects/food_rescue_rmab/rmab/mcts_policies.py:27\u001b[0m, in \u001b[0;36mMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     26\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 27\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfc1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     28\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrelu(x)\n\u001b[1;32m     29\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc2(x)\n",
      "File \u001b[0;32m~/miniconda3/envs/food/lib/python3.8/site-packages/torch/nn/modules/module.py:890\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    888\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    889\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 890\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m hook \u001b[38;5;129;01min\u001b[39;00m \u001b[43mitertools\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mchain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    891\u001b[0m \u001b[43m        \u001b[49m\u001b[43m_global_forward_hooks\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    892\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_forward_hooks\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m    893\u001b[0m     hook_result \u001b[38;5;241m=\u001b[39m hook(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m, result)\n\u001b[1;32m    894\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m hook_result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "policy = full_mcts_policy \n",
    "name = \"mcts\"\n",
    "\n",
    "train_iterations = 30\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,is_mcts=True,train_iterations=train_iterations,test_iterations=test_iterations)\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "results['{}_train_time'.format(name)] = rewards['train_time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'mcts_exploration/mcts_ablation'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acting should always be good! 0.000 < 0.044\n",
      "acting should always be good! 0.000 < 0.162\n",
      "acting should always be good! 0.108 < 0.183\n",
      "good start state should always be good! 0.380 < 0.508\n",
      "good start state should always be good! 0.506 < 0.760\n",
      "cohort [26 66]\n",
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "Took 47.1790988445282 time for inference and 48.06454658508301 time for training\n",
      "5.987343410315398\n"
     ]
    }
   ],
   "source": [
    "def mcts_random(env,state,budget,lamb,memory,per_epoch_results):\n",
    "    return full_mcts_policy(env,state,budget,lamb,memory,per_epoch_results,group_setup=\"none\")\n",
    "policy = mcts_random \n",
    "name = \"mcts_no_group\"\n",
    "\n",
    "train_iterations = 30\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,is_mcts=True,train_iterations=train_iterations,test_iterations=test_iterations)\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_train_time'.format(name)] = rewards['train_time']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mcts_random(env,state,budget,lamb,memory,per_epoch_results):\n",
    "    return full_mcts_policy(env,state,budget,lamb,memory,per_epoch_results,group_setup=\"random\")\n",
    "policy = mcts_random \n",
    "name = \"mcts_rand_group\"\n",
    "\n",
    "train_iterations = 30\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,is_mcts=True,train_iterations=train_iterations,test_iterations=test_iterations)\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_train_time'.format(name)] = rewards['train_time']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acting should always be good! 0.000 < 0.044\n",
      "acting should always be good! 0.000 < 0.162\n",
      "acting should always be good! 0.108 < 0.183\n",
      "good start state should always be good! 0.380 < 0.508\n",
      "good start state should always be good! 0.506 < 0.760\n",
      "cohort [26 66]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr0/home/naveenr/miniconda3/envs/food/lib/python3.8/site-packages/sklearn/cluster/_kmeans.py:1412: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning\n",
      "  super()._check_params_vs_input(X, default_n_init=10)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "Took 42.727776527404785 time for inference and 44.05237698554993 time for training\n",
      "5.990477498704265\n"
     ]
    }
   ],
   "source": [
    "def mcts_random(env,state,budget,lamb,memory,per_epoch_results):\n",
    "    return full_mcts_policy(env,state,budget,lamb,memory,per_epoch_results,group_setup=\"whittle\")\n",
    "policy = mcts_random \n",
    "name = \"mcts_whittle_group\"\n",
    "\n",
    "train_iterations = 30\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,is_mcts=True,train_iterations=train_iterations,test_iterations=test_iterations)\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_train_time'.format(name)] = rewards['train_time']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acting should always be good! 0.000 < 0.044\n",
      "acting should always be good! 0.000 < 0.162\n",
      "acting should always be good! 0.108 < 0.183\n",
      "good start state should always be good! 0.380 < 0.508\n",
      "good start state should always be good! 0.506 < 0.760\n",
      "cohort [26 66]\n",
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "Took 8.604904174804688 time for inference and 9.518194437026978 time for training\n",
      "5.887286249496324\n"
     ]
    }
   ],
   "source": [
    "def mcts_without_ucb(env,state,budget,lamb,memory,per_epoch_results):\n",
    "    return full_mcts_policy(env,state,budget,lamb,memory,per_epoch_results,run_ucb=False)\n",
    "policy = mcts_without_ucb \n",
    "name = \"mcts_no_ucb\"\n",
    "\n",
    "train_iterations = 30\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,is_mcts=True,train_iterations=train_iterations,test_iterations=test_iterations)\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "results['{}_train_time'.format(name)] = rewards['train_time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acting should always be good! 0.000 < 0.044\n",
      "acting should always be good! 0.000 < 0.162\n",
      "acting should always be good! 0.108 < 0.183\n",
      "good start state should always be good! 0.380 < 0.508\n",
      "good start state should always be good! 0.506 < 0.760\n",
      "cohort [26 66]\n",
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "Took 44.96473932266235 time for inference and 46.4594509601593 time for training\n",
      "5.85701937427381\n"
     ]
    }
   ],
   "source": [
    "def mcts_without_whittle(env,state,budget,lamb,memory,per_epoch_results):\n",
    "    return full_mcts_policy(env,state,budget,lamb,memory,per_epoch_results,use_whittle=False)\n",
    "policy = mcts_without_whittle \n",
    "name = \"mcts_no_whittle\"\n",
    "\n",
    "train_iterations = 30\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,is_mcts=True,train_iterations=train_iterations,test_iterations=test_iterations)\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "results['{}_train_time'.format(name)] = rewards['train_time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acting should always be good! 0.000 < 0.044\n",
      "acting should always be good! 0.000 < 0.162\n",
      "acting should always be good! 0.108 < 0.183\n",
      "good start state should always be good! 0.380 < 0.508\n",
      "good start state should always be good! 0.506 < 0.760\n",
      "cohort [26 66]\n",
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "Took 45.54479002952576 time for inference and 1.869741439819336 time for training\n",
      "5.88048247887817\n"
     ]
    }
   ],
   "source": [
    "policy = full_mcts_policy \n",
    "name = \"mcts_no_train\"\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,is_mcts=True,train_iterations=0,test_iterations=30)\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "results['{}_train_time'.format(name)] = rewards['train_time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acting should always be good! 0.000 < 0.044\n",
      "acting should always be good! 0.000 < 0.162\n",
      "acting should always be good! 0.108 < 0.183\n",
      "good start state should always be good! 0.380 < 0.508\n",
      "good start state should always be good! 0.506 < 0.760\n",
      "cohort [26 66]\n",
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "Took 0.7968130111694336 time for inference and 47.424615144729614 time for training\n",
      "5.834900010832622\n"
     ]
    }
   ],
   "source": [
    "policy = full_mcts_policy \n",
    "name = \"mcts_no_test\"\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,is_mcts=True,train_iterations=train_iterations,test_iterations=0)\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "results['{}_train_time'.format(name)] = rewards['train_time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Write Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = get_save_path(out_folder,save_name,seed,use_date=save_with_date)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "delete_duplicate_results(out_folder,\"\",results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(results,open('../../results/'+save_path,'w'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "food",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
