{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Semi Synthetic Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Analyze the performance of various algorithms to solve the joint matching + activity task, when the number of volunteers is large and structured"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random \n",
    "import matplotlib.pyplot as plt\n",
    "import json \n",
    "import argparse \n",
    "import sys\n",
    "import secrets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr0/home/naveenr/miniconda3/envs/food/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from rmab.simulator import RMABSimulator, run_heterogenous_policy, get_discounted_reward, create_random_transitions\n",
    "from rmab.omniscient_policies import *\n",
    "from rmab.dqn_policies import *\n",
    "from rmab.fr_dynamics import get_all_transitions, get_db_data, get_all_transitions_partition\n",
    "from rmab.mcts_policies import *\n",
    "from rmab.utils import get_save_path, delete_duplicate_results, create_prob_distro\n",
    "import resource"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.set_per_process_memory_fraction(0.5)\n",
    "torch.set_num_threads(1)\n",
    "resource.setrlimit(resource.RLIMIT_AS, (30 * 1024 * 1024 * 1024, -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_jupyter = 'ipykernel' in sys.modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "if is_jupyter: \n",
    "    seed        = 51\n",
    "    n_arms      = 4\n",
    "    volunteers_per_arm = 1\n",
    "    budget      = 2\n",
    "    discount    = 0.9\n",
    "    alpha       = 3 \n",
    "    n_episodes  = 105\n",
    "    episode_len = 50 \n",
    "    n_epochs    = 1\n",
    "    save_with_date = False \n",
    "    TIME_PER_RUN = 0.01 * 1000\n",
    "    lamb = 0.5\n",
    "    prob_distro = 'uniform'\n",
    "    reward_type = \"linear\"\n",
    "    reward_parameters = {'universe_size': 20, 'arm_set_low': 0, 'arm_set_high': 1}\n",
    "    policy_lr=5e-3\n",
    "    value_lr=1e-4\n",
    "    train_iterations = 30\n",
    "    test_iterations = 30\n",
    "    out_folder = 'iterative'\n",
    "    time_limit = 100\n",
    "else:\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument('--n_arms',         '-N', help='num beneficiaries (arms)', type=int, default=2)\n",
    "    parser.add_argument('--volunteers_per_arm',         '-V', help='volunteers per arm', type=int, default=5)\n",
    "    parser.add_argument('--episode_len',    '-H', help='episode length', type=int, default=50)\n",
    "    parser.add_argument('--n_episodes',     '-T', help='num episodes', type=int, default=105)\n",
    "    parser.add_argument('--budget',         '-B', help='budget', type=int, default=3)\n",
    "    parser.add_argument('--n_epochs',       '-E', help='number of epochs (num_repeats)', type=int, default=1)\n",
    "    parser.add_argument('--discount',       '-d', help='discount factor', type=float, default=0.9)\n",
    "    parser.add_argument('--alpha',          '-a', help='alpha: for conf radius', type=float, default=3)\n",
    "    parser.add_argument('--lamb',          '-l', help='lambda for matching-engagement tradeoff', type=float, default=0.5)\n",
    "    parser.add_argument('--universe_size', help='For set cover, total num unvierse elems', type=int, default=10)\n",
    "    parser.add_argument('--arm_set_low', help='Least size of arm set, for set cover', type=float, default=3)\n",
    "    parser.add_argument('--arm_set_high', help='Largest size of arm set, for set cover', type=float, default=6)\n",
    "    parser.add_argument('--reward_type',          '-r', help='Which type of custom reward', type=str, default='set_cover')\n",
    "    parser.add_argument('--seed',           '-s', help='random seed', type=int, default=42)\n",
    "    parser.add_argument('--prob_distro',           '-p', help='which prob distro [uniform,uniform_small,uniform_large,normal]', type=str, default='uniform')\n",
    "    parser.add_argument('--time_per_run',      '-t', help='time per MCTS run', type=float, default=.01*1000)\n",
    "    parser.add_argument('--policy_lr', help='Learning Rate Policy', type=float, default=5e-3)\n",
    "    parser.add_argument('--value_lr', help='Learning Rate Value', type=float, default=1e-4)\n",
    "    parser.add_argument('--train_iterations', help='Number of MCTS train iterations', type=int, default=30)\n",
    "    parser.add_argument('--test_iterations', help='Number of MCTS test iterations', type=int, default=30)\n",
    "    parser.add_argument('--out_folder', help='Which folder to write results to', type=str, default='iterative')\n",
    "    parser.add_argument('--time_limit', help='Online time limit for computation', type=float, default=100)\n",
    "\n",
    "\n",
    "    parser.add_argument('--use_date', action='store_true')\n",
    "\n",
    "    args = parser.parse_args()\n",
    "\n",
    "    n_arms      = args.n_arms\n",
    "    volunteers_per_arm = args.volunteers_per_arm\n",
    "    budget      = args.budget\n",
    "    discount    = args.discount\n",
    "    alpha       = args.alpha \n",
    "    seed        = args.seed\n",
    "    n_episodes  = args.n_episodes\n",
    "    episode_len = args.episode_len\n",
    "    n_epochs    = args.n_epochs\n",
    "    lamb = args.lamb\n",
    "    save_with_date = args.use_date\n",
    "    TIME_PER_RUN = args.time_per_run\n",
    "    prob_distro = args.prob_distro\n",
    "    policy_lr = args.policy_lr \n",
    "    value_lr = args.value_lr \n",
    "    out_folder = args.out_folder\n",
    "    train_iterations = args.train_iterations \n",
    "    test_iterations = args.test_iterations \n",
    "    reward_type = args.reward_type\n",
    "    reward_parameters = {'universe_size': args.universe_size,\n",
    "                        'arm_set_low': args.arm_set_low, \n",
    "                        'arm_set_high': args.arm_set_high}\n",
    "    time_limit = args.time_limit \n",
    "\n",
    "save_name = secrets.token_hex(4)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_states = 2\n",
    "n_actions = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(seed)\n",
    "all_population_size = 100 \n",
    "max_transition_prob = 0.25\n",
    "all_transitions = create_random_transitions(all_population_size,max_transition_prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "def partition_volunteers(probs_by_num,num_by_section):\n",
    "    total = sum([len(probs_by_num[i]) for i in probs_by_num])\n",
    "    num_per_section = total//num_by_section\n",
    "\n",
    "    nums_by_partition = []\n",
    "    current_count = 0\n",
    "    current_partition = []\n",
    "\n",
    "    keys = sorted(probs_by_num.keys())\n",
    "\n",
    "    for i in keys:\n",
    "        if current_count >= num_per_section*(len(nums_by_partition)+1):\n",
    "            nums_by_partition.append(current_partition)\n",
    "            current_partition = []\n",
    "        \n",
    "        current_partition.append(i)\n",
    "        current_count += len(probs_by_num[i])\n",
    "    return nums_by_partition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "if prob_distro == \"food_rescue_top\":\n",
    "    all_population_size = 20 \n",
    "    probs_by_user = json.load(open(\"../../results/food_rescue/match_probs.json\",\"r\"))\n",
    "    donation_id_to_latlon, recipient_location_to_latlon, rescues_by_user, all_rescue_data, user_id_to_latlon = get_db_data()\n",
    "    probs_by_num = {}\n",
    "    for i in rescues_by_user:\n",
    "        if str(i) in probs_by_user and probs_by_user[str(i)] > 0 and len(rescues_by_user[i]) >= 100:\n",
    "            if len(rescues_by_user[i]) not in probs_by_num:\n",
    "                probs_by_num[len(rescues_by_user[i])] = []\n",
    "            probs_by_num[len(rescues_by_user[i])].append(probs_by_user[str(i)])\n",
    "\n",
    "    partitions = partition_volunteers(probs_by_num,all_population_size)\n",
    "    probs_by_partition = []\n",
    "\n",
    "    for i in range(len(partitions)):\n",
    "        temp_probs = []\n",
    "        for j in partitions[i]:\n",
    "            temp_probs += (probs_by_num[j])\n",
    "        probs_by_partition.append(temp_probs)\n",
    "\n",
    "    all_transitions = get_all_transitions_partition(all_population_size,partitions)\n",
    "\n",
    "    for i,partition in enumerate(partitions):\n",
    "        current_transitions = np.array(all_transitions[i])\n",
    "        partition_scale = np.array([len(probs_by_num[j]) for j in partition])\n",
    "        partition_scale = partition_scale/np.sum(partition_scale)\n",
    "        prod = current_transitions*partition_scale[:,np.newaxis,np.newaxis,np.newaxis]\n",
    "        new_transition = np.sum(prod,axis=0)\n",
    "        all_transitions[i] = new_transition\n",
    "    all_transitions = np.array(all_transitions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "if prob_distro == \"food_rescue\":\n",
    "    all_population_size = 100 \n",
    "\n",
    "    probs_by_user = json.load(open(\"../../results/food_rescue/match_probs.json\",\"r\"))\n",
    "    donation_id_to_latlon, recipient_location_to_latlon, rescues_by_user, all_rescue_data, user_id_to_latlon = get_db_data()\n",
    "    probs_by_num = {}\n",
    "    for i in rescues_by_user:\n",
    "        if str(i) in probs_by_user and probs_by_user[str(i)] > 0 and len(rescues_by_user[i]) >= 3:\n",
    "            if len(rescues_by_user[i]) not in probs_by_num:\n",
    "                probs_by_num[len(rescues_by_user[i])] = []\n",
    "            probs_by_num[len(rescues_by_user[i])].append(probs_by_user[str(i)])\n",
    "\n",
    "    partitions = partition_volunteers(probs_by_num,all_population_size)\n",
    "    probs_by_partition = []\n",
    "    all_transitions = get_all_transitions_partition(all_population_size,partitions)\n",
    "\n",
    "    for i in range(len(partitions)):\n",
    "        temp_probs = []\n",
    "        for j in partitions[i]:\n",
    "            temp_probs += (probs_by_num[j])\n",
    "        probs_by_partition.append(temp_probs)\n",
    "\n",
    "    for i,partition in enumerate(partitions):\n",
    "        current_transitions = np.array(all_transitions[i])\n",
    "        partition_scale = np.array([len(probs_by_num[j]) for j in partition])\n",
    "        partition_scale = partition_scale/np.sum(partition_scale)\n",
    "        prod = current_transitions*partition_scale[:,np.newaxis,np.newaxis,np.newaxis]\n",
    "        new_transition = np.sum(prod,axis=0)\n",
    "        all_transitions[i] = new_transition\n",
    "    all_transitions = np.array(all_transitions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "if prob_distro == \"high_prob\":\n",
    "    np.random.seed(seed)\n",
    "    all_population_size = 100 \n",
    "    max_transition_prob = 1.0\n",
    "    all_transitions = create_random_transitions(all_population_size,max_transition_prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "if prob_distro == \"one_time\":\n",
    "    np.random.seed(seed)\n",
    "    all_population_size = 100 \n",
    "    max_transition_prob = 1.0\n",
    "    all_transitions = np.zeros((all_population_size,2,2,2))\n",
    "    all_transitions[:,:,1,0] = 1\n",
    "    all_transitions[:,1,0,1] = 1\n",
    "    all_transitions[:,0,0,0] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_environment(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    all_features = np.arange(all_population_size)\n",
    "    N = all_population_size*volunteers_per_arm\n",
    "    if reward_type == \"set_cover\":\n",
    "        if prob_distro == \"fixed\":\n",
    "            match_probabilities = []\n",
    "            set_sizes = [int(reward_parameters['arm_set_low']) for i in range(N)]\n",
    "            for i in range(N):\n",
    "                s = set() \n",
    "                while len(s) < set_sizes[i]:\n",
    "                    s.add(np.random.randint(0,reward_parameters['universe_size']))\n",
    "                match_probabilities.append(s)\n",
    "        else:\n",
    "            set_sizes = [np.random.randint(int(reward_parameters['arm_set_low']),int(reward_parameters['arm_set_high'])+1) for i in range(N)]\n",
    "            match_probabilities = [] \n",
    "            \n",
    "            for i in range(N):\n",
    "                temp_set = set() \n",
    "                \n",
    "                while len(temp_set) < set_sizes[i]:\n",
    "                    temp_set.add(np.random.randint(0,reward_parameters['universe_size']))\n",
    "                match_probabilities.append(temp_set)\n",
    "    elif prob_distro == \"food_rescue\" or prob_distro == \"food_rescue_top\":\n",
    "        match_probabilities = [np.random.choice(probs_by_partition[i//volunteers_per_arm]) for i in range(N)] \n",
    "    else:\n",
    "        match_probabilities = [np.random.uniform(reward_parameters['arm_set_low'],reward_parameters['arm_set_high']) for i in range(N)]\n",
    "\n",
    "    simulator = RMABSimulator(all_population_size, all_features, all_transitions,\n",
    "                n_arms, volunteers_per_arm, episode_len, n_epochs, n_episodes, budget, discount,number_states=n_states, reward_style='custom',match_probability_list=match_probabilities,TIME_PER_RUN=TIME_PER_RUN)\n",
    "    simulator.reward_type = reward_type \n",
    "    simulator.reward_parameters = reward_parameters \n",
    "    return simulator "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_multi_seed(seed_list,policy,is_mcts=False,per_epoch_function=None,train_iterations=0,test_iterations=0,test_length=20,avg_reward=0,num_samples=100):\n",
    "    memories = []\n",
    "    scores = {\n",
    "        'reward': [],\n",
    "        'time': [], \n",
    "        'match': [], \n",
    "        'active_rate': [],\n",
    "    }\n",
    "\n",
    "    for seed in seed_list:\n",
    "        simulator = create_environment(seed)\n",
    "        simulator.time_limit = time_limit\n",
    "        simulator.avg_reward = avg_reward\n",
    "        simulator.num_samples = num_samples\n",
    "        simulator.mcts_train_iterations = train_iterations\n",
    "        simulator.mcts_test_iterations = 400\n",
    "        simulator.policy_lr = policy_lr\n",
    "        simulator.value_lr = value_lr\n",
    "        simulator.mcts_depth = 2\n",
    "        simulator.shapley_iterations = 1000 \n",
    "\n",
    "        if prob_distro == \"one_time\":\n",
    "            N = n_arms*volunteers_per_arm\n",
    "            simulator.first_init_states = np.array([[[1 for i in range(N)] for i in range(n_episodes)]])\n",
    "            random.seed(seed)\n",
    "            shuffled_list = [reward_parameters['arm_set_high'] for i in range(2)] + [reward_parameters['arm_set_high'] for i in range(N-2)]\n",
    "            random.shuffle(shuffled_list)\n",
    "\n",
    "            simulator.match_probability_list[simulator.cohort_selection[0]] = shuffled_list\n",
    "\n",
    "        if is_mcts:\n",
    "            match, active_rate, memory = run_heterogenous_policy(simulator, n_episodes, n_epochs, discount,policy,seed,lamb=lamb,should_train=True,test_T=test_length,get_memory=True,per_epoch_function=per_epoch_function)\n",
    "        else:\n",
    "            match, active_rate = run_heterogenous_policy(simulator, n_episodes, n_epochs, discount,policy,seed,lamb=lamb,should_train=False,test_T=test_length,per_epoch_function=per_epoch_function)\n",
    "        num_timesteps = match.size\n",
    "        match = match.reshape((num_timesteps//episode_len,episode_len))\n",
    "        active_rate = active_rate.reshape((num_timesteps//episode_len,episode_len))\n",
    "\n",
    " \n",
    "        time_whittle = simulator.time_taken\n",
    "        discounted_reward = get_discounted_reward(match,active_rate,discount,lamb)\n",
    "        scores['reward'].append(discounted_reward)\n",
    "        scores['time'].append(time_whittle)\n",
    "        scores['match'].append(np.mean(match))\n",
    "        scores['active_rate'].append(np.mean(active_rate))\n",
    "        if is_mcts:\n",
    "            memories.append(memory)\n",
    "\n",
    "    return scores, memories, simulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "results['parameters'] = {'seed'      : seed,\n",
    "        'n_arms'    : n_arms,\n",
    "        'volunteers_per_arm': volunteers_per_arm, \n",
    "        'budget'    : budget,\n",
    "        'discount'  : discount, \n",
    "        'alpha'     : alpha, \n",
    "        'n_episodes': n_episodes, \n",
    "        'episode_len': episode_len, \n",
    "        'n_epochs'  : n_epochs, \n",
    "        'lamb': lamb,\n",
    "        'time_per_run': TIME_PER_RUN, \n",
    "        'prob_distro': prob_distro, \n",
    "        'policy_lr': policy_lr, \n",
    "        'value_lr': value_lr, \n",
    "        'reward_type': reward_type, \n",
    "        'universe_size': reward_parameters['universe_size'],\n",
    "        'arm_set_low': reward_parameters['arm_set_low'], \n",
    "        'arm_set_high': reward_parameters['arm_set_high'],\n",
    "        'time_limit': time_limit\n",
    "        } "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Index Policies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_list = [seed]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cohort [61 54 87 93]\n",
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "instance 0, ep 50\n",
      "instance 0, ep 51\n",
      "instance 0, ep 52\n",
      "instance 0, ep 53\n",
      "instance 0, ep 54\n",
      "instance 0, ep 55\n",
      "instance 0, ep 56\n",
      "instance 0, ep 57\n",
      "instance 0, ep 58\n",
      "instance 0, ep 59\n",
      "instance 0, ep 60\n",
      "instance 0, ep 61\n",
      "instance 0, ep 62\n",
      "instance 0, ep 63\n",
      "instance 0, ep 64\n",
      "instance 0, ep 65\n",
      "instance 0, ep 66\n",
      "instance 0, ep 67\n",
      "instance 0, ep 68\n",
      "instance 0, ep 69\n",
      "instance 0, ep 70\n",
      "instance 0, ep 71\n",
      "instance 0, ep 72\n",
      "instance 0, ep 73\n",
      "instance 0, ep 74\n",
      "instance 0, ep 75\n",
      "instance 0, ep 76\n",
      "instance 0, ep 77\n",
      "instance 0, ep 78\n",
      "instance 0, ep 79\n",
      "instance 0, ep 80\n",
      "instance 0, ep 81\n",
      "instance 0, ep 82\n",
      "instance 0, ep 83\n",
      "instance 0, ep 84\n",
      "instance 0, ep 85\n",
      "instance 0, ep 86\n",
      "instance 0, ep 87\n",
      "instance 0, ep 88\n",
      "instance 0, ep 89\n",
      "instance 0, ep 90\n",
      "instance 0, ep 91\n",
      "instance 0, ep 92\n",
      "instance 0, ep 93\n",
      "instance 0, ep 94\n",
      "instance 0, ep 95\n",
      "instance 0, ep 96\n",
      "instance 0, ep 97\n",
      "instance 0, ep 98\n",
      "instance 0, ep 99\n",
      "instance 0, ep 100\n",
      "instance 0, ep 101\n",
      "instance 0, ep 102\n",
      "instance 0, ep 103\n",
      "instance 0, ep 104\n",
      "Took 0.03727865219116211 time for inference and 0.4646177291870117 time for training\n",
      "7.359031745134615\n"
     ]
    }
   ],
   "source": [
    "policy = whittle_policy\n",
    "name = \"linear_whittle\"\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,test_length=episode_len*(n_episodes%50))\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cohort [61 54 87 93]\n",
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "instance 0, ep 50\n",
      "instance 0, ep 51\n",
      "instance 0, ep 52\n",
      "instance 0, ep 53\n",
      "instance 0, ep 54\n",
      "instance 0, ep 55\n",
      "instance 0, ep 56\n",
      "instance 0, ep 57\n",
      "instance 0, ep 58\n",
      "instance 0, ep 59\n",
      "instance 0, ep 60\n",
      "instance 0, ep 61\n",
      "instance 0, ep 62\n",
      "instance 0, ep 63\n",
      "instance 0, ep 64\n",
      "instance 0, ep 65\n",
      "instance 0, ep 66\n",
      "instance 0, ep 67\n",
      "instance 0, ep 68\n",
      "instance 0, ep 69\n",
      "instance 0, ep 70\n",
      "instance 0, ep 71\n",
      "instance 0, ep 72\n",
      "instance 0, ep 73\n",
      "instance 0, ep 74\n",
      "instance 0, ep 75\n",
      "instance 0, ep 76\n",
      "instance 0, ep 77\n",
      "instance 0, ep 78\n",
      "instance 0, ep 79\n",
      "instance 0, ep 80\n",
      "instance 0, ep 81\n",
      "instance 0, ep 82\n",
      "instance 0, ep 83\n",
      "instance 0, ep 84\n",
      "instance 0, ep 85\n",
      "instance 0, ep 86\n",
      "instance 0, ep 87\n",
      "instance 0, ep 88\n",
      "instance 0, ep 89\n",
      "instance 0, ep 90\n",
      "instance 0, ep 91\n",
      "instance 0, ep 92\n",
      "instance 0, ep 93\n",
      "instance 0, ep 94\n",
      "instance 0, ep 95\n",
      "instance 0, ep 96\n",
      "instance 0, ep 97\n",
      "instance 0, ep 98\n",
      "instance 0, ep 99\n",
      "instance 0, ep 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr0/home/naveenr/miniconda3/envs/food/lib/python3.8/site-packages/numpy/core/fromnumeric.py:3464: RuntimeWarning: Mean of empty slice.\n",
      "  return _methods._mean(a, axis=axis, dtype=dtype,\n",
      "/usr0/home/naveenr/miniconda3/envs/food/lib/python3.8/site-packages/numpy/core/_methods.py:192: RuntimeWarning: invalid value encountered in scalar divide\n",
      "  ret = ret.dtype.type(ret / rcount)\n",
      "/usr0/home/naveenr/projects/food_rescue_rmab/rmab/mcts_policies.py:157: RuntimeWarning: divide by zero encountered in log\n",
      "  choices_weights = [(c.q() / c.n()) + c_param * np.sqrt((2 * np.log(self.n()) / c.n())) for c in self.children]\n",
      "/usr0/home/naveenr/projects/food_rescue_rmab/rmab/mcts_policies.py:157: RuntimeWarning: invalid value encountered in sqrt\n",
      "  choices_weights = [(c.q() / c.n()) + c_param * np.sqrt((2 * np.log(self.n()) / c.n())) for c in self.children]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "instance 0, ep 101\n",
      "instance 0, ep 102\n",
      "instance 0, ep 103\n",
      "instance 0, ep 104\n",
      "Took 27.098926067352295 time for inference and 0.5985987186431885 time for training\n",
      "7.316326794639589\n"
     ]
    }
   ],
   "source": [
    "policy = mcts_policy\n",
    "name = \"mcts\"\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,test_length=episode_len*(n_episodes%50),test_iterations=400)\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cohort [61 54 87 93]\n",
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "instance 0, ep 50\n",
      "instance 0, ep 51\n",
      "instance 0, ep 52\n",
      "instance 0, ep 53\n",
      "instance 0, ep 54\n",
      "instance 0, ep 55\n",
      "instance 0, ep 56\n",
      "instance 0, ep 57\n",
      "instance 0, ep 58\n",
      "instance 0, ep 59\n",
      "instance 0, ep 60\n",
      "instance 0, ep 61\n",
      "instance 0, ep 62\n",
      "instance 0, ep 63\n",
      "instance 0, ep 64\n",
      "instance 0, ep 65\n",
      "instance 0, ep 66\n",
      "instance 0, ep 67\n",
      "instance 0, ep 68\n",
      "instance 0, ep 69\n",
      "instance 0, ep 70\n",
      "instance 0, ep 71\n",
      "instance 0, ep 72\n",
      "instance 0, ep 73\n",
      "instance 0, ep 74\n",
      "instance 0, ep 75\n",
      "instance 0, ep 76\n",
      "instance 0, ep 77\n",
      "instance 0, ep 78\n",
      "instance 0, ep 79\n",
      "instance 0, ep 80\n",
      "instance 0, ep 81\n",
      "instance 0, ep 82\n",
      "instance 0, ep 83\n",
      "instance 0, ep 84\n",
      "instance 0, ep 85\n",
      "instance 0, ep 86\n",
      "instance 0, ep 87\n",
      "instance 0, ep 88\n",
      "instance 0, ep 89\n",
      "instance 0, ep 90\n",
      "instance 0, ep 91\n",
      "instance 0, ep 92\n",
      "instance 0, ep 93\n",
      "instance 0, ep 94\n",
      "instance 0, ep 95\n",
      "instance 0, ep 96\n",
      "instance 0, ep 97\n",
      "instance 0, ep 98\n",
      "instance 0, ep 99\n",
      "instance 0, ep 100\n",
      "instance 0, ep 101\n",
      "instance 0, ep 102\n",
      "instance 0, ep 103\n",
      "instance 0, ep 104\n",
      "Took 0.0383305549621582 time for inference and 0.45908331871032715 time for training\n",
      "7.373958381332352\n"
     ]
    }
   ],
   "source": [
    "policy = greedy_policy\n",
    "name = \"greedy\"\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,test_length=episode_len*(n_episodes%50))\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cohort [61 54 87 93]\n",
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "instance 0, ep 50\n",
      "instance 0, ep 51\n",
      "instance 0, ep 52\n",
      "instance 0, ep 53\n",
      "instance 0, ep 54\n",
      "instance 0, ep 55\n",
      "instance 0, ep 56\n",
      "instance 0, ep 57\n",
      "instance 0, ep 58\n",
      "instance 0, ep 59\n",
      "instance 0, ep 60\n",
      "instance 0, ep 61\n",
      "instance 0, ep 62\n",
      "instance 0, ep 63\n",
      "instance 0, ep 64\n",
      "instance 0, ep 65\n",
      "instance 0, ep 66\n",
      "instance 0, ep 67\n",
      "instance 0, ep 68\n",
      "instance 0, ep 69\n",
      "instance 0, ep 70\n",
      "instance 0, ep 71\n",
      "instance 0, ep 72\n",
      "instance 0, ep 73\n",
      "instance 0, ep 74\n",
      "instance 0, ep 75\n",
      "instance 0, ep 76\n",
      "instance 0, ep 77\n",
      "instance 0, ep 78\n",
      "instance 0, ep 79\n",
      "instance 0, ep 80\n",
      "instance 0, ep 81\n",
      "instance 0, ep 82\n",
      "instance 0, ep 83\n",
      "instance 0, ep 84\n",
      "instance 0, ep 85\n",
      "instance 0, ep 86\n",
      "instance 0, ep 87\n",
      "instance 0, ep 88\n",
      "instance 0, ep 89\n",
      "instance 0, ep 90\n",
      "instance 0, ep 91\n",
      "instance 0, ep 92\n",
      "instance 0, ep 93\n",
      "instance 0, ep 94\n",
      "instance 0, ep 95\n",
      "instance 0, ep 96\n",
      "instance 0, ep 97\n",
      "instance 0, ep 98\n",
      "instance 0, ep 99\n",
      "instance 0, ep 100\n",
      "instance 0, ep 101\n",
      "instance 0, ep 102\n",
      "instance 0, ep 103\n",
      "instance 0, ep 104\n",
      "Took 0.024310588836669922 time for inference and 0.4572174549102783 time for training\n",
      "5.006209262579181\n"
     ]
    }
   ],
   "source": [
    "policy = random_policy\n",
    "name = \"random\"\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,test_length=episode_len*(n_episodes%50))\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cohort [61 54 87 93]\n",
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "instance 0, ep 50\n",
      "instance 0, ep 51\n",
      "instance 0, ep 52\n",
      "instance 0, ep 53\n",
      "instance 0, ep 54\n",
      "instance 0, ep 55\n",
      "instance 0, ep 56\n",
      "instance 0, ep 57\n",
      "instance 0, ep 58\n",
      "instance 0, ep 59\n",
      "instance 0, ep 60\n",
      "instance 0, ep 61\n",
      "instance 0, ep 62\n",
      "instance 0, ep 63\n",
      "instance 0, ep 64\n",
      "instance 0, ep 65\n",
      "instance 0, ep 66\n",
      "instance 0, ep 67\n",
      "instance 0, ep 68\n",
      "instance 0, ep 69\n",
      "instance 0, ep 70\n",
      "instance 0, ep 71\n",
      "instance 0, ep 72\n",
      "instance 0, ep 73\n",
      "instance 0, ep 74\n",
      "instance 0, ep 75\n",
      "instance 0, ep 76\n",
      "instance 0, ep 77\n",
      "instance 0, ep 78\n",
      "instance 0, ep 79\n",
      "instance 0, ep 80\n",
      "instance 0, ep 81\n",
      "instance 0, ep 82\n",
      "instance 0, ep 83\n",
      "instance 0, ep 84\n",
      "instance 0, ep 85\n",
      "instance 0, ep 86\n",
      "instance 0, ep 87\n",
      "instance 0, ep 88\n",
      "instance 0, ep 89\n",
      "instance 0, ep 90\n",
      "instance 0, ep 91\n",
      "instance 0, ep 92\n",
      "instance 0, ep 93\n",
      "instance 0, ep 94\n",
      "instance 0, ep 95\n",
      "instance 0, ep 96\n",
      "instance 0, ep 97\n",
      "instance 0, ep 98\n",
      "instance 0, ep 99\n",
      "instance 0, ep 100\n",
      "instance 0, ep 101\n",
      "instance 0, ep 102\n",
      "instance 0, ep 103\n",
      "instance 0, ep 104\n",
      "Took 0.03218364715576172 time for inference and 0.45625734329223633 time for training\n",
      "7.263818437942289\n"
     ]
    }
   ],
   "source": [
    "policy = whittle_activity_policy\n",
    "name = \"whittle_activity\"\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,test_length=episode_len*(n_episodes%50))\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running DQN\n",
      "cohort [61 54 87 93]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "instance 0, ep 50\n",
      "instance 0, ep 51\n",
      "instance 0, ep 52\n",
      "instance 0, ep 53\n",
      "instance 0, ep 54\n",
      "instance 0, ep 55\n",
      "instance 0, ep 56\n",
      "instance 0, ep 57\n",
      "instance 0, ep 58\n",
      "instance 0, ep 59\n",
      "instance 0, ep 60\n",
      "instance 0, ep 61\n",
      "instance 0, ep 62\n",
      "instance 0, ep 63\n",
      "instance 0, ep 64\n",
      "instance 0, ep 65\n",
      "instance 0, ep 66\n",
      "instance 0, ep 67\n",
      "instance 0, ep 68\n",
      "instance 0, ep 69\n",
      "instance 0, ep 70\n",
      "instance 0, ep 71\n",
      "instance 0, ep 72\n",
      "instance 0, ep 73\n",
      "instance 0, ep 74\n",
      "instance 0, ep 75\n",
      "instance 0, ep 76\n",
      "instance 0, ep 77\n",
      "instance 0, ep 78\n",
      "instance 0, ep 79\n",
      "instance 0, ep 80\n",
      "instance 0, ep 81\n",
      "instance 0, ep 82\n",
      "instance 0, ep 83\n",
      "instance 0, ep 84\n",
      "instance 0, ep 85\n",
      "instance 0, ep 86\n",
      "instance 0, ep 87\n",
      "instance 0, ep 88\n",
      "instance 0, ep 89\n",
      "instance 0, ep 90\n",
      "instance 0, ep 91\n",
      "instance 0, ep 92\n",
      "instance 0, ep 93\n",
      "instance 0, ep 94\n",
      "instance 0, ep 95\n",
      "instance 0, ep 96\n",
      "instance 0, ep 97\n",
      "instance 0, ep 98\n",
      "instance 0, ep 99\n",
      "instance 0, ep 100\n",
      "instance 0, ep 101\n",
      "instance 0, ep 102\n",
      "instance 0, ep 103\n",
      "instance 0, ep 104\n",
      "Took 0.0848855972290039 time for inference and 8.914021730422974 time for training\n",
      "6.777497652478743\n"
     ]
    }
   ],
   "source": [
    "if n_arms*volunteers_per_arm <= 10:\n",
    "    policy = dqn_policy_greedy\n",
    "    name = \"dqn\"\n",
    "\n",
    "    print(\"Running DQN\")\n",
    "\n",
    "    rewards, memory, simulator = run_multi_seed(seed_list,policy,is_mcts=True,avg_reward=np.mean(results['linear_whittle_reward'][0]),test_length=episode_len*(n_episodes%50))\n",
    "    results['{}_reward'.format(name)] = rewards['reward']\n",
    "    results['{}_match'.format(name)] =  rewards['match'] \n",
    "    results['{}_active'.format(name)] = rewards['active_rate']\n",
    "    results['{}_time'.format(name)] =  rewards['time']\n",
    "    print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running DQN Step\n",
      "cohort [61 54 87 93]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "instance 0, ep 1\n",
      "instance 0, ep 2\n",
      "instance 0, ep 3\n",
      "instance 0, ep 4\n",
      "instance 0, ep 5\n",
      "instance 0, ep 6\n",
      "instance 0, ep 7\n",
      "instance 0, ep 8\n",
      "instance 0, ep 9\n",
      "instance 0, ep 10\n",
      "instance 0, ep 11\n",
      "instance 0, ep 12\n",
      "instance 0, ep 13\n",
      "instance 0, ep 14\n",
      "instance 0, ep 15\n",
      "instance 0, ep 16\n",
      "instance 0, ep 17\n",
      "instance 0, ep 18\n",
      "instance 0, ep 19\n",
      "instance 0, ep 20\n",
      "instance 0, ep 21\n",
      "instance 0, ep 22\n",
      "instance 0, ep 23\n",
      "instance 0, ep 24\n",
      "instance 0, ep 25\n",
      "instance 0, ep 26\n",
      "instance 0, ep 27\n",
      "instance 0, ep 28\n",
      "instance 0, ep 29\n",
      "instance 0, ep 30\n",
      "instance 0, ep 31\n",
      "instance 0, ep 32\n",
      "instance 0, ep 33\n",
      "instance 0, ep 34\n",
      "instance 0, ep 35\n",
      "instance 0, ep 36\n",
      "instance 0, ep 37\n",
      "instance 0, ep 38\n",
      "instance 0, ep 39\n",
      "instance 0, ep 40\n",
      "instance 0, ep 41\n",
      "instance 0, ep 42\n",
      "instance 0, ep 43\n",
      "instance 0, ep 44\n",
      "instance 0, ep 45\n",
      "instance 0, ep 46\n",
      "instance 0, ep 47\n",
      "instance 0, ep 48\n",
      "instance 0, ep 49\n",
      "instance 0, ep 50\n",
      "instance 0, ep 51\n",
      "instance 0, ep 52\n",
      "instance 0, ep 53\n",
      "instance 0, ep 54\n",
      "instance 0, ep 55\n",
      "instance 0, ep 56\n",
      "instance 0, ep 57\n",
      "instance 0, ep 58\n",
      "instance 0, ep 59\n",
      "instance 0, ep 60\n",
      "instance 0, ep 61\n",
      "instance 0, ep 62\n",
      "instance 0, ep 63\n",
      "instance 0, ep 64\n",
      "instance 0, ep 65\n",
      "instance 0, ep 66\n",
      "instance 0, ep 67\n",
      "instance 0, ep 68\n",
      "instance 0, ep 69\n",
      "instance 0, ep 70\n",
      "instance 0, ep 71\n",
      "instance 0, ep 72\n",
      "instance 0, ep 73\n",
      "instance 0, ep 74\n",
      "instance 0, ep 75\n",
      "instance 0, ep 76\n",
      "instance 0, ep 77\n",
      "instance 0, ep 78\n",
      "instance 0, ep 79\n",
      "instance 0, ep 80\n",
      "instance 0, ep 81\n",
      "instance 0, ep 82\n",
      "instance 0, ep 83\n",
      "instance 0, ep 84\n",
      "instance 0, ep 85\n",
      "instance 0, ep 86\n",
      "instance 0, ep 87\n",
      "instance 0, ep 88\n",
      "instance 0, ep 89\n",
      "instance 0, ep 90\n",
      "instance 0, ep 91\n",
      "instance 0, ep 92\n",
      "instance 0, ep 93\n",
      "instance 0, ep 94\n",
      "instance 0, ep 95\n",
      "instance 0, ep 96\n",
      "instance 0, ep 97\n",
      "instance 0, ep 98\n",
      "instance 0, ep 99\n",
      "instance 0, ep 100\n",
      "instance 0, ep 101\n",
      "instance 0, ep 102\n",
      "instance 0, ep 103\n",
      "instance 0, ep 104\n",
      "Took 0.08581185340881348 time for inference and 10.183659791946411 time for training\n",
      "6.7161052595174215\n"
     ]
    }
   ],
   "source": [
    "policy = dqn_with_steps\n",
    "name = \"dqn_step\"\n",
    "\n",
    "print(\"Running DQN Step\")\n",
    "\n",
    "rewards, memory, simulator = run_multi_seed(seed_list,policy,is_mcts=True,avg_reward=np.mean(results['linear_whittle_reward'][0]),test_length=episode_len*(n_episodes%50))\n",
    "results['{}_reward'.format(name)] = rewards['reward']\n",
    "results['{}_match'.format(name)] =  rewards['match'] \n",
    "results['{}_active'.format(name)] = rewards['active_rate']\n",
    "results['{}_time'.format(name)] =  rewards['time']\n",
    "print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "# policy = dqn_with_stablization_steps\n",
    "# name = \"dqn_stable_step\"\n",
    "\n",
    "# print(\"Running DQN Step\")\n",
    "\n",
    "# rewards, memory, simulator = run_multi_seed(seed_list,policy,is_mcts=True,avg_reward=np.mean(results['linear_whittle_reward'][0]),test_length=episode_len,num_samples=1024)\n",
    "# results['{}_reward'.format(name)] = rewards['reward']\n",
    "# results['{}_match'.format(name)] =  rewards['match'] \n",
    "# results['{}_active'.format(name)] = rewards['active_rate']\n",
    "# results['{}_time'.format(name)] =  rewards['time']\n",
    "# print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if n_arms * volunteers_per_arm <= 4:\n",
    "#     policy = q_iteration_policy\n",
    "#     per_epoch_function = q_iteration_custom_epoch()\n",
    "#     name = \"optimal\"\n",
    "\n",
    "#     rewards, memory, simulator = run_multi_seed(seed_list,policy,per_epoch_function=per_epoch_function,test_length=episode_len*(n_episodes%50))\n",
    "#     results['{}_reward'.format(name)] = rewards['reward']\n",
    "#     results['{}_match'.format(name)] =  rewards['match'] \n",
    "#     results['{}_active'.format(name)] = rewards['active_rate']\n",
    "#     results['{}_time'.format(name)] =  rewards['time']\n",
    "#     print(np.mean(rewards['reward']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Write Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = get_save_path(out_folder,save_name,seed,use_date=save_with_date)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Deleting ../../results/mcts_exploration/rl_exploration/63d9ed7d_44.json\n"
     ]
    }
   ],
   "source": [
    "delete_duplicate_results(out_folder,\"\",results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(results,open('../../results/'+save_path,'w'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "food",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
