{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "[0.06995175 0.54413489 0.03543106 0.0139976  0.00437743]\n",
      "[0.07821001 0.5863064  0.03850752 0.01688889 0.00284763]\n",
      "[0.1009214  0.61505664 0.04992286 0.01348024 0.00737093]\n",
      "[0.13580497 0.25325603 0.06654522 0.03787627 0.03041102]\n",
      "[0.02319832 0.35805376 0.01026258 0.00739543 0.00475869]\n",
      "[0.10647116 0.1444494  0.0536818  0.02376009 0.0195769 ]\n",
      "\n",
      "\n",
      "[0.06827929 0.54553104 0.03203086 0.01593004 0.00419752]\n",
      "[7.98349594e-02 5.88877143e-01 3.88767779e-02 2.20158045e-02\n",
      " 4.72150948e-04]\n",
      "[0.10082125 0.618016   0.04834696 0.01578472 0.00710318]\n",
      "[0.13880692 0.25296998 0.07297317 0.04244582 0.03444481]\n",
      "[0.02777335 0.35949338 0.01262539 0.01209159 0.00111147]\n",
      "[0.10543783 0.14430449 0.05043798 0.02390001 0.01932402]\n",
      "\n",
      "\n",
      "[0.06558289 0.54580291 0.02725616 0.00938243 0.01156279]\n",
      "[0.07598454 0.58864653 0.03256471 0.01416989 0.00826128]\n",
      "[0.0999398  0.61792469 0.04482461 0.01093772 0.01276286]\n",
      "[0.13735124 0.253372   0.0703638  0.03901367 0.03064377]\n",
      "[0.02557122 0.35864112 0.00909864 0.00777136 0.00592879]\n",
      "[0.10447739 0.14483218 0.04804779 0.02102501 0.0161908 ]\n",
      "\n",
      "\n",
      "[0.0692723  0.54403448 0.03616036 0.01504563 0.00286073]\n",
      "[0.07863222 0.58704321 0.04031881 0.01827278 0.00098019]\n",
      "[0.10064889 0.61558739 0.05096437 0.01415371 0.00620492]\n",
      "[0.1380559  0.25299427 0.06864681 0.04060026 0.03341069]\n",
      "[0.02744391 0.35788555 0.01462808 0.01044594 0.00131495]\n",
      "[0.10637269 0.14365531 0.05456905 0.02445204 0.020417  ]\n",
      "\n",
      "\n",
      "[0.06956429 0.54377205 0.03345268 0.01521324 0.00397272]\n",
      "[8.16278109e-02 5.88510211e-01 4.08555955e-02 2.04991961e-02\n",
      " 1.18822211e-04]\n",
      "[0.10525894 0.61760496 0.05266752 0.01653127 0.00528688]\n",
      "[0.13941312 0.25347491 0.07218502 0.04229152 0.03462585]\n",
      "[0.02777083 0.35823473 0.01314221 0.01120743 0.00138918]\n",
      "[0.1041273  0.14417898 0.05010636 0.02202314 0.01758905]\n",
      "\n",
      "\n",
      "[0.06846153 0.54374046 0.03272076 0.01354564 0.00546317]\n",
      "[0.07998658 0.58846336 0.0395241  0.01805868 0.00239175]\n",
      "[0.10498423 0.61665219 0.0526164  0.01738394 0.00410131]\n",
      "[0.13788993 0.25317899 0.06968981 0.04033031 0.03269915]\n",
      "[0.024749   0.3575753  0.01087413 0.0085276  0.00398663]\n",
      "[0.10440074 0.14379138 0.05080841 0.02213816 0.01777659]\n",
      "\n",
      "\n",
      "[0.0681218  0.54166826 0.0340833  0.01206693 0.00617227]\n",
      "[0.07821364 0.58574955 0.03902919 0.01507592 0.00457779]\n",
      "[0.10333254 0.61415684 0.05226049 0.01392855 0.00674301]\n",
      "[0.13747875 0.25405546 0.06792753 0.03889079 0.03148063]\n",
      "[0.02360479 0.35847121 0.01101581 0.0087475  0.00329681]\n",
      "[0.10441608 0.14429618 0.0512686  0.02112681 0.01688492]\n",
      "\n",
      "\n",
      "[0.07578329 0.54523578 0.04298859 0.0198678  0.00272683]\n",
      "[0.08216619 0.58795242 0.04368326 0.01932881 0.00077656]\n",
      "[0.10680996 0.61642852 0.05659668 0.01770046 0.00183399]\n",
      "[0.13805136 0.25308669 0.06754959 0.0398596  0.03290273]\n",
      "[0.03065667 0.35755433 0.01809205 0.01234854 0.00108565]\n",
      "[0.10801134 0.14370416 0.05612723 0.02479406 0.02091443]\n",
      "\n",
      "\n",
      "[0.06684584 0.54613227 0.02870069 0.01387564 0.00743689]\n",
      "[0.07444169 0.5888     0.03082293 0.0154341  0.00752452]\n",
      "[0.09926263 0.61783993 0.04450987 0.01425004 0.00991881]\n",
      "[0.13792881 0.25346747 0.07288994 0.0403817  0.03184935]\n",
      "[0.02593494 0.35877835 0.0083028  0.00851144 0.00551389]\n",
      "[0.10543556 0.14426412 0.04807992 0.02300876 0.01815347]\n",
      "\n",
      "\n",
      "[0.06679637 0.54417225 0.03050529 0.01408259 0.00556505]\n",
      "[0.07902549 0.58885157 0.03799449 0.01980102 0.00129069]\n",
      "[0.10347189 0.61819488 0.05068025 0.01670421 0.00558989]\n",
      "[0.13623678 0.25259717 0.06866628 0.03875501 0.030829  ]\n",
      "[2.82942221e-02 3.59413160e-01 1.37187043e-02 1.26418334e-02\n",
      " 2.11156523e-04]\n",
      "[0.1110191  0.14398571 0.05625986 0.02855403 0.02427559]\n",
      "\n",
      "\n",
      "[0.0719101  0.54480351 0.03415465 0.02168649 0.00070686]\n",
      "[0.08539247 0.59083156 0.04332864 0.02770649 0.00512234]\n",
      "[1.07668141e-01 6.19193671e-01 5.36992226e-02 2.33643584e-02\n",
      " 5.31394292e-04]\n",
      "[0.13920831 0.25305192 0.07509525 0.04333808 0.0349275 ]\n",
      "[0.02848823 0.35828611 0.01265531 0.01607385 0.00234941]\n",
      "[0.11067586 0.14479689 0.05400502 0.02925158 0.02461763]\n",
      "\n",
      "\n",
      "[0.06716974 0.54472678 0.03372923 0.01280972 0.00589033]\n",
      "[0.07716186 0.58860365 0.03875851 0.01650279 0.00360996]\n",
      "[0.10124992 0.61797839 0.05151386 0.01346599 0.00778622]\n",
      "[0.13797785 0.25309017 0.06892288 0.03994219 0.03244877]\n",
      "[0.02755649 0.35892695 0.01619342 0.01075154 0.0014889 ]\n",
      "[0.10799223 0.14534528 0.0559395  0.02585712 0.0216564 ]\n",
      "\n",
      "\n",
      "[0.06815525 0.54584131 0.03185775 0.01310206 0.00681489]\n",
      "[0.07698788 0.59012868 0.03553308 0.01487383 0.00663261]\n",
      "[0.10262664 0.61879651 0.04993054 0.01501199 0.00756106]\n",
      "[0.13754161 0.25333969 0.06990139 0.03910862 0.03109939]\n",
      "[0.02842918 0.35906412 0.01345442 0.01121056 0.00179427]\n",
      "[0.10719117 0.14622851 0.05326776 0.0249574  0.02042357]\n",
      "\n",
      "\n",
      "[0.07054854 0.54383209 0.03558754 0.01468486 0.00416164]\n",
      "[0.07716447 0.58544553 0.03680924 0.01629152 0.00397989]\n",
      "[0.10109334 0.61452311 0.04955537 0.01407275 0.00731333]\n",
      "[0.13831964 0.25382504 0.07047644 0.0408211  0.03323459]\n",
      "[0.02370161 0.35736079 0.00985761 0.00637479 0.00612619]\n",
      "[0.10459262 0.14399225 0.05038475 0.02158955 0.01722997]\n",
      "\n",
      "\n",
      "[0.06964914 0.54606259 0.03484989 0.01442246 0.00482399]\n",
      "[0.07573962 0.58808466 0.03524289 0.01478039 0.00597439]\n",
      "[0.09961151 0.61719577 0.04805336 0.01262393 0.00926279]\n",
      "[0.13762954 0.25307365 0.06884124 0.0393149  0.0315764 ]\n",
      "[0.02337385 0.35827495 0.0101273  0.00717215 0.00553848]\n",
      "[0.10663868 0.14485697 0.05302063 0.02369098 0.01930453]\n",
      "\n",
      "\n",
      "[0.07072618 0.54292371 0.03697247 0.01512931 0.00282214]\n",
      "[0.0840864  0.58711372 0.04609863 0.02259022 0.00338954]\n",
      "[0.10729523 0.61593817 0.05735794 0.01848656 0.00183407]\n",
      "[0.13806298 0.25344888 0.06911673 0.04089022 0.03366064]\n",
      "[0.02620033 0.35686041 0.01342288 0.00867212 0.00317053]\n",
      "[0.10674    0.14420009 0.05474349 0.02462425 0.02055721]\n",
      "\n",
      "\n",
      "[0.06528636 0.54382274 0.03147362 0.01095855 0.00712498]\n",
      "[0.07740733 0.58766136 0.03912533 0.01732194 0.00203336]\n",
      "[0.10126398 0.61659931 0.05076192 0.01380036 0.0066565 ]\n",
      "[0.13852727 0.25310971 0.06994044 0.04151306 0.03433222]\n",
      "[2.68366013e-02 3.57966571e-01 1.52064050e-02 1.17626792e-02\n",
      " 1.37425781e-05]\n",
      "[0.10495744 0.14418989 0.05233966 0.02205879 0.01791409]\n",
      "\n",
      "\n",
      "[0.07058517 0.54647666 0.03599504 0.01699349 0.00289673]\n",
      "[0.07949916 0.58886529 0.03934499 0.02017364 0.00118946]\n",
      "[0.105525   0.61919328 0.0535945  0.01723227 0.00538061]\n",
      "[0.13969531 0.25330917 0.07211477 0.04168346 0.03373367]\n",
      "[0.02953708 0.35929963 0.01649456 0.014438   0.00148315]\n",
      "[0.10516911 0.14430687 0.051468   0.02349575 0.01895744]\n",
      "\n",
      "\n",
      "[0.06603527 0.54070001 0.0336506  0.00593391 0.01045588]\n",
      "[0.07698122 0.58374827 0.0396214  0.01065393 0.00691056]\n",
      "[0.09883326 0.61208362 0.04965454 0.00568858 0.01289653]\n",
      "[0.1353973  0.25400626 0.06379747 0.03711329 0.03046584]\n",
      "[0.02583546 0.35862195 0.01535745 0.00630287 0.00449519]\n",
      "[0.10178471 0.1427307  0.05091392 0.01831129 0.01449802]\n",
      "\n",
      "\n",
      "[0.0671597  0.5414917  0.03373623 0.01158041 0.00600328]\n",
      "[0.08192016 0.58629173 0.04433378 0.02040722 0.00162986]\n",
      "[0.10405237 0.6155417  0.0545553  0.0144545  0.00549098]\n",
      "[0.13711242 0.253554   0.0670087  0.03877782 0.03165105]\n",
      "[0.0260338  0.3590746  0.01394876 0.00875606 0.00286832]\n",
      "[0.1072565  0.14378156 0.05568955 0.02460865 0.02065421]\n",
      "\n",
      "[[0.06879424 0.54424528 0.03376684 0.01401544 0.00530181]\n",
      " [0.07902319 0.58779874 0.03901869 0.01804235 0.00348567]\n",
      " [0.10273355 0.61672528 0.05110333 0.01495281 0.00658146]\n",
      " [0.1378245  0.25331307 0.06963262 0.04014738 0.03252135]\n",
      " [0.02654949 0.35839185 0.01292393 0.01006016 0.00289627]\n",
      " [0.10615838 0.14429455 0.05255797 0.02366137 0.01934579]]\n",
      "[[0.0024712  0.00161282 0.00328743 0.00339008 0.00260195]\n",
      " [0.00284133 0.00167266 0.00374181 0.00372751 0.00253779]\n",
      " [0.00277977 0.001846   0.0033069  0.00343691 0.00319172]\n",
      " [0.00110745 0.00035875 0.0025825  0.00159809 0.00148186]\n",
      " [0.00210997 0.00072305 0.00266302 0.00267543 0.00196225]\n",
      " [0.0021997  0.00071327 0.00256356 0.00251109 0.00249839]]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Experiments on semi-synthetic dataset\n",
    "\"\"\"\n",
    "import pickle\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def getAccuracy(real, predList):\n",
    "    returnList = []\n",
    "    for i in predList:\n",
    "        returnList.append(np.abs(real-i)/real)\n",
    "    return np.array(returnList)\n",
    "\n",
    "\n",
    "file = open(\"data/synthetic_data\", \"rb\")\n",
    "ground_truth = pickle.load(file)\n",
    "one = pickle.load(file)\n",
    "three = pickle.load(file)\n",
    "five = pickle.load(file)\n",
    "rotate = pickle.load(file)\n",
    "skew = pickle.load(file)\n",
    "crs = pickle.load(file)\n",
    "file.close()\n",
    "\n",
    "propensity = np.copy(ground_truth)\n",
    "p = 0.5\n",
    "propensity[np.where(propensity == 0.9)] = p ** 1\n",
    "propensity[np.where(propensity == 0.7)] = p ** 2\n",
    "propensity[np.where(propensity == 0.5)] = p ** 3\n",
    "propensity[np.where(propensity == 0.3)] = p ** 4\n",
    "propensity[np.where(propensity == 0.1)] = p ** 4\n",
    "res = np.zeros([6, 5])\n",
    "res_var = np.zeros([6, 5])\n",
    "for i in range(20):\n",
    "    observation = np.random.binomial(1, propensity)\n",
    "    ones = np.count_nonzero(observation)\n",
    "    zeros = observation.shape[0] - ones\n",
    "    p_o = ones/(ones+zeros)\n",
    "    ground_truth = np.random.binomial(1, ground_truth)\n",
    "    o = np.where(observation == 1)\n",
    "    a = np.random.random((943*1682))\n",
    "    p_hat = a/propensity + (1-a)/p_o\n",
    "    predList = [one, three, five, rotate, skew, crs]\n",
    "    for j in range(6):\n",
    "        prediction = predList[j]\n",
    "        ce = -ground_truth * np.log(prediction) - (1 - ground_truth) * np.log(1 - prediction)\n",
    "\n",
    "        # DR\n",
    "        prediction_hat = np.sum(prediction * p_hat * observation) / np.sum(observation * p_hat)\n",
    "        ce_hat = -prediction_hat * np.log(prediction) - (1 - prediction_hat) * np.log(1 - prediction)\n",
    "        \n",
    "        # TDR\n",
    "        prediction_hat = np.sum(prediction * p_hat * observation) / np.sum(observation * p_hat)\n",
    "        ce_hat_tdr = (-prediction_hat * np.log(prediction) - (1 - prediction_hat) * np.log(1 - prediction)\n",
    "                      ) + np.sum((ce[o]-ce_hat[o])*(p_hat[o]-1))/np.sum((p_hat[o]-1)**2)*(p_hat - 1)      \n",
    "        \n",
    "        real_ce = np.mean(ce)\n",
    "        naive_ce = np.mean(ce[o])\n",
    "        eib_ce = np.mean(ce_hat*(1-observation)+ce*observation)\n",
    "        ips_ce = np.mean(ce * observation * p_hat)\n",
    "        dr_ce = np.mean(ce_hat + observation * (ce - ce_hat) * p_hat)\n",
    "        tdr_ce = np.mean(ce_hat_tdr + observation * (ce - ce_hat_tdr) * p_hat)\n",
    "        \n",
    "        acc = getAccuracy(real_ce, [naive_ce, eib_ce, ips_ce, dr_ce,  tdr_ce])\n",
    "        res[j] += acc\n",
    "        res_var[j] += acc ** 2\n",
    "\n",
    "        print(acc)\n",
    "    print()\n",
    "print(res/20)\n",
    "print(np.sqrt((1/19)*(res_var - 20*(res/20)**2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch-gpu]",
   "language": "python",
   "name": "conda-env-pytorch-gpu-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
