{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8c2492f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from cppkauri import Kauri\n",
    "from torchdouglas import TorchDouglas\n",
    "from data import get_data\n",
    "from sklearn import metrics\n",
    "import numpy as np\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "70998849",
   "metadata": {},
   "outputs": [],
   "source": [
    "X,y = get_data(\"congressional_votes\", \"../data/datasets\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1d541bb1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((435, 16), (435,))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape, y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0696773a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Kauri(max_clusters=2, verbose=True, random_state=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "de67f98e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initialising variables\n",
      "Starting main loop\n",
      "Gain is: 2396.9158191126116\n",
      "=> Cut is on feature 4 <= 0.0\n",
      "=> From (0), assignments are L = 1 / R = 0\n",
      "=> Sizes are: L = 223 / R = 212\n",
      "Gain is: 18.835162148544413\n",
      "=> Cut is on feature 7 <= -1.0\n",
      "=> From (0), assignments are L = 0 / R = 1\n",
      "=> Sizes are: L = 6 / R = 217\n",
      "Gain is: 39.09781114249009\n",
      "=> Cut is on feature 8 <= -1.0\n",
      "=> From (0), assignments are L = 0 / R = 1\n",
      "=> Sizes are: L = 4 / R = 2\n"
     ]
    }
   ],
   "source": [
    "y_pred = model.fit_predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "137c29c9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.4713023654569267"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics.adjusted_rand_score(y_pred, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5e4b73ee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2696.5269533231863"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.score(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2438efbb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.7314814814814815, 0.045662100456621)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y[np.where(y_pred==0)[0]].mean(), y[np.where(y_pred==1)[0]].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ce162a7a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.15632183908045977, 0.8436781609195402)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(y_pred==y).mean(), (1-y_pred==y).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f7e08508",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_kauri_tree(kauri_tree, feature_names=None):\n",
    "\n",
    "    def print_node(node_id):\n",
    "        current_depth = kauri_tree.tree_.depths[node_id]\n",
    "        print(\"| \" * current_depth, f\"Node {node_id}\", sep=\"\")\n",
    "        left_child = kauri_tree.tree_.children_left[node_id]\n",
    "        right_child = kauri_tree.tree_.children_right[node_id]\n",
    "        if left_child == -1:\n",
    "            print(\"| \" * current_depth, f\"Cluster: {kauri_tree.tree_.target[node_id]}\")\n",
    "            return\n",
    "        feature = kauri_tree.tree_.features[node_id]\n",
    "        threshold = kauri_tree.tree_.thresholds[node_id]\n",
    "        if feature_names is not None:\n",
    "            feature_name = feature_names[feature]\n",
    "        else:\n",
    "            feature_name = f\"X[:, {feature}]\"\n",
    "        print(\"| \" * current_depth, \"|=\", f\"{feature_name} <= {threshold}\", sep=\"\")\n",
    "        print_node(left_child)\n",
    "        print(\"| \" * current_depth, \"|=\", f\"{feature_name} > {threshold}\", sep=\"\")\n",
    "        print_node(right_child)\n",
    "\n",
    "    print_node(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7eef14ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node 0\n",
      "|=X[:, 4] <= 0.0\n",
      "| Node 1\n",
      "| |=X[:, 7] <= -1.0\n",
      "| | Node 3\n",
      "| | |=X[:, 8] <= -1.0\n",
      "| | | Node 5\n",
      "| | |  Cluster: 0\n",
      "| | |=X[:, 8] > -1.0\n",
      "| | | Node 6\n",
      "| | |  Cluster: 1\n",
      "| |=X[:, 7] > -1.0\n",
      "| | Node 4\n",
      "| |  Cluster: 1\n",
      "|=X[:, 4] > 0.0\n",
      "| Node 2\n",
      "|  Cluster: 0\n"
     ]
    }
   ],
   "source": [
    "print_kauri_tree(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e19d6ad",
   "metadata": {},
   "source": [
    "# Same for the douglas tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "fe4363d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73335bff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ba12b5c9a5a4343be651e9be38a0157",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/30 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# We are going to work only the selected features by the Kauri model\n",
    "active_points = np.zeros(X.shape[1])\n",
    "aris = []\n",
    "for i in tqdm(range(30)):\n",
    "    douglas_model = TorchDouglas(n_clusters=2, n_cuts=1, n_epochs=200, temperature=1)\n",
    "    douglas_model.fit(X)\n",
    "    \n",
    "    y_douglas = douglas_model.predict(X)\n",
    "    \n",
    "    for (i,c) in douglas_model.cut_points_list_:\n",
    "        feature = X[:,i]\n",
    "        c = c.item()\n",
    "        if np.all(feature<=c) or np.all(feature>=c):\n",
    "            continue\n",
    "        else:\n",
    "            active_points[i]+=1\n",
    "    \n",
    "    aris += [(metrics.adjusted_rand_score(y,y_douglas), metrics.adjusted_rand_score(y_pred, y_douglas))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c05c3d20",
   "metadata": {},
   "outputs": [],
   "source": [
    "sum([x[0] for x in aris])/len(aris), sum([x[1] for x in aris])/len(aris),"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b46f2762",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "active_points/30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25fcb315",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.argsort(active_points)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
