{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9ee3da83",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "5b1f96d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_base_dir = '/shared/share_mala/implicitbayes/dataset_files/synthetic_data/binary_context_redo//N=2000,D=10000,D_eval=10000,binaryX=False,cnts=5.0,method=shareU_0702,one_X_per_col=False/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "06617e03",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = torch.load(dataset_base_dir + '/train_data.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "4525ddcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_data = torch.load(dataset_base_dir + '/eval_data.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "9084720b",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_fn = lambda D,N,binary,ave_U,g: generate_data_beta_context_shareU(D=D,N=N,binaryX=binary,cnts=args.cnts,one_X_per_col=args.one_X_per_col,ave_U=ave_U,num_Us=0, uniform=False, generator=g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "fc5ead47",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert train_data['click_rate'][0,0] != eval_data['click_rate'][0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "1e47aee0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0786],\n",
       "        [0.3721],\n",
       "        [0.9007],\n",
       "        ...,\n",
       "        [0.9442],\n",
       "        [0.4205],\n",
       "        [0.3163]])"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_data['Z']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "4b160b88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.2673, 0.8725, 0.3353, 0.4030, 0.7871, 0.4576, 0.0719, 0.9715, 0.7147,\n",
       "        0.4275, 0.0846, 0.4904, 0.1662, 0.1538, 0.3638, 0.5145, 0.7122, 0.8787,\n",
       "        0.8431, 0.9879])"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g = torch.Generator()\n",
    "#g.initial_seed()\n",
    "torch.rand(20, generator=g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "fa8c033d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.2673, 0.8725, 0.3353, 0.4030, 0.7871, 0.4576, 0.0719, 0.9715, 0.7147,\n",
       "        0.4275, 0.0846, 0.4904, 0.1662, 0.1538, 0.3638, 0.5145, 0.7122, 0.8787,\n",
       "        0.8431, 0.9879])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g = torch.Generator()\n",
    "g.initial_seed()\n",
    "torch.rand(20, generator=g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "807059e4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.2673, 0.8725, 0.3353, 0.4030, 0.7871, 0.4576, 0.0719, 0.9715, 0.7147,\n",
       "        0.4275, 0.0846, 0.4904, 0.1662, 0.1538, 0.3638, 0.5145, 0.7122, 0.8787,\n",
       "        0.8431, 0.9879])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g = torch.Generator()\n",
    "g.initial_seed()\n",
    "torch.rand(20, generator=g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc428c53",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
