{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy\n",
    "import math\n",
    "from scipy.special import rel_entr\n",
    "from scipy.stats import entropy\n",
    "import copy\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "from bounds import *"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### compute bounds for synthetic data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate synthetic data\n",
    "\n",
    "allPs = []\n",
    "trials = 100\n",
    "s_max = 0\n",
    "diff_max = 0\n",
    "P_max = None\n",
    "for _ in tqdm(range(trials)):\n",
    "    k = np.random.exponential(scale=1.0, size=8)\n",
    "    P = (k / sum(k)).reshape((2,2,2))\n",
    "    all_quantities, all_bounds = get_bounds(P)\n",
    "    allPs.append((P, all_quantities, all_bounds))\n",
    "\n",
    "with open('/content/drive/My Drive/neurips_bounds_data/all_Ps_1000_new_upper.pkl', 'wb') as fp:\n",
    "    pickle.dump(allPs, fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/content/drive/My Drive/neurips_bounds_data/all_Ps_10000.pkl', 'rb') as fp:\n",
    "  savedPs = pickle.load(fp)\n",
    "print (len(savedPs))\n",
    "\n",
    "# all_R = [all_quantities['R'] - all_quantities['I_x1x2'] + all_quantities['I_x1x2_given_y'] for (P, all_quantities, all_bounds) in savedPs]\n",
    "\n",
    "# visualize all bounds\n",
    "\n",
    "all_S = []\n",
    "all_upper = []\n",
    "all_lower_diff = []\n",
    "for (P, all_quantities, all_bounds) in savedPs:\n",
    "  lower_diff = (all_bounds['lower_diff'] + max(all_quantities['U1'],all_quantities['U2']))/4 - max(all_quantities['U1'],all_quantities['U2'])\n",
    "  if lower_diff > 0:\n",
    "    all_S.append(all_quantities['S'])\n",
    "    all_lower_diff.append(lower_diff)\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.scatter(all_lower_diff, all_S, color='blue', cmap=plt.cm.coolwarm, zorder=10)\n",
    "# ax.scatter(all_upper, all_S, color='red', cmap=plt.cm.coolwarm, zorder=10)\n",
    "# lims = [0.0, 1.0]\n",
    "# ax.plot(lims, lims, 'k-', alpha=0.75, zorder=0)\n",
    "# ax.set_aspect('equal')\n",
    "# ax.set_xlim(lims)\n",
    "# ax.set_ylim(lims)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "allPs = []\n",
    "trials = 100000\n",
    "s_max = 0\n",
    "diff_max = 0\n",
    "P_max = None\n",
    "for _ in tqdm(range(trials)):\n",
    "    k = np.random.exponential(scale=1.0, size=8)\n",
    "    P = (k / sum(k)).reshape((2,2,2))\n",
    "    try:\n",
    "        r, u1, u2, s, diff, Py_givenx1, Py_givenx2, Px1x2, Px1, Px2, I_x1x2 = get_rus_diff(P)\n",
    "        allPs.append((P, r, u1, u2, s, diff, Py_givenx1, Py_givenx2, Px1x2, Px1, Px2, I_x1x2))\n",
    "        diff = diff-max(u1,u2)\n",
    "        if s >= 0.3 and diff >= 0.5:\n",
    "            print ('s=', s, 'diff-U=', diff, 'P=', P)\n",
    "        if diff >= diff_max:\n",
    "            s_max = s\n",
    "            diff_max = diff\n",
    "            P_max = P\n",
    "    except:\n",
    "        pass\n",
    "print ('s_max=', s_max, 'diff_max=', diff_max, 'P_max=', P_max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print (len(allPs))\n",
    "for entry in allPs:\n",
    "    P = entry[0]\n",
    "    r = entry[1]\n",
    "    u1 = entry[2]\n",
    "    u2 = entry[3]\n",
    "    s = entry[4]\n",
    "    diff = entry[5]\n",
    "    if s >= 0.2 and diff/4.0-max(u1,u2) >= 0.1:\n",
    "        print ('s=', s, 'diff-U=', diff/4.0-max(u1,u2), 'P=', P)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "P = np.array([[[0.23148881,0.00941881],[0.07064852,0.00650789]],[[0.01758162,0.42417211],[0.22418906,0.01599317]]])\n",
    "# P = np.array([[[0.00941881,0.23148881],[0.07064852,0.00650789]],[[0.42417211,0.01758162],[0.01599317,0.22418906]]])\n",
    "# P = np.array([[[0.0,0.25],[0.15,0.1]],[[0.25,0.0],[0.1,0.15]]])\n",
    "# P = np.array([[[0.03349886,0.02460521], [0.23188986,0.0167147 ]], [[0.04237604,0.64494174], [0.00478614,0.00118744]]])\n",
    "# P = np.array([[[0.11564164,0.01350078], [0.03169016,0.00144381]], [[0.05276016,0.66327317], [0.10926603,0.01242425]]])\n",
    "P.shape\n",
    "all_quantities, all_bounds = get_bounds(P)\n",
    "r = all_quantities['R']\n",
    "u1 = all_quantities['U1']\n",
    "u2 = all_quantities['U2']\n",
    "s = all_quantities['S']\n",
    "diff = all_bounds['lower_diff']\n",
    "Py_givenx1 = all_quantities['Py_given_x1']\n",
    "Py_givenx2 = all_quantities['Py_given_x2']\n",
    "Px1x2 = all_quantities['Px1x2']\n",
    "Px1 = all_quantities['Px1']\n",
    "Px2 = all_quantities['Px2']\n",
    "I_x1x2 = all_quantities['I_x1x2']\n",
    "print ('r=', r, 'u1=', u1, 'u2=', u2, 's=', s, 'total=', r+u1+u2+s)\n",
    "print ('y|x1=', Py_givenx1, 'y|x2=', Py_givenx2, 'diff=', diff+max(u1,u2), 'diff-U=', diff)\n",
    "print ('Px1x2=', Px1x2, 'Px1=', Px1, 'Px2=', Px2, 'I_x1x2=', I_x1x2)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### generate binary data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def group(L, support):\n",
    "    res = np.zeros(L[0].shape)\n",
    "    index = 0\n",
    "    for variable in L:\n",
    "        res += variable*support**index\n",
    "        index += 1\n",
    "    return res\n",
    "\n",
    "def gen_binary_data(num_data):\n",
    "    x1dim = 3\n",
    "    x2dim = 3\n",
    "    ydim = 5\n",
    "\n",
    "    x1 = [np.random.randint(0, 2, (num_data, 1)) for _ in range(x1dim)]\n",
    "    x2 = [np.random.randint(0, 2, (num_data, 1)) for _ in range(x2dim)]\n",
    "    x2 = copy.deepcopy(x1)\n",
    "    # x1[0] = np.random.randint(0, 2, (num_data, 1))\n",
    "    x2[0] = np.random.randint(0, 2, (num_data, 1))\n",
    "    x2[1] = np.random.randint(0, 2, (num_data, 1))\n",
    "\n",
    "    # s\n",
    "    ydim1 = (x1[0] + x2[0] + x1[1] + x2[1] + x1[2] + x2[2]) % 2\n",
    "    # ydim2 = (x1[0] + x2[0] + x1[1]) % 2\n",
    "    # ydim3 = (x1[0] + x2[2]) % 2\n",
    "    # ydim4 = (x1[1] + x2[1] + x1[2] + x2[2]) % 2\n",
    "    # ydim5 = (x2[0] + x1[1] + x2[1] + x1[2]) % 2\n",
    "\n",
    "    # rand_prob = 0.7\n",
    "    # ydim3 = (np.random.rand(num_data,1) < rand_prob) * ydim3\n",
    "    # rand_prob = 0.5\n",
    "    # ydim5 = (np.random.rand(num_data,1) < rand_prob) * ydim5\n",
    "    y_s = group([ydim1], support=2)\n",
    "\n",
    "    # rand_prob = 0.7\n",
    "    # x1[1] = (np.random.rand(num_data,1) < rand_prob) * x1[1]\n",
    "    rand_prob = 0.5\n",
    "    x2[2] = (np.random.rand(num_data,1) < rand_prob) * x2[2]\n",
    "\n",
    "    rand_prob = 0.3\n",
    "    x2[0] = (np.random.rand(num_data,1) < rand_prob) * x2[0]\n",
    "\n",
    "    x1 = group(x1, support=2)\n",
    "    x2 = group(x2, support=2)\n",
    "    data = {\n",
    "        's': (x1, x2, y_s),\n",
    "    }\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = gen_binary_data(10000)\n",
    "y = data['s'][2]\n",
    "unique, counts = np.unique(y, return_counts=True)\n",
    "print (unique, counts)\n",
    "\n",
    "P, maps = convert_data_to_distribution(*data['s'])\n",
    "r, u1, u2, s = test(P)\n",
    "print (r, u1, u2, s, r+u1+u2+s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# checking H(Y), H(Y|X_1,X_2)\n",
    "P.shape\n",
    "Py = np.sum(np.sum(P, axis=0), axis=0)\n",
    "y_ent = scipy.special.entr(Py).sum(axis=0) / np.log(2)\n",
    "print (y_ent)\n",
    "\n",
    "Px1x2 = np.sum(P, axis=2, keepdims=True)\n",
    "Py_givenx1x2 = P/Px1x2\n",
    "print (Py_givenx1x2.shape)\n",
    "Py_givenx1x2_ent = scipy.special.entr(Py_givenx1x2) / np.log(2)\n",
    "Py_givenx1x2_ent = np.einsum('ijk,ij->k', Py_givenx1x2_ent, Px1x2.squeeze(axis=-1))\n",
    "print (Py_givenx1x2_ent.sum(), Py_givenx1x2_ent.shape)\n",
    "\n",
    "Px1y = np.sum(P, axis=1, keepdims=True)\n",
    "Px2y = np.sum(P, axis=0, keepdims=True)\n",
    "Px1 = np.sum(Px1y, axis=2, keepdims=True)\n",
    "Px2 = np.sum(Px2y, axis=2, keepdims=True)\n",
    "Py_givenx1 = Px1y/Px1\n",
    "Py_givenx2 = Px2y/Px2\n",
    "\n",
    "print (Py_givenx1.shape, Py_givenx2.shape)\n",
    "diff = (Py_givenx1 - Py_givenx2)**2\n",
    "print (diff.shape)\n",
    "diff = np.einsum('ij,ij', Px1x2.squeeze(axis=-1), diff.sum(axis=-1))\n",
    "print (diff)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
