{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy import integrate\n",
    "from itertools import product\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def create_combinations(tuple_list):\n",
    "    combinations = product([0, 1], repeat=len(tuple_list))\n",
    "    \n",
    "    bounds = []\n",
    "    for combo in combinations:\n",
    "        new_list = [tuple_list[i][idx] for i, idx in enumerate(combo)]\n",
    "        bounds.append(new_list)\n",
    "    \n",
    "    return bounds\n",
    "\n",
    "def n_dim_integral(integrand, intervals, *args):\n",
    "    res = 0 \n",
    "    bounds = create_combinations(intervals)\n",
    "    for bound in bounds:\n",
    "        res += integrate.nquad(integrand, bound, args=args)[0]\n",
    "    \n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trs_biasvar(x,y,z,w):\n",
    "    return abs(x - y) * abs(z-w) * (y + w * (x - y)) * (1 - (y + w * (x - y)))\n",
    "\n",
    "def trs_bias(x,y,z,w):\n",
    "    return abs(x - y) * abs(z-w) \n",
    "\n",
    "def trs_bias_sq(x,y,z,w):\n",
    "    return (abs(x - y) * abs(z-w)) ** 2 \n",
    "\n",
    "def trs_var(x,y,w):\n",
    "    return (y + w * (x - y)) * (1 - (y + w * (x - y)))\n",
    "\n",
    "def trs_var_sq(x,y,w):\n",
    "    return ((y + w * (x - y)) * (1 - (y + w * (x - y)))) ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conf_biasvar_A(x,y,z,w):\n",
    "    return (abs(x - y) * abs(z - w) / (2 * (z + w))) * ((z + w) / 2) * (1 - ((z + w) / 2))\n",
    "\n",
    "def conf_bias_A(x,y,z,w):\n",
    "    return (abs(x - y) * abs(z - w) / (2 * (z + w)))\n",
    "\n",
    "def conf_bias_sq_A(x,y,z,w):\n",
    "    return (abs(x - y) * abs(z - w) / (2 * (z + w))) ** 2\n",
    "\n",
    "def conf_var_A(z,w):\n",
    "    return ((z + w) / 2) * (1 - ((z + w) / 2))\n",
    "\n",
    "def conf_var_sq_A(z,w):\n",
    "    return (((z + w) / 2) * (1 - ((z + w) / 2))) ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conf_biasvar_Y(x,y,z,w):\n",
    "    return (abs(x - y) * abs(z - w) / (2 * (z + w))) * ((x * z + y * w) / (z + w)) * (1 - ((x * z + y * w) / (z + w)))\n",
    "\n",
    "def conf_bias_Y(x,y,z,w):\n",
    "    return (abs(x - y) * abs(z - w) / (2 * (z + w)))\n",
    "\n",
    "def conf_bias_sq_Y(x,y,z,w):\n",
    "    return (abs(x - y) * abs(z - w) / (2 * (z + w))) ** 2\n",
    "\n",
    "def conf_var_Y(x,y,z,w):\n",
    "    return ((x * z + y * w) / (z + w)) * (1 - ((x * z + y * w) / (z + w)))\n",
    "\n",
    "def conf_var_sq_Y(x,y,z,w):\n",
    "    return (((x * z + y * w) / (z + w)) * (1 - ((x * z + y * w) / (z + w)))) ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "low_thr = 0.1\n",
    "high_thrs = [0.11, 0.2, 0.3, 0.4, 0.5]\n",
    "high_thrs = [0.11, 0.125, 0.15, 0.175, 0.2, 0.225, 0.25, 0.275, 0.3, 0.325, 0.35, 0.375, 0.4, 0.425, 0.45, 0.475, 0.5]\n",
    "# high_thrs = [0.2]\n",
    "\n",
    "res = {\"trans_Y\": [],\n",
    "       \"conf_Y\": [],\n",
    "       \"conf_A\": []}\n",
    "\n",
    "for i, high_thr in enumerate(high_thrs):\n",
    "    C = 1 / (2 * (high_thr - low_thr))\n",
    "\n",
    "    interval = [(low_thr, high_thr), (1 - high_thr, 1 - low_thr)]\n",
    "\n",
    "    ###################################\n",
    "\n",
    "    n_mul = 4\n",
    "    n_b = 4\n",
    "    n_v = 3\n",
    "\n",
    "    E_mul = (C ** n_mul) * n_dim_integral(trs_biasvar, [interval for _ in range(n_mul)])\n",
    "    E_b = (C ** n_b) * n_dim_integral(trs_bias, [interval for _ in range(n_b)])\n",
    "    E_v = (C ** n_v) * n_dim_integral(trs_var, [interval for _ in range(n_v)])\n",
    "\n",
    "    E_b_sq = (C ** n_b) * n_dim_integral(trs_bias_sq, [interval for _ in range(n_b)])\n",
    "    E_v_sq = (C ** n_v) * n_dim_integral(trs_var_sq, [interval for _ in range(n_v)])\n",
    "\n",
    "    b_sig = np.sqrt(E_b_sq - E_b ** 2)\n",
    "    v_sig = np.sqrt(E_v_sq - E_v ** 2)\n",
    "\n",
    "    cov = (E_mul - E_b * E_v)\n",
    "    trs_rho =  cov / (b_sig * v_sig)\n",
    "\n",
    "    ###################################\n",
    "\n",
    "    n_mul = 4\n",
    "    n_b = 4\n",
    "    n_v = 2\n",
    "\n",
    "    E_mul = (C ** n_mul) * n_dim_integral(conf_biasvar_A, [interval for _ in range(n_mul)])\n",
    "    E_b = (C ** n_b) * n_dim_integral(conf_bias_A, [interval for _ in range(n_b)])\n",
    "    E_v = (C ** n_v) * n_dim_integral(conf_var_A, [interval for _ in range(n_v)])\n",
    "\n",
    "    E_b_sq = (C ** n_b) * n_dim_integral(conf_bias_sq_A, [interval for _ in range(n_b)])\n",
    "    E_v_sq = (C ** n_v) * n_dim_integral(conf_var_sq_A, [interval for _ in range(n_v)])\n",
    "\n",
    "    b_sig = np.sqrt(E_b_sq - E_b ** 2)\n",
    "    v_sig = np.sqrt(E_v_sq - E_v ** 2)\n",
    "\n",
    "    cov = (E_mul - E_b * E_v)\n",
    "    conf_rho_A =  cov / (b_sig * v_sig)\n",
    "\n",
    "    ###################################\n",
    "\n",
    "    n_mul = 4\n",
    "    n_b = 4\n",
    "    n_v = 4\n",
    "\n",
    "    E_mul = (C ** n_mul) * n_dim_integral(conf_biasvar_Y, [interval for _ in range(n_mul)])\n",
    "    E_b = (C ** n_b) * n_dim_integral(conf_bias_Y, [interval for _ in range(n_b)])\n",
    "    E_v = (C ** n_v) * n_dim_integral(conf_var_Y, [interval for _ in range(n_v)])\n",
    "\n",
    "    E_b_sq = (C ** n_b) * n_dim_integral(conf_bias_sq_Y, [interval for _ in range(n_b)])\n",
    "    E_v_sq = (C ** n_v) * n_dim_integral(conf_var_sq_Y, [interval for _ in range(n_v)])\n",
    "\n",
    "    b_sig = np.sqrt(E_b_sq - E_b ** 2)\n",
    "    v_sig = np.sqrt(E_v_sq - E_v ** 2)\n",
    "\n",
    "    cov = (E_mul - E_b * E_v)\n",
    "    conf_rho_Y =  cov / (b_sig * v_sig)\n",
    "\n",
    "    print(f\"High thr: {high_thr}\")\n",
    "    print(f\"Trans (Rho_Y): {trs_rho:.4f}\")\n",
    "    print(f\"Conf. (Rho_Y): {conf_rho_Y:.4f}\")\n",
    "    print(f\"Conf. (Rho_A): {conf_rho_A:.4f}\")\n",
    "\n",
    "    res['trans_Y'].append(trs_rho)\n",
    "    res['conf_Y'].append(conf_rho_Y)\n",
    "    res['conf_A'].append(conf_rho_A)\n",
    "\n",
    "    plt.figure()\n",
    "    plt.plot(high_thrs[:i + 1], res['trans_Y'], marker='o', label='Trans_rho_Y')\n",
    "    plt.plot(high_thrs[:i + 1], res['conf_Y'], marker='x', label='Conf_rho_Y')\n",
    "    plt.plot(high_thrs[:i + 1], res['conf_A'], marker='d', label='Conf_rho_A')\n",
    "    plt.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(high_thrs[:i + 1], res['trans_Y'], marker='o', label=r'$\\rho(b, Y)$' + ' – Transportability bias')\n",
    "plt.plot(high_thrs[:i + 1], res['conf_Y'], marker='x', label=r'$\\rho(b, Y)$' + ' – Confounding bias')\n",
    "plt.plot(high_thrs[:i + 1], res['conf_A'], marker='d', label=r'$\\rho(b, A)$' + ' – Confounding bias')\n",
    "plt.axhline(0, linestyle='--', color='dimgray')\n",
    "plt.legend()\n",
    "plt.xlabel(\"p\")\n",
    "plt.tight_layout()\n",
    "plt.savefig('covs.png')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "falsification",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
