{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.06870317 0.54389303 0.03501452 0.01397393 0.00451768]\n",
      "[0.07891602 0.58712144 0.04014332 0.01799546 0.00186428]\n",
      "[0.10213739 0.61572108 0.05165734 0.01462464 0.00635071]\n",
      "[0.13780715 0.25304207 0.06891669 0.04006413 0.03263223]\n",
      "[0.02608216 0.35773065 0.01364033 0.00968953 0.00245542]\n",
      "[0.10564487 0.14426284 0.05308644 0.02324259 0.01901957]\n",
      "\n",
      "[0.06849189 0.54396769 0.03457817 0.01323775 0.00513704]\n",
      "[0.0787635  0.5874185  0.03982734 0.01724934 0.00249091]\n",
      "[0.10244618 0.61624541 0.05184821 0.01402729 0.00682254]\n",
      "[0.1383478  0.25311439 0.06932239 0.04040182 0.03303735]\n",
      "[0.02592071 0.35757497 0.01332267 0.00908277 0.00298264]\n",
      "[0.10548681 0.14434568 0.05291931 0.02295236 0.01874951]\n",
      "\n",
      "[0.06894858 0.54449343 0.03503044 0.01415689 0.00438927]\n",
      "[0.07889363 0.58793717 0.03985817 0.0176087  0.002336  ]\n",
      "[0.10202025 0.61658607 0.05131424 0.01395751 0.00711674]\n",
      "[0.13807896 0.25300758 0.06918895 0.04008467 0.03263864]\n",
      "[0.02631094 0.35790548 0.01344447 0.00996423 0.00220346]\n",
      "[0.10680531 0.14443471 0.0539806  0.02424454 0.02004143]\n",
      "\n",
      "[0.06856651 0.5440137  0.0345496  0.01326795 0.00529612]\n",
      "[0.07805067 0.58709063 0.03890663 0.01646918 0.00348963]\n",
      "[0.10154888 0.61569024 0.05077965 0.01336231 0.00770551]\n",
      "[0.13712434 0.25313907 0.06831212 0.03925181 0.03177061]\n",
      "[0.02526123 0.35757394 0.01245817 0.00880457 0.00339445]\n",
      "[0.10562095 0.14442156 0.05275069 0.02305233 0.01880716]\n",
      "\n",
      "[0.06906504 0.54476645 0.03514231 0.0135063  0.00509693]\n",
      "[0.07891246 0.58797207 0.03987787 0.01691447 0.00308236]\n",
      "[0.10126607 0.61630797 0.05055684 0.01287748 0.00825296]\n",
      "[0.13813371 0.25304755 0.0689642  0.03998605 0.03252524]\n",
      "[0.025705   0.35779008 0.01276316 0.00888774 0.00332984]\n",
      "[0.10537277 0.14433519 0.05268385 0.0226887  0.01842898]\n",
      "\n",
      "[0.06776553 0.54383578 0.03377169 0.01259385 0.0058854 ]\n",
      "[0.07842525 0.58757631 0.03951068 0.01671565 0.00314174]\n",
      "[0.10159185 0.61600689 0.05095312 0.01339637 0.00756544]\n",
      "[0.13809464 0.25313127 0.06907803 0.0400666  0.03265515]\n",
      "[0.02549523 0.35756583 0.01299885 0.00884583 0.00328439]\n",
      "[0.10521799 0.14425734 0.05248933 0.02261911 0.01838792]\n",
      "\n",
      "[0.0682004  0.54332939 0.03485705 0.01231063 0.00566137]\n",
      "[0.07799268 0.58656779 0.03959413 0.01573415 0.00358535]\n",
      "[0.10149806 0.61511487 0.05140285 0.01256178 0.00783126]\n",
      "[0.13720445 0.25312979 0.06766313 0.03939613 0.03216245]\n",
      "[0.02612651 0.35734722 0.01402869 0.00883216 0.00296312]\n",
      "[0.10433056 0.14408259 0.05238223 0.02179117 0.01765229]\n",
      "\n",
      "[0.06826813 0.54363158 0.03408063 0.01361819 0.00495108]\n",
      "[0.07880813 0.58723547 0.03967344 0.01781938 0.00213038]\n",
      "[0.10198131 0.61585407 0.05111147 0.0142442  0.00683349]\n",
      "[0.13736033 0.2530527  0.06865575 0.03956892 0.03208942]\n",
      "[0.02646771 0.35738236 0.01347585 0.00964818 0.00253485]\n",
      "[0.10601073 0.1443442  0.05323357 0.0236867  0.0194594 ]\n",
      "\n",
      "[0.06927107 0.54382059 0.0359468  0.01381907 0.00442881]\n",
      "[0.07889082 0.58711105 0.04042722 0.01693699 0.00269401]\n",
      "[0.10190195 0.61543855 0.05177316 0.01369759 0.00702026]\n",
      "[0.13761327 0.25313097 0.068337   0.03972056 0.03237287]\n",
      "[0.02567645 0.35736553 0.01345907 0.00881781 0.00318264]\n",
      "[0.10621362 0.14423237 0.05388967 0.02372334 0.01957423]\n",
      "\n",
      "[0.06794438 0.54383341 0.03422408 0.0129407  0.00542676]\n",
      "[0.07807702 0.58721396 0.03942847 0.01692966 0.00280064]\n",
      "[0.10122586 0.61587642 0.0510499  0.01351415 0.00732596]\n",
      "[0.13846877 0.25307764 0.06958183 0.0408225  0.03347904]\n",
      "[0.02586653 0.35780109 0.0133816  0.00925911 0.0027995 ]\n",
      "[0.10529272 0.1442466  0.05276626 0.02277309 0.01857043]\n",
      "\n",
      "[0.06717655 0.54299551 0.03268701 0.01203467 0.00653394]\n",
      "[0.07684913 0.58623559 0.03722964 0.01548777 0.00447226]\n",
      "[0.10031115 0.61491304 0.04917519 0.0124448  0.00862455]\n",
      "[0.13693069 0.25318341 0.06838065 0.03932327 0.03184421]\n",
      "[0.02548001 0.35726645 0.01242957 0.00891694 0.00326545]\n",
      "[0.10473962 0.14409104 0.05160137 0.02221176 0.01794944]\n",
      "\n",
      "[0.0678745  0.54308073 0.03473015 0.0120846  0.00575353]\n",
      "[0.07777405 0.58641522 0.03958179 0.01567066 0.00350277]\n",
      "[0.10133377 0.61503423 0.05155803 0.01266354 0.00757233]\n",
      "[0.13767142 0.2531902  0.06796714 0.03972684 0.03255581]\n",
      "[0.02449853 0.35733722 0.01275191 0.00768298 0.0040632 ]\n",
      "[0.10443242 0.14400973 0.05263001 0.02185117 0.01774741]\n",
      "\n",
      "[0.06789705 0.54411682 0.03383163 0.0128887  0.00579677]\n",
      "[0.07761636 0.5872512  0.03843018 0.01642111 0.00365731]\n",
      "[0.1004465  0.61585588 0.0497054  0.01274591 0.00846793]\n",
      "[0.13789435 0.25308837 0.06910115 0.03997328 0.03247716]\n",
      "[0.02560048 0.35747355 0.0125659  0.00907437 0.00318026]\n",
      "[0.10535085 0.14429024 0.05234738 0.02279013 0.01851648]\n",
      "\n",
      "[0.06699171 0.54281174 0.03249203 0.01198061 0.00675104]\n",
      "[0.077702   0.58655686 0.03829808 0.01619712 0.00392939]\n",
      "[0.09983432 0.61491898 0.04873336 0.01179322 0.00948953]\n",
      "[0.13669179 0.25321827 0.06824355 0.03904868 0.03149119]\n",
      "[0.02545739 0.35734309 0.01210797 0.0087419  0.00355933]\n",
      "[0.10483363 0.14401171 0.0516415  0.02254934 0.01826325]\n",
      "\n",
      "[0.06653697 0.54361066 0.03225707 0.01103532 0.0075768 ]\n",
      "[0.07670056 0.58708604 0.03744261 0.01477968 0.00522095]\n",
      "[0.10012014 0.61561723 0.0492281  0.01165154 0.0094555 ]\n",
      "[0.13717188 0.253147   0.06825036 0.03917283 0.03169499]\n",
      "[0.02356092 0.35720588 0.01042478 0.00680728 0.00543824]\n",
      "[0.10502865 0.14424049 0.05206178 0.02242097 0.0181637 ]\n",
      "\n",
      "[0.06850464 0.54333042 0.03482887 0.01352796 0.00467782]\n",
      "[0.07910659 0.58724614 0.04051471 0.01745053 0.00212955]\n",
      "[0.10233234 0.61567518 0.05204761 0.01418156 0.00648777]\n",
      "[0.13782275 0.25310579 0.06875895 0.04007459 0.03275413]\n",
      "[0.02687588 0.35764924 0.01449795 0.01014733 0.00179399]\n",
      "[0.10618416 0.14423492 0.0537998  0.02372622 0.01958514]\n",
      "\n",
      "[0.0685278  0.54395204 0.03389952 0.01400597 0.00479571]\n",
      "[0.07904885 0.58767913 0.03947165 0.01809908 0.00210875]\n",
      "[0.10226778 0.61621072 0.05095312 0.01478467 0.00655301]\n",
      "[0.13853483 0.25302896 0.07015615 0.04088304 0.03335643]\n",
      "[0.02591559 0.35780051 0.01272017 0.00962394 0.00272883]\n",
      "[0.10550835 0.14408905 0.05227018 0.02308529 0.01879226]\n",
      "\n",
      "[0.0686378  0.54410948 0.03487291 0.01400189 0.00449608]\n",
      "[0.07869903 0.58743078 0.03997578 0.01785348 0.00202177]\n",
      "[0.10180682 0.61609991 0.05134128 0.01407333 0.00693423]\n",
      "[0.13763718 0.25300975 0.06884314 0.03986863 0.03243112]\n",
      "[0.02563097 0.35757872 0.01299957 0.00958422 0.00255896]\n",
      "[0.10616295 0.14425886 0.05362744 0.02393259 0.01973383]\n",
      "\n",
      "[0.06797835 0.54408477 0.03464992 0.01336722 0.0051075 ]\n",
      "[0.07786007 0.58740732 0.03934827 0.01685447 0.00300696]\n",
      "[0.10093313 0.6159634  0.05089798 0.01357873 0.00739079]\n",
      "[0.13821578 0.25302942 0.06925339 0.04043718 0.03303666]\n",
      "[0.02567984 0.35754302 0.01341613 0.00921461 0.002911  ]\n",
      "[0.10557285 0.14435572 0.05316044 0.02311035 0.0188897 ]\n",
      "\n",
      "[0.06769087 0.54321373 0.03396393 0.01193513 0.00621364]\n",
      "[0.07835291 0.58677189 0.03963089 0.01614672 0.00334615]\n",
      "[0.10129281 0.61520669 0.050846   0.01256469 0.00802189]\n",
      "[0.1372674  0.25312666 0.06781235 0.03928563 0.03198071]\n",
      "[0.02521109 0.35723735 0.01283419 0.00795502 0.00397665]\n",
      "[0.10496748 0.14392054 0.05263081 0.02235253 0.0181997 ]\n",
      "\n",
      "[[0.06815205 0.54374455 0.03427042 0.01301437 0.00542466]\n",
      " [0.07827199 0.58716623 0.03935854 0.01676668 0.00305056]\n",
      " [0.10141483 0.61571684 0.05084664 0.01333727 0.00759112]\n",
      " [0.13770358 0.25310004 0.06873935 0.03985786 0.03244927]\n",
      " [0.02564116 0.35752361 0.01298605 0.00897903 0.00313031]\n",
      " [0.10543886 0.14422327 0.05279763 0.02294021 0.01872659]]\n",
      "[[7.00090704e-04 4.98952253e-04 9.37049059e-04 8.76860550e-04\n",
      "  8.58163249e-04]\n",
      " [7.05882458e-04 4.71261921e-04 8.86389441e-04 9.13053118e-04\n",
      "  8.80048689e-04]\n",
      " [7.58498010e-04 4.83663646e-04 9.36031203e-04 8.92936498e-04\n",
      "  9.09112195e-04]\n",
      " [5.28362600e-04 6.16798244e-05 6.25807432e-04 5.26392625e-04\n",
      "  5.40360327e-04]\n",
      " [7.01222521e-04 2.09564607e-04 8.34271614e-04 7.89172541e-04\n",
      "  7.74724409e-04]\n",
      " [6.30127530e-04 1.43367896e-04 6.81890486e-04 6.74735227e-04\n",
      "  6.75025586e-04]]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Experiments on semi-synthetic dataset\n",
    "\"\"\"\n",
    "import pickle\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def getAccuracy(real, predList):\n",
    "    returnList = []\n",
    "    for i in predList:\n",
    "        returnList.append(np.abs(real-i)/real)\n",
    "    return np.array(returnList)\n",
    "\n",
    "\n",
    "file = open(\"data/synthetic_data\", \"rb\")\n",
    "ground_truth = pickle.load(file)\n",
    "one = pickle.load(file)\n",
    "three = pickle.load(file)\n",
    "five = pickle.load(file)\n",
    "rotate = pickle.load(file)\n",
    "skew = pickle.load(file)\n",
    "crs = pickle.load(file)\n",
    "file.close()\n",
    "\n",
    "propensity = np.copy(ground_truth)\n",
    "p = 0.5\n",
    "propensity[np.where(propensity == 0.9)] = p ** 1\n",
    "propensity[np.where(propensity == 0.7)] = p ** 2\n",
    "propensity[np.where(propensity == 0.5)] = p ** 3\n",
    "propensity[np.where(propensity == 0.3)] = p ** 4\n",
    "propensity[np.where(propensity == 0.1)] = p ** 4\n",
    "res = np.zeros([6, 5])\n",
    "res_var = np.zeros([6, 5])\n",
    "for i in range(20):\n",
    "    observation = np.random.binomial(1, propensity)\n",
    "    ones = np.count_nonzero(observation)\n",
    "    zeros = observation.shape[0] - ones\n",
    "    p_o = ones/(ones+zeros)\n",
    "    ground_truth = np.random.binomial(1, ground_truth)\n",
    "    o = np.where(observation == 1)\n",
    "    a = np.random.random((6040*3952))\n",
    "    p_hat = a/propensity + (1-a)/p_o\n",
    "    predList = [one, three, five, rotate, skew, crs]\n",
    "    for j in range(6):\n",
    "        prediction = predList[j]\n",
    "        ce = -ground_truth * np.log(prediction) - (1 - ground_truth) * np.log(1 - prediction)\n",
    "\n",
    "        # DR\n",
    "        # r tilde\n",
    "        prediction_hat = np.sum(prediction * p_hat * observation) / np.sum(observation * p_hat)\n",
    "        # e hat\n",
    "        ce_hat = -prediction_hat * np.log(prediction) - (1 - prediction_hat) * np.log(1 - prediction)\n",
    "        \n",
    "#         # MRDR\n",
    "#         prediction_hat = np.sum(prediction * p_hat * p_hat * (1 - 1 / p_hat) * observation) / np.sum(\n",
    "#             p_hat * p_hat * (1 - 1 / p_hat) * observation)\n",
    "        \n",
    "    \n",
    "        # MRDR\n",
    "        prediction_hat = np.sum(prediction * p_hat * p_hat * (1 - 1 / p_hat) * observation) / np.sum(\n",
    "            p_hat * p_hat * (1 - 1 / p_hat) * observation)\n",
    "        ce_mrdr = -prediction_hat * np.log(prediction) - (1 - prediction_hat) * np.log(1 - prediction)\n",
    "        \n",
    "        # DR-TMLE\n",
    "        #prediction_hat = np.sum(prediction * p_hat * observation) / np.sum(observation * p_hat)\n",
    "#         e_hat_tmle = (-prediction_hat * np.log(prediction) - (1 - prediction_hat) * np.log(1 - prediction)\n",
    "#                       ) + np.sum((ce[o]-ce_hat[o])*(p_hat[o]-1))/np.sum((p_hat[o]-1)**2)*(p_hat - 1)\n",
    "\n",
    "# MRDR\n",
    "#         prediction_hat = np.sum(prediction * p_hat * p_hat * observation) / np.sum(\n",
    "#             p_hat * p_hat * observation)\n",
    "#         ce_sdr_hat = -prediction_hat * np.log(prediction) - (1 - prediction_hat) * np.log(1 - prediction)\n",
    "        \n",
    "        prediction_hat = np.sum(prediction * p_hat * observation) / np.sum(observation * p_hat)\n",
    "        ce_hat_tdr = (-prediction_hat * np.log(prediction) - (1 - prediction_hat) * np.log(1 - prediction)\n",
    "                      ) + np.sum((ce[o]-ce_hat[o])*(p_hat[o]-1))/np.sum((p_hat[o]-1)**2)*(p_hat - 1)   \n",
    "        \n",
    "        real_ce = np.mean(ce)\n",
    "        naive_ce = np.mean(ce[o])\n",
    "        eib_ce = np.mean(ce_hat*(1-observation)+ce*observation)\n",
    "        ips_ce = np.mean(ce * observation * p_hat)\n",
    "        snips_ce = np.sum(ce * observation * p_hat)/np.sum(observation * p_hat)\n",
    "        dr_ce = np.mean(ce_hat + observation * (ce - ce_hat) * p_hat)\n",
    "#         sdr_ce = np.mean(ce_hat) + np.sum(observation * (ce - ce_hat) * p_hat)/np.sum(observation * p_hat)\n",
    "#         ssdr_ce = np.mean(ce_sdr_hat) + np.sum(observation * (ce - ce_sdr_hat) * p_hat)/np.sum(observation * p_hat)\n",
    "        mrdr_ce = np.mean(ce_mrdr + observation * (ce - ce_mrdr) * p_hat)\n",
    "        tdr_ce = np.mean(ce_hat_tdr + observation * (ce - ce_hat_tdr) * p_hat)\n",
    "        #r_ce_tmle = np.mean(ce_hat_tmle + observation * (ce - ce_hat_tmle) * p_hat)\n",
    "        \n",
    "        acc = getAccuracy(real_ce, [naive_ce, eib_ce, ips_ce, dr_ce, tdr_ce])\n",
    "        res[j] += acc\n",
    "        res_var[j] += acc ** 2\n",
    "\n",
    "        print(acc)\n",
    "    print()\n",
    "print(res/20)\n",
    "print(np.sqrt((1/19)*(res_var - 20*(res/20)**2)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch-gpu]",
   "language": "python",
   "name": "conda-env-pytorch-gpu-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
