{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install -r requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In a terminal run ```sudo apt install grapviz``` to visualize trees."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import datasets\n",
    "import numpy as np\n",
    "from dpdt.utils import build_mdp, CartAIGSelector, average_traj_length_in_mdp, extract_tree\n",
    "from dpdt.solver import backward_induction_multiple_zetas\n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import the Iris dataset and print the number of samples, the number of features, and the number of classe: $N, p, k$ ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iris_X, iris_Y = datasets.load_iris(return_X_y=True)\n",
    "\n",
    "print(\"N = {}\".format(iris_X.shape[0]))\n",
    "print(\"p = {}\".format(iris_X.shape[1]))\n",
    "print(\"K = {}\".format(len(np.unique(iris_Y))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build the MDP for trees of max depth 3.\n",
    "We use CART with a maximum depth of 3 as a tests generating function $\\phi$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phi = CartAIGSelector(depth=3)\n",
    "mdp = build_mdp(iris_X, iris_Y, max_depth=3, aig_fn=phi)\n",
    "print(\"MDP states: {}\".format(sum([len(depth) for depth in mdp])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the optimal policies $\\pi^*(.,\\alpha)$ for some values of $\\alpha$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alphas = np.linspace(-1, 0, 1000)\n",
    "policies = backward_induction_multiple_zetas(mdp, alphas)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Extract and evaluate the trees associated with each $\\pi(.,\\alpha)$ (different $\\alpha$ can give the same tree)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores, average_nb_tests, trees = (\n",
    "        np.zeros(alphas.shape[0], dtype=np.float64),\n",
    "        np.zeros(alphas.shape[0], dtype=np.float64),\n",
    "        []\n",
    "        \n",
    "    )\n",
    "\n",
    "initial_mdp_state = mdp[0][0].obs\n",
    "for i, alpha in enumerate(alphas):\n",
    "    scores[i], average_nb_tests[i] = average_traj_length_in_mdp(iris_X, iris_Y, policies, initial_mdp_state, i)\n",
    "    tree_, _, _ = extract_tree(policies, initial_mdp_state, zeta=i)\n",
    "    trees.append(tree_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot the interpretability-performance pareto front."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(average_nb_tests, scores, \"-x\")\n",
    "plt.grid()\n",
    "plt.xlabel(\"Average tests\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot some trees."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trees[0].graphviz()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trees[800].graphviz()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trees[-1].graphviz()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
