{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# Identifiability Example\n",
    "# In this notebook, we will demonstrate how to use the `find_confounded_paths` function to identify confounded paths in a directed acyclic graph (DAG) and check the identifiability of queries using the `check_identifiable_query` function.\n",
    "\n",
    "Algorithms 5 and 6 in DeCaFlow paper."
   ],
   "id": "e664018afcca0d69"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-20T21:50:20.181899Z",
     "start_time": "2025-04-20T21:50:18.323063Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import pytest\n",
    "import torch\n",
    "from decaflow.utils.identifiability import (find_confounded_paths,\n",
    "                                            check_identifiable_query,\n",
    "                                            check_identifiable_query_on_all_descendants)\n",
    "import networkx as nx"
   ],
   "id": "5cab9211a4006b80",
   "outputs": [],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-20T21:50:20.192402Z",
     "start_time": "2025-04-20T21:50:20.184416Z"
    }
   },
   "cell_type": "code",
   "source": [
    "miao_graph = nx.DiGraph()\n",
    "\n",
    "miao_graph.add_edges_from([\n",
    "    ('Z', 'W'),\n",
    "    ('Z', 'T'),\n",
    "    ('Z', 'N'),\n",
    "    ('Z', 'Y'),\n",
    "    ('W', 'Y'),\n",
    "    ('N', 'T'),\n",
    "    ('T', 'Y')\n",
    "])\n",
    "non_confounded_set, confounded_dict, frontdoor_set  = find_confounded_paths(miao_graph, ['Z'], frontdoor=True)\n",
    "queries = [('N', 'T'), ('W', 'Y'), ('T', 'Y')]\n",
    "assert ('N', 'T') in confounded_dict.keys()\n",
    "assert ('W', 'Y') in confounded_dict.keys()\n",
    "assert('T', 'Y') in confounded_dict.keys()\n",
    "\n",
    "assert check_identifiable_query(miao_graph, ('T', 'Y'), confounded_dict, ['Z'])\n",
    "\n",
    "assert not check_identifiable_query(miao_graph, ('N', 'T'), confounded_dict, ['Z'])\n",
    "assert not check_identifiable_query(miao_graph, ('W', 'Y'), confounded_dict, ['Z'])\n",
    "\n",
    "for query in queries:\n",
    "    if query in non_confounded_set:\n",
    "        print(f'Query {query} is non-confounded')\n",
    "    else:\n",
    "        is_id = check_identifiable_query(miao_graph, query, confounded_dict, ['Z'])\n",
    "        if is_id:\n",
    "            print(f'Query {query} is identifiable')\n",
    "        else:\n",
    "            print(f'Query {query} is not identifiable')"
   ],
   "id": "d3a4bf12409c74f0",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query ('N', 'T') is not identifiable\n",
      "Query ('W', 'Y') is not identifiable\n",
      "Query ('T', 'Y') is identifiable\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-04-20T21:50:20.441165Z",
     "start_time": "2025-04-20T21:50:20.436100Z"
    }
   },
   "cell_type": "code",
   "source": [
    "two_outcomes_graph = nx.DiGraph()\n",
    "two_outcomes_graph.add_edges_from([\n",
    "    ('Z', 'T'),\n",
    "    ('Z', 'Y1'),\n",
    "    ('T', 'Y1'),\n",
    "    ('T', 'Y2')\n",
    "])\n",
    "\n",
    "non_confounded_set, confounded_dict, frontdoor_set  = find_confounded_paths(two_outcomes_graph, ['Z'], frontdoor=True)\n",
    "is_id = check_identifiable_query_on_all_descendants(two_outcomes_graph, 'T', confounded_dict, ['Z'])\n",
    "print('T is identifiable on all descendants:', is_id)"
   ],
   "id": "2be12d70c0e7a046",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T is identifiable on all descendants: False\n"
     ]
    }
   ],
   "execution_count": 3
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3 (ipykernel)"
  }
 },
 "nbformat": 5,
 "nbformat_minor": 9
}
