{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random \n",
    "import matplotlib.pyplot as plt\n",
    "import json \n",
    "import argparse \n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rmab.simulator import RMABSimulator, random_valid_transition, random_valid_transition_round_down, synthetic_transition_small_window\n",
    "from rmab.uc_whittle import UCWhittle, UCWhittleFixed \n",
    "from rmab.ucw_value import UCWhittle_value, UCWhittle_value_fixed\n",
    "from rmab.baselines import optimal_policy, random_policy, WIQL\n",
    "from rmab.fr_dynamics import get_all_transitions\n",
    "from rmab.utils import get_save_path, delete_duplicate_results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_jupyter = 'ipykernel' in sys.modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_arms = 2\n",
    "all_population_size = 100\n",
    "q = 0.5\n",
    "p = 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_transitions = get_all_transitions(all_population_size)\n",
    "all_transitions = all_transitions[[25,99]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[0.86206897, 0.13793103],\n",
       "        [0.47011952, 0.52988048]],\n",
       "\n",
       "       [[0.55952381, 0.44047619],\n",
       "        [0.28521127, 0.71478873]]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_transitions[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_q(time,volunteer_num,decisions_volunteer):\n",
    "    q_vals = [q]\n",
    "\n",
    "    assert len(decisions_volunteer) == time \n",
    "\n",
    "    for t in range(1,time+1):\n",
    "        this_q = q_vals[t-1]*decisions_volunteer[t-1]*all_transitions[volunteer_num][1][1][1]\n",
    "        this_q += q_vals[t-1]*(1-decisions_volunteer[t-1])*all_transitions[volunteer_num][1][0][1]\n",
    "        this_q += (1-q_vals[t-1])*decisions_volunteer[t-1]*all_transitions[volunteer_num][0][1][1]\n",
    "        this_q += (1-q_vals[t-1])*(1-decisions_volunteer[t-1])*all_transitions[volunteer_num][0][0][1]\n",
    "\n",
    "        q_vals.append(this_q)\n",
    "\n",
    "    return q_vals[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reward(all_decisions,total_time):\n",
    "    for t in range(total_time):\n",
    "        q_vals_by_volunteer = [compute_q(t,i,all_decisions[:t]) for i in range(n_arms)]\n",
    "\n",
    "        reward = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.28920361247947457"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decisions_volunteer = [0]\n",
    "compute_q(1,0,decisions_volunteer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "food",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
