{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c843b5d-2d50-42dd-9516-24f5afc1be8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "currentdir = os.getcwd()\n",
    "parentdir = os.path.dirname(currentdir)\n",
    "sys.path.insert(0, parentdir) \n",
    "print(parentdir)\n",
    "\n",
    "import trees.shap_model as shap_model\n",
    "import trees.tree_to_fourier as tree_to_fourier\n",
    "import trees.preprocess_forest as preprocess_forest\n",
    "import numpy as np\n",
    "from algorithms.jax_fourier_explainer import forest_fourier_shap, forest_fourier_shap_compact, \\\n",
    "    get_forest_multiplier_matrix, forest_fourier_shap_precompute\n",
    "from algorithms.fourier_explainer import fourier_explainer_matrix_batch\n",
    "from jax import vmap, jit\n",
    "import jax.numpy as jnp\n",
    "import time\n",
    "import tqdm\n",
    "from shap import TreeExplainer, GPUTreeExplainer, KernelExplainer\n",
    "import jax\n",
    "from sklearn.metrics import r2_score\n",
    "import pandas as pd\n",
    "import argparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6823a016-b53b-4d39-87a3-ae9f868776d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "154b037d-5c3a-4513-823b-7c0dad69555b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fourier_to_jax(n, fourier_forest, amp_threshold=0.005):\n",
    "    freq_list = []\n",
    "    amp_list = []\n",
    "    for tree_fourier in fourier_forest:\n",
    "        freqs = []\n",
    "        amps = []\n",
    "        for j, (set_freq, amp) in enumerate(tree_fourier.items()):\n",
    "            if abs(amp) >= amp_threshold:\n",
    "                freq = np.zeros((n,), dtype=np.int32)\n",
    "                index = list(set_freq)\n",
    "                freq[index] = 1\n",
    "                freqs.append(freq)\n",
    "                amps.append(amp)\n",
    "\n",
    "        freq_list.append(np.array(freqs))\n",
    "        amp_list.append(np.array(amps))\n",
    "\n",
    "\n",
    "    no_coefficients = max([len(freq_array) for freq_array in freq_list])\n",
    "    for i in range(len(freq_list)):\n",
    "        freq_list[i] = np.concatenate((freq_list[i], np.zeros((no_coefficients-len(freq_list[i]), n), dtype=np.int32)), axis=0)\n",
    "        amp_list[i] = np.concatenate((amp_list[i], np.zeros((no_coefficients-len(amp_list[i]),), dtype=np.float32)), axis=0)\n",
    "    \n",
    "    return jnp.array(freq_list, dtype=jnp.int32), jnp.array(amp_list, dtype=jnp.float32)\n",
    "    \n",
    "def fourier_output(fourier, inputs):\n",
    "    pred = np.zeros((inputs.shape[0]))\n",
    "    for freq, amp in fourier.items():\n",
    "        sign = np.sum(inputs[:, list(freq)], axis=1, keepdims=False)\n",
    "        pred += float(amp) * np.where((sign % 2) == 1, -1, 1)\n",
    "    \n",
    "    return pred\n",
    "\n",
    "def measure_shap_computation(sm):\n",
    "    # treeshap\n",
    "    explainer = TreeExplainer(sm.rf, data=sm.X_train, feature_perturbation=\"interventional\")\n",
    "    dummy = explainer.shap_values(np.array(sm.X_test[-10]), check_additivity=False)\n",
    "    times_tree_shap = []\n",
    "    for j in tqdm.tqdm(range(5)):\n",
    "        now = time.time()\n",
    "        tree_shap = explainer.shap_values(sm.X_test, check_additivity=False)\n",
    "        later = time.time()\n",
    "        times_tree_shap.append(later-now)\n",
    "\n",
    "    # GPU treeshap\n",
    "    explainer = GPUTreeExplainer(sm.rf, data=sm.X_train, feature_perturbation=\"interventional\")\n",
    "    dummy = explainer.shap_values(sm.X_test[-10], check_additivity=False)\n",
    "    times_gpu_tree_shap = []\n",
    "    for j in tqdm.tqdm(range(5)):\n",
    "        now = time.time()\n",
    "        gpu_tree_shap = explainer.shap_values(sm.X_test, check_additivity=False)\n",
    "        later = time.time()\n",
    "        times_gpu_tree_shap.append(later-now)\n",
    "\n",
    "    # fouriershap\n",
    "    try:\n",
    "        freq_array, amp_array = fourier_to_jax(sm.n, fourier_forest)\n",
    "\n",
    "        X_train = jnp.array(sm.X_train, dtype=jnp.bfloat16)\n",
    "        X_test = jnp.array(sm.X_test, dtype=jnp.bfloat16)\n",
    "        freq_array = jnp.array(freq_array, dtype=jnp.bfloat16)\n",
    "        amp_array = jnp.array(amp_array, dtype=jnp.bfloat16)\n",
    "\n",
    "        # dummy call for jit compilation\n",
    "        jax_fourier = forest_fourier_shap\n",
    "        now = time.time()\n",
    "        jax_fourier(freq_array, amp_array, X_train, X_test).block_until_ready()\n",
    "        later = time.time()\n",
    "        dummy_time = later - now\n",
    "\n",
    "        times_fourier = []\n",
    "        for j in tqdm.tqdm(range(5)):\n",
    "            now = time.time()\n",
    "            jax_shap = jax_fourier(freq_array, amp_array, X_train, X_test).block_until_ready()\n",
    "            later = time.time()\n",
    "            times_fourier.append(later-now)\n",
    "    except:\n",
    "        dummy_time = -1\n",
    "        jax_shap = np.zeros(shape=tree_shap.shape)\n",
    "        times_fourier = [-1]\n",
    "\n",
    "    # fouriershap precompute\n",
    "    try:\n",
    "        freq_array, amp_array = fourier_to_jax(sm.n, fourier_forest)\n",
    "\n",
    "        X_train = jnp.array(sm.X_train, dtype=jnp.bfloat16)\n",
    "        X_test = jnp.array(sm.X_test, dtype=jnp.bfloat16)\n",
    "        freq_array = jnp.array(freq_array, dtype=jnp.bfloat16)\n",
    "        amp_array = jnp.array(amp_array, dtype=jnp.bfloat16)\n",
    "        multiplier_matrix = get_forest_multiplier_matrix(freq_array, amp_array, X_train)\n",
    "\n",
    "        # dummy call for jit compilation\n",
    "        jax_fourier = forest_fourier_shap\n",
    "        now = time.time()\n",
    "        forest_fourier_shap_precompute(freq_array, multiplier_matrix, X_train, X_test).block_until_ready()\n",
    "        later = time.time()\n",
    "        dummy_time_precompute = later - now\n",
    "\n",
    "        times_fourier_precompute = []\n",
    "        for j in tqdm.tqdm(range(5)):\n",
    "            now = time.time()\n",
    "            jax_precompute_shap = forest_fourier_shap_precompute(freq_array, multiplier_matrix, X_train, X_test).block_until_ready()\n",
    "            later = time.time()\n",
    "            times_fourier_precompute.append(later-now)\n",
    "    except:\n",
    "        dummy_time_precompute = -1\n",
    "        jax_precompute_shap = np.zeros(shape=tree_shap.shape)\n",
    "        times_fourier = [-1]\n",
    "\n",
    "    # # fourier shap compact\n",
    "    # try:\n",
    "    #     freq_ones_array, one_counts_array, amps_array, max_ones, max_freqs = preprocess_forest.fourier_to_jax_compact(sm.n, fourier_forest)\n",
    "\n",
    "    #     shap_shape = (len(freq_ones_array), sm.X_test.shape[0], sm.X_test.shape[1])\n",
    "    #     X_background = np.zeros(shape=(len(freq_ones_array), max_freqs, sm.X_train.shape[0], max_ones))\n",
    "    #     X_query = np.zeros(shape=(len(freq_ones_array), max_freqs, sm.X_test.shape[0], max_ones))\n",
    "    #     for t in range(len(freq_ones_array)):\n",
    "    #         for i, freq_ones in enumerate(freq_ones_array[t]):\n",
    "    #             X_background[t, i, :, :len(freq_ones)] = sm.X_train[:, list(freq_ones)]\n",
    "    #             X_query[t, i, :, :len(freq_ones)] = sm.X_test[:, list(freq_ones)]\n",
    "\n",
    "    #     signs_array = np.where(np.sum(X_background, axis=-1) % 2 == 0, -1, 1)\n",
    "\n",
    "    #     X_background = jnp.array(X_background, dtype=jnp.int32)\n",
    "    #     X_query = jnp.array(X_query, dtype=jnp.int32)\n",
    "    #     signs_array = jnp.array(signs_array, dtype=jnp.int32)\n",
    "\n",
    "    #     # Convert arrays to tuple so become hashable\n",
    "    #     freq_ones_array = tuple(freq_ones_array)\n",
    "    #     amps_array = tuple(amps_array)\n",
    "\n",
    "    #     jax_fourier_compact = jit(forest_fourier_shap_compact, static_argnums=(0,1,2))\n",
    "\n",
    "    #     now = time.time()\n",
    "    #     dummy_X_background = np.random.rand(*X_background.shape) > 0.5\n",
    "    #     dummy_X_query = np.random.rand(*X_query.shape) > 0.5\n",
    "    #     dummy_signs_array = np.where(np.sum(dummy_X_background, axis=-1) % 2 == 0, -1, 1)\n",
    "    #     dummy_shap = jax_fourier_compact(\n",
    "    #             freq_ones_array,\n",
    "    #             amps_array,\n",
    "    #             shap_shape,\n",
    "    #             jnp.array(dummy_signs_array, dtype=jnp.int32),\n",
    "    #             jnp.array(dummy_X_background, dtype=jnp.int32),\n",
    "    #             jnp.array(dummy_X_query, dtype=jnp.int32),\n",
    "    #         ).block_until_ready()\n",
    "    #     later = time.time()\n",
    "    #     dummy_time_compact = later - now\n",
    "\n",
    "    #     times_fourier_compact = []\n",
    "    #     for j in tqdm.tqdm(range(5)):\n",
    "    #         now = time.time()\n",
    "    #         jax_compact_shap = jax_fourier_compact(\n",
    "    #                 freq_ones_array,\n",
    "    #                 amps_array,\n",
    "    #                 shap_shape,\n",
    "    #                 signs_array,\n",
    "    #                 X_background, \n",
    "    #                 X_query,\n",
    "    #             ).block_until_ready()\n",
    "    #         later = time.time()\n",
    "    #         times_fourier_compact.append(later-now)\n",
    "    #     print(times_fourier_compact)\n",
    "    # except Exception as e:\n",
    "    #     print(f\"Error: {e}\")\n",
    "    #     dummy_time_compact = -1\n",
    "    #     jax_compact_shap = np.zeros(shape=tree_shap.shape)\n",
    "    #     times_fourier_compact = [-1]\n",
    "\n",
    "\n",
    "    # if sm.X_test.shape[0] < 100:\n",
    "    #     now = time.time()\n",
    "    #     explainer = KernelExplainer(sm.rf.predict, sm.X_train.astype(np.int32))\n",
    "    #     later = time.time()\n",
    "    #     kernelshap_setup_time = later-now\n",
    "    #     times_kernel_shap = []\n",
    "    #     for j in tqdm.tqdm(range(1)):\n",
    "    #         now = time.time()\n",
    "    #         kernel_shap = explainer.shap_values(sm.X_test.astype(np.int32), nsamples=10000)\n",
    "    #         later = time.time()\n",
    "    #         times_kernel_shap.append(later-now)\n",
    "    # else:\n",
    "    #     kernelshap_setup_time = [-1]\n",
    "    #     kernel_shap = np.zeros(shape=tree_shap.shape)\n",
    "    #     times_kernel_shap = [-1]\n",
    "\n",
    "    # times_classic_fourier = []\n",
    "    # for j in tqdm.tqdm(range(6)):\n",
    "    #     classic_shap = np.zeros([sm.X_test.shape[0], sm.X_test.shape[1]])\n",
    "    #     now = time.time()\n",
    "    #     for t in range(freq_array.shape[0]):\n",
    "    #         # classic_shap += classic_fourier_explainer(fourier_forest[t], X_train[t], X_test[t])\n",
    "    #         classic_shap += fourier_explainer_matrix_batch(fourier_forest[t], sm.X_train, sm.X_test)\n",
    "    #     classic_shap /= len(fourier_forest)\n",
    "    #     later = time.time()\n",
    "    #     times_classic_fourier.append(later - now)\n",
    "\n",
    "    log = {\n",
    "        \"Jax GPU tree shap quality\": r2_score(tree_shap.flatten(), gpu_tree_shap.flatten()),\n",
    "        \"Jax shap quality\": r2_score(tree_shap.flatten(), jax_shap.astype(jnp.float32).flatten()),\n",
    "        \"Jax precompute shap quality\": r2_score(tree_shap.flatten(), jax_precompute_shap.astype(jnp.float32).flatten()),\n",
    "        # \"Jax compact shap quality\": r2_score(tree_shap, jax_compact_shap),\n",
    "        # \"Kernelshap quality\": r2_score(tree_shap.flatten(), kernel_shap.flatten()),\n",
    "        # \"Jax precompute shap quality over Kernelshap\": r2_score(kernel_shap.flatten(), jax_precompute_shap.astype(jnp.float32).flatten()),\n",
    "        # \"Jax shap quality (to classic)\": r2_score(fourier_shap, classic_shap),\n",
    "        # \"Classic shap quality\": r2_score(tree_shap, classic_shap),\n",
    "        \"Jax compilation time\": dummy_time,\n",
    "        \"Jax precompute compilation time\": dummy_time_precompute,\n",
    "        # \"Jax compact compilation time\": dummy_time_compact,\n",
    "        \"Jax shap time\": np.mean(times_fourier),\n",
    "        \"Jax precompute shap time\": np.mean(times_fourier_precompute),\n",
    "        # \"Jax compact shap time\": np.mean(times_fourier_compact),\n",
    "        # \"Classic fourier shap time\": np.mean(times_classic_fourier[1:]),\n",
    "        \"Treeshap time\": np.mean(times_tree_shap),\n",
    "        \"GPU Treeshap time\": np.mean(times_gpu_tree_shap),\n",
    "        # \"Kernelshap setup time\": kernelshap_setup_time,\n",
    "        # \"Kernelshap time\": np.mean(times_kernel_shap),\n",
    "        # \"Jax compact shap quality (to kernelshap)\": r2_score(kernel_shap, jax_compact_shap),\n",
    "    }\n",
    "    print('\\n'.join([f'{key}: {value}' for key, value in log.items()]))\n",
    "\n",
    "    log[\"Treeshap time list\"] = times_tree_shap\n",
    "    log[\"Treeshap GPU time list\"] = times_gpu_tree_shap\n",
    "    log[\"Jax shap time list\"] = times_fourier\n",
    "    log[\"Jax precompute shap time list\"] = times_fourier_precompute\n",
    "    # log[\"Jax compact shap time list\"] = times_fourier_compact\n",
    "    # log[\"Kernelshap time list\"] = times_kernel_shap\n",
    "\n",
    "    return log\n",
    "\n",
    "def save_logs(name, logs):\n",
    "    # Write to csv\n",
    "    pd.DataFrame(logs).to_csv(f'logs/{dataset}.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e5f1986-d925-479e-96e4-94aa491d84b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"avgfp\"\n",
    "\n",
    "background_sizes = [100]\n",
    "query_sizes = [100]\n",
    "n_est_list = [20]\n",
    "depth_list = [3, 4, 5, 6, 7, 8]\n",
    "\n",
    "# background_sizes = [100]\n",
    "# query_sizes = [300]\n",
    "# n_est_list = [10]\n",
    "# depth_list = [5]\n",
    "\n",
    "fourier_forest = None\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    logs = []\n",
    "    for background_size in background_sizes:\n",
    "        for query_size in query_sizes:\n",
    "            for depth in depth_list:\n",
    "                for n_est in n_est_list:\n",
    "                    log = {\"Background size\": background_size,\"Query size\": query_size, \"Depth\": depth, \"# Est\": n_est}\n",
    "                    sm = shap_model.ShapModel(dataset, n_est, depth)\n",
    "                    fourier_forest = tree_to_fourier.forest_to_fourier(sm.rf)\n",
    "                    print(\"Node count of trees:\", [len(t) for t in fourier_forest])\n",
    "                    ests = [tuple(list(est.tree_.feature) + list(est.tree_.threshold)) for est in sm.rf.estimators_]\n",
    "                    log[\"# unique trees\"] = len(set(ests))\n",
    "\n",
    "                    # Limit train-test\n",
    "                    sm.X_train = sm.X_train[:background_size]\n",
    "                    sm.X_test = sm.X_test[:query_size]\n",
    "\n",
    "                    # Check Fourier quality\n",
    "                    X_train = np.array([sm.X_train]*len(fourier_forest), dtype=np.float32)\n",
    "                    X_test = np.array([sm.X_test]*len(fourier_forest), dtype=np.float32)\n",
    "                    X_test_random = np.random.rand(2000, sm.X_test.shape[1]) > 0.6\n",
    "                    fourier_pred = np.mean(np.vstack([fourier_output(tree_fourier, X_test[t]) for t, tree_fourier in enumerate(fourier_forest)]), axis=0)\n",
    "                    fourier_pred_random = np.mean(np.vstack([fourier_output(tree_fourier, X_test_random) for tree_fourier in fourier_forest]), axis=0)\n",
    "                    rf_pred = sm.rf.predict(sm.X_test)\n",
    "                    rf_pred_random = sm.rf.predict(X_test_random)\n",
    "                    log[\"Fourier quality (test)\"]= r2_score(rf_pred, fourier_pred)\n",
    "                    log[\"Fourier quality (random)\"]= r2_score(rf_pred_random, fourier_pred_random)\n",
    "                    print(log)\n",
    "\n",
    "                    log.update(measure_shap_computation(sm))\n",
    "                    logs.append(log)\n",
    "\n",
    "                    # save_logs(dataset, logs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "677a803e-cf81-4c5f-9927-1c0bc22f89ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer = TreeExplainer(sm.rf, data=sm.X_train, feature_perturbation=\"interventional\")\n",
    "dummy = explainer.shap_values(np.array(sm.X_test[-10]), check_additivity=False)\n",
    "tree_shap = explainer.shap_values(sm.X_test, check_additivity=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59591771-2b37-4dc5-a19c-e17c19873d55",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_shap[13]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e13a33b-5e24-44e4-aa92-3b8014fd46e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer = KernelExplainer(sm.rf.predict, sm.X_train.astype(np.int32))\n",
    "kernel_shap = explainer.shap_values(sm.X_test.astype(np.int32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71e4cc79-a499-4fa3-b19a-ceaedd3c7560",
   "metadata": {},
   "outputs": [],
   "source": [
    "kernel_shap[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "eec99014-8614-468f-9ab9-a04b94a55f7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "freq_array, amp_array = fourier_to_jax(sm.n, fourier_forest)\n",
    "\n",
    "X_train = jnp.array(sm.X_train, dtype=jnp.float32)\n",
    "X_test = jnp.array(sm.X_test, dtype=jnp.float32)\n",
    "freq_array = jnp.array(freq_array, dtype=jnp.float32)\n",
    "amp_array = jnp.array(amp_array, dtype=jnp.float32)\n",
    "multiplier_matrix = get_forest_multiplier_matrix(freq_array, amp_array, X_train)\n",
    "\n",
    "# dummy call for jit compilation\n",
    "jax_fourier = forest_fourier_shap\n",
    "forest_fourier_shap_precompute(freq_array, multiplier_matrix, X_train, X_test).block_until_ready()\n",
    "jax_precompute_shap = forest_fourier_shap_precompute(freq_array, multiplier_matrix, X_train, X_test).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc0f8299-589f-4ffd-a9a0-96b67e028907",
   "metadata": {},
   "outputs": [],
   "source": [
    "jax_precompute_shap[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03cffd3b-4be2-4fcf-954f-6bb3c248c448",
   "metadata": {},
   "outputs": [],
   "source": [
    "sm.rf.predict(sm.X_test) - np.mean(sm.rf.predict(sm.X_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a4dcd0b-b18b-40a6-81ba-6996c5696d19",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efebc7f5-4255-441c-bd14-02ae2a71ed19",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(sm.X_train, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7988990-82d6-4274-a0ed-ad7b2f1b7d5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "r2_score(tree_shap.flatten(), kernel_shap.flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53e38eda-e6dd-49c6-8383-9873ca3d0293",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_shap.flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff959f1f-a149-4faa-857c-d22d12466d0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "freq_array.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f63fb7c4-e623-4f8a-924d-50dbcb935bce",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
