{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from do_shap.frontiers import parts_of\n",
    "from causaleffect import ID, createGraph"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "from random import seed\n",
    "from string import ascii_uppercase\n",
    "\n",
    "from do_shap.frontiers import sample_dag_all_ancestors\n",
    "\n",
    "\n",
    "seed(123)\n",
    "\n",
    "K = 5\n",
    "p = .25\n",
    "\n",
    "graph = sample_dag_all_ancestors(K, p, rejection=True)\n",
    "\n",
    "nodes = ascii_uppercase[:K] + 'Y'\n",
    "\n",
    "G = createGraph([\n",
    "    f'{nodes[i]}->{nodes[j]}'\n",
    "    for i, j in graph.edges\n",
    "])\n",
    "\n",
    "G.get_edgelist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0, 1), (1, 2), (2, 3), (3, 4), (0, 5), (1, 5), (4, 5), (1, 3), (3, 1)]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nodes = ['Z', 'X', 'A', 'B', 'C', 'Y']\n",
    "G = createGraph([\n",
    "    'Z->X',\n",
    "    'X->A',\n",
    "    'A->B',\n",
    "    'B->C',\n",
    "    'Z->Y',\n",
    "    'X->Y',\n",
    "    'C->Y',\n",
    "    'X<->B',\n",
    "])\n",
    "\n",
    "G.get_edgelist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "() P(y)\n",
      "('Z',) P(y|z)\n",
      "('X',) \\sum_{a, b, c, z}P(a|x, z)P(c, y|a, b, x, z)P(z)\\left(\\sum_{x}P(b|a, x, z)P(x|z)\\right)\n",
      "('A',) \\sum_{b, c, x, z}P(b, c, y|a, x, z)P(x, z)\n",
      "('B',) \\sum_{c, x, z}P(c, y|a, b, x, z)P(x, z)\n",
      "('C',) \\sum_{x, z}P(x, z)P(y|a, b, c, x, z)\n",
      "('Z', 'X') \\sum_{a, b, c}P(a|x, z)P(c, y|a, b, x, z)\\left(\\sum_{x}P(b|a, x, z)P(x|z)\\right)\n",
      "('Z', 'A') \\sum_{b, c, x}P(b, c, y|a, x, z)P(x|z)\n",
      "('Z', 'B') \\sum_{c, x}P(c, y|a, b, x, z)P(x|z)\n",
      "('Z', 'C') \\sum_{x}P(x|z)P(y|a, b, c, x, z)\n",
      "('X', 'A') \\sum_{b, c, z}P(c, y|a, b, x, z)P(z)\\left(\\sum_{x}P(b|a, x, z)P(x|z)\\right)\n",
      "('X', 'B') \\sum_{c, z}P(c, y|a, b, x, z)P(z)\n",
      "('X', 'C') \\sum_{z}P(y|a, b, c, x, z)P(z)\n",
      "('A', 'B') \\sum_{c, x, z}P(c, y|a, b, x, z)P(x, z)\n",
      "('A', 'C') \\sum_{x, z}P(x, z)P(y|a, b, c, x, z)\n",
      "('B', 'C') \\sum_{x, z}P(x, z)P(y|a, b, c, x, z)\n",
      "('Z', 'X', 'A') \\sum_{b, c}P(c, y|a, b, x, z)\\left(\\sum_{x}P(b|a, x, z)P(x|z)\\right)\n",
      "('Z', 'X', 'B') P(y|a, b, x, z)\n",
      "('Z', 'X', 'C') P(y|a, b, c, x, z)\n",
      "('Z', 'A', 'B') \\sum_{c, x}P(c, y|a, b, x, z)P(x|z)\n",
      "('Z', 'A', 'C') \\sum_{x}P(x|z)P(y|a, b, c, x, z)\n",
      "('Z', 'B', 'C') \\sum_{x}P(x|z)P(y|a, b, c, x, z)\n",
      "('X', 'A', 'B') \\sum_{c, z}P(c, y|a, b, x, z)P(z)\n",
      "('X', 'A', 'C') \\sum_{z}P(y|a, b, c, x, z)P(z)\n",
      "('X', 'B', 'C') \\sum_{z}P(y|a, b, c, x, z)P(z)\n",
      "('A', 'B', 'C') \\sum_{x, z}P(x, z)P(y|a, b, c, x, z)\n",
      "('Z', 'X', 'A', 'B') P(y|a, b, x, z)\n",
      "('Z', 'X', 'A', 'C') P(y|a, b, c, x, z)\n",
      "('Z', 'X', 'B', 'C') P(y|a, b, c, x, z)\n",
      "('Z', 'A', 'B', 'C') \\sum_{x}P(x|z)P(y|a, b, c, x, z)\n",
      "('X', 'A', 'B', 'C') \\sum_{z}P(y|a, b, c, x, z)P(z)\n",
      "('Z', 'X', 'A', 'B', 'C') P(y|a, b, c, x, z)\n"
     ]
    }
   ],
   "source": [
    "for S in parts_of(nodes[:-1]):\n",
    "    print(S, ID({'Y'}, set(S), G).printLatex())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$\\sum_{a, b, c, z}P(a|x, z)P(c, y|a, b, x, z)P(z)\\left(\\sum_{x}P(b|a, x, z)P(x|z)\\right)$"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dcg_shap",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
