{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5e591f32",
   "metadata": {},
   "outputs": [],
   "source": [
    "from util import set_seed\n",
    "from argparse import Namespace\n",
    "import torch\n",
    "import numpy as np\n",
    "from train_models_dpt import get_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "20709991",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = Namespace(user='tcai',\n",
    "data_dir='/shared/share_mala/implicitbayes/dataset_files/synthetic_data/bimodal/N=500,D=2500,D_eval=1000,cnts1=25,cnts2=25,forced',\n",
    "                   marginal_vs_sequential='sequential',\n",
    "                   batch_size=20,\n",
    "                   eval_batch_size=20,\n",
    "                   Z_dim=2,\n",
    "                   learning_rate=0.001,\n",
    "                   dataset_type='synthetic',\n",
    "                   seed=2340923,\n",
    "                   embed_data_dir=None,\n",
    "                   extra_eval_data=None,\n",
    "                   sequential_init_mean=0.5,\n",
    "                   MLP_width=100,\n",
    "                   MLP_layer=3,\n",
    "                   repeat_suffstat=10,\n",
    "                   device='cpu',\n",
    "                   weight_decay=1e-2,\n",
    "                   MLP_last_fn='none',\n",
    "                   num_arms=10\n",
    ")\n",
    "data = torch.load(config.data_dir+'/eval_data.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8c132bf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 234234234\n",
    "set_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a32bd777",
   "metadata": {},
   "outputs": [],
   "source": [
    "# select 10 arms\n",
    "num_arms = 10\n",
    "\n",
    "T = data['Y'].shape[1]\n",
    "total_rows = len(data['Z'])\n",
    "E = 1000 # number of environments\n",
    "# assume one history per environment...\n",
    "H = T # length of history\n",
    "\n",
    "envs = []\n",
    "histories = []\n",
    "for e in range(E): \n",
    "    if e % 100 == 0:\n",
    "        print(e)\n",
    "    arm_idxs = (torch.randperm(total_rows) < num_arms).int().nonzero().flatten()\n",
    "\n",
    "    this_Z = data['Z'][arm_idxs]\n",
    "    arm_counts = torch.zeros(num_arms)\n",
    "    this_Y = data['Y'][arm_idxs]\n",
    "    this_best_arm = data['click_rate'][arm_idxs].argmax()\n",
    "    targets = torch.zeros(num_arms)\n",
    "    targets[this_best_arm] = 1\n",
    "\n",
    "    envs.append({'Z':this_Z, 'Y':this_Y, 'best_arm':targets, 'arm_rewards': data['click_rate'][arm_idxs]})\n",
    "    probs = get_probs(10)\n",
    "    \n",
    "    this_history = torch.zeros(H)\n",
    "    for h in range(H):\n",
    "        chosen_arm = np.random.choice(np.arange(num_arms), p=probs)    \n",
    "        cols = torch.arange(T).unsqueeze(0)  \n",
    "        this_history[h] = chosen_arm\n",
    "    histories.append(this_history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4c364452",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save({'env':envs, 'hist':histories}, config.data_dir + f'/dpt_histories/eval_hist_num_arms={num_arms},E={E},H={H},S=1,seed={seed}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "126b80eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.mkdir(config.data_dir + '/dpt_histories')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
