{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import combinations\n",
    "\n",
    "import causaleffect as ce"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_id(graph, y: set[str], x: set[str]) -> bool:\n",
    "    try:\n",
    "        ce.ID(y, x, graph)\n",
    "        return True\n",
    "    except ce.id.NoCaseTriggered:\n",
    "        # Should happen\n",
    "        raise\n",
    "    except ce.id.HedgeFound:\n",
    "        # Not identifiable\n",
    "        return False\n",
    "\n",
    "\n",
    "def do_shap_id(graph, target: str, V: tuple[str]) -> bool:\n",
    "    for i in range(len(V) + 1):\n",
    "        for comb in combinations(V, i):\n",
    "            if not is_id(graph, {target}, set(comb)):\n",
    "                return False  # if only one is, it's not identifiable\n",
    "    else:\n",
    "        return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Markovian case\n",
    "graph = ce.createGraph([\n",
    "    'u->x',\n",
    "    'u->b',\n",
    "    'z->x',\n",
    "    'x->a',\n",
    "    'a->b',\n",
    "    'b->c',\n",
    "    \n",
    "    'z->y',\n",
    "    'x->y',\n",
    "    'c->y',\n",
    "])\n",
    "\n",
    "assert do_shap_id(graph, 'y', tuple('uzxabc'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Semi-Markovian case\n",
    "graph = ce.createGraph([\n",
    "    'x<->b',  # latent confounder\n",
    "    'z->x',\n",
    "    'x->a',\n",
    "    'a->b',\n",
    "    'b->c',\n",
    "\n",
    "    'z->y',\n",
    "    'x->y',\n",
    "    'c->y'\n",
    "])\n",
    "\n",
    "assert do_shap_id(graph, 'y', tuple('zxabc'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bike Rental case\n",
    "graph = ce.createGraph([\n",
    "    'hour->temperature',\n",
    "    'season->temperature',\n",
    "    'season->weather',\n",
    "    'hour->humidity',\n",
    "    'weather->humidity',\n",
    "    'hour->windspeed',\n",
    "    'weather->windspeed',\n",
    "    'hour->bikes',\n",
    "    'temperature->bikes',\n",
    "    'humidity->bikes',\n",
    "    'windspeed->bikes',\n",
    "    'workingday->bikes',\n",
    "])\n",
    "\n",
    "assert do_shap_id(graph, 'bikes', (\n",
    "    'hour', 'season', 'temperature', 'weather', \n",
    "    'humidity', 'windspeed', 'workingday'\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Diabetes case\n",
    "graph = ce.createGraph(list(filter(bool,\n",
    "'''\n",
    "PhysActivity->BMI\n",
    "PhysActivity->HighBP\n",
    "\n",
    "Fruits->BMI\n",
    "Fruits->HighChol\n",
    "\n",
    "Veggies->BMI\n",
    "Veggies->HighChol\n",
    "\n",
    "BMI->HighChol\n",
    "BMI->CholCheck\n",
    "\n",
    "HighChol->HighBP\n",
    "HighChol->CholCheck\n",
    "\n",
    "HighBP->Stroke\n",
    "HighBP->HeartAttack\n",
    "\n",
    "Smoker->Stroke\n",
    "\n",
    "HeartAttack->CholCheck\n",
    "Stroke->CholCheck\n",
    "\n",
    "BMI->Diabetes\n",
    "HighBP->Diabetes\n",
    "HighChol->Diabetes\n",
    "'''.split('\\n')\n",
    ")))\n",
    "\n",
    "assert do_shap_id(graph, 'Diabetes', (\n",
    "    'PhysActivity', 'Fruits', 'Veggies',\n",
    "    'BMI', 'HighChol', 'HighBP'\n",
    "))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dcg_shap",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
