{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e2cdcb67",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, numpy as np\n",
    "from argparse import Namespace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "71758857",
   "metadata": {},
   "outputs": [],
   "source": [
    "from train_models_dpt import get_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "ba7962e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = Namespace(save_name='0927_dpt_bimodal_test', \n",
    "data_dir='/shared/share_mala/implicitbayes/dataset_files/MIND_data/filter100/',\n",
    "                   embed_data_dir=False, \n",
    "                   extra_eval_data=None, \n",
    "                   wandb_user='tcai', \n",
    "                   learning_rate=0.001, \n",
    "                   weight_decay=0.01, epochs=1000, num_arms=10, \n",
    "                   batch_size=5, eval_batch_size=50, \n",
    "                   sequential_beta_bernoulli=False, \n",
    "                   sequential_beta_bernoulli_alpha_beta=False, \n",
    "                   sequential_beta_bernoulli_less_params=False, \n",
    "                   seed=2340923, onelayer=False, MLP_width=100, MLP_layer=3, MLP_last_fn='none', \n",
    "                   repeat_suffstat=10, rand_prior=0, prior_scale=0, \n",
    "                   postprocess_often=0, sequential_one_length=None, \n",
    "                   weight_factor=1, scheduler_type='constant', \n",
    "                   dataset_type='MIND', sample_frac=1.0, \n",
    "                   num_loader_obs=500, \n",
    "                   num_loader_obs_train=500, \n",
    "                   datasplit=None, \n",
    "                   trainlens=None, \n",
    "                   trainlensexact=None, \n",
    "                   savelens=None, use_X=True, \n",
    "                   use_X_model=False, Z_dim=2, X_dim=1, use_text=1, \n",
    "                   use_category=0, transform_success_p_alpha=1, click_data_suffix=None, \n",
    "                   bootstrap_seed=None, bert_learning_rate=None, bert_weight_decay=None, \n",
    "                   freeze_bert=False, load_bert_file=None, gpu=None, \n",
    "                   aplusb_learning_rate=None, sequential_init_mean=0.5, \n",
    "                   init_weights_path='/shared/share_mala/implicitbayes/dataset_files/MIND_data/filter100/bertmodels/marg_bert_rate/marginal:epochs=500,bs=100,lr=1e-05,bert_lr=1e-05,wd=0.01,sample_frac=1.0,onelayer=False,freezebert=False,has_marg_model=False,min_obs_length=100,no_Z=False,transform_success_p_alpha=1,num_train_rows_for_eval=None,seed=2340923', \n",
    "                   test_idea='', postproc_force_recalc=True, \n",
    "                   post_sample_all_num_prev_obs=[0, 1, 2, 5, 10, 25], \n",
    "                   post_sample_num_repetitions=250, post_sample_num_imagined=500, \n",
    "                   device='cpu',\n",
    "                   marginal_vs_sequential='sequential')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "1762e64b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train dataset size: 9122\n",
      "eval dataset size: 2280\n"
     ]
    }
   ],
   "source": [
    "from dataset_MIND import get_loaders_MIND\n",
    "mind_loaders = get_loaders_MIND(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "4e7ec243",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_loader = mind_loaders['val_loader']\n",
    "val_dataset = mind_loaders['val_dataset']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "0e13e9b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_rows = len(val_dataset)\n",
    "T = val_dataset.num_loader_obs\n",
    "num_arms = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d7775bb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from util import set_seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "afa5f048",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 234234234\n",
    "set_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "a81a1831",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "100\n",
      "200\n",
      "300\n",
      "400\n",
      "500\n",
      "600\n",
      "700\n",
      "800\n",
      "900\n"
     ]
    }
   ],
   "source": [
    "E = 1000 # number of environments\n",
    "# assume one history per environment...\n",
    "H = T # length of history\n",
    "env_arm_idxs = []\n",
    "histories = []\n",
    "for e in range(E): \n",
    "    if e % 100 == 0:\n",
    "        print(e)\n",
    "    arm_idxs = (torch.randperm(total_rows) < num_arms).int().nonzero().flatten()\n",
    "    click_rates = val_dataset.click_rates[arm_idxs]\n",
    "    Y = torch.bernoulli(val_dataset.click_rates[arm_idxs].unsqueeze(1).repeat(1,T))\n",
    "    this_best_arm = click_rates.argmax()\n",
    "    targets = torch.zeros(num_arms)\n",
    "    targets[this_best_arm] = 1\n",
    "\n",
    "    env_arm_idxs.append({'arm_idxs':arm_idxs, \n",
    "                         'Y': Y, \n",
    "                         'best_arm': this_best_arm, \n",
    "                         'best_arm_onehot':targets, \n",
    "                         'click_rates': click_rates\n",
    "                        })\n",
    "\n",
    "    probs = get_probs(10)\n",
    "    \n",
    "    this_history = torch.zeros(H)\n",
    "    for h in range(H):\n",
    "        chosen_arm = np.random.choice(np.arange(num_arms), p=probs)    \n",
    "        cols = torch.arange(T).unsqueeze(0)  \n",
    "        this_history[h] = chosen_arm\n",
    "    histories.append(this_history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "eaadf750",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.mkdir(config.data_dir + '/dpt_histories')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "5940183c",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save({'env':env_arm_idxs, 'hist':histories}, config.data_dir + f'/dpt_histories/eval_hist_num_arms={num_arms},E={E},H={H},S=1,seed={seed}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
