{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "currentdir = os.getcwd()\n",
    "parentdir = os.path.dirname(currentdir)\n",
    "sys.path.insert(0, parentdir) \n",
    "print(parentdir)\n",
    "import utils\n",
    "\n",
    "import pathlib\n",
    "import pickle\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "from jax import jit, vmap\n",
    "\n",
    "from models import catboost_model\n",
    "\n",
    "\n",
    "# Check if we are on a GPU/Metal\n",
    "import jax\n",
    "from jax.lib import xla_bridge\n",
    "print(xla_bridge.get_backend().platform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(23424, 34)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from fourier_extractor.jax_fourier_extractor import compute_fourier\n",
    "from algorithms.jax_fourier_explainer_new import pad_freqs, fourier_shap_broadcast, broadcast_slice_inputs\n",
    "import time\n",
    "\n",
    "dataset = 'sgemm'\n",
    "model = 'nn'\n",
    "depth = ''\n",
    "b = 15\n",
    "no_background_samples = 50\n",
    "no_test_samples = 50\n",
    "\n",
    "def freq_array_to_indices(freq_array):\n",
    "    row_indices, col_indices = jnp.nonzero(freq_array)\n",
    "    # Split the column indices by row\n",
    "    return tuple([tuple(col_indices[row_indices == i].tolist()) for i in range(freq_array.shape[0])])\n",
    "\n",
    "x_train, x_test, _, _ = utils.get_dataset(dataset, with_splits=True)\n",
    "r = compute_fourier(dataset, model, b, depth, save_result=True)\n",
    "if len(r) == 2:\n",
    "    freq_array, amp_array = r\n",
    "else:\n",
    "    freq_array, amp_array, _ = r\n",
    "freq_array = freq_array.transpose()\n",
    "\n",
    "# Prune spectrum\n",
    "threshold = 0.0005\n",
    "mask = abs(amp_array) > threshold\n",
    "freq_array = freq_array[mask]\n",
    "amp_array = amp_array[mask]\n",
    "\n",
    "if 1 in amp_array.shape:\n",
    "    amp_array = amp_array.squeeze()\n",
    "dataset_settings = utils.get_task_settings()\n",
    "background_samples = dataset_settings[\"background_samples\"][dataset]\n",
    "test_samples = dataset_settings[\"test_samples\"][dataset]\n",
    "shap_values, times = [], []\n",
    "\n",
    "X_train = jnp.array(x_train[0:no_background_samples], dtype=jnp.int32)\n",
    "X_test = jnp.array(x_test[0:no_test_samples], dtype=jnp.int32)\n",
    "\n",
    "# padded_freq_array, freq_mask = pad_freqs(freq_array_to_indices(freq_array))\n",
    "# broadcasted_sliced_x_background_array = broadcast_slice_inputs(X_train, padded_freq_array, freq_mask)\n",
    "\n",
    "freq_ones = freq_array_to_indices(freq_array)\n",
    "padded_freq_array = pad_freqs(freq_ones)\n",
    "tuple_amps = tuple(amp_array.tolist())\n",
    "\n",
    "# # Warm-up round\n",
    "# now = time.time()\n",
    "# fourier_shap_broadcast(freq_ones, tuple_amps, padded_freq_array, X_train, X_test)\n",
    "# then = time.time()\n",
    "# print(then - now)\n",
    "# now = time.time()\n",
    "# shap_values = fourier_shap_broadcast(freq_ones, tuple_amps, padded_freq_array, X_train, X_test).block_until_ready()\n",
    "# then = time.time()\n",
    "# print(then - now)\n",
    "freq_array.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1782529354095459\n",
      "0.003831624984741211\n"
     ]
    }
   ],
   "source": [
    "from algorithms.jax_fourier_explainer import fourier_shap\n",
    "\n",
    "X_train = X_train.astype(jnp.bfloat16)\n",
    "X_test = X_test.astype(jnp.bfloat16)\n",
    "freq_array = freq_array.astype(jnp.bfloat16)\n",
    "amp_array = amp_array.astype(jnp.bfloat16)\n",
    "\n",
    "# X_train = X_train.astype(jnp.int8)\n",
    "# X_test = X_test.astype(jnp.int8)\n",
    "# freq_array = freq_array.astype(jnp.int8)\n",
    "# # amp_array = amp_array.astype(jnp.int32)\n",
    "\n",
    "\n",
    "# Warm-up round\n",
    "now = time.time()\n",
    "shap_values = fourier_shap(freq_array, amp_array, X_train, X_test).block_until_ready()\n",
    "then = time.time()\n",
    "print(then - now)\n",
    "now = time.time()\n",
    "shap_values = fourier_shap(freq_array, amp_array, X_train, X_test).block_until_ready()\n",
    "then = time.time()\n",
    "print(then - now)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0026810169219970703\n",
      "0.0025413036346435547\n"
     ]
    }
   ],
   "source": [
    "from algorithms.jax_fourier_explainer import fourier_shap_precompute, get_multiplier_matrix\n",
    "\n",
    "multiplier_matrix = get_multiplier_matrix(freq_array, amp_array, X_train).block_until_ready()\n",
    "\n",
    "now = time.time()\n",
    "shap_values = fourier_shap_precompute(freq_array, multiplier_matrix, X_train, X_test).block_until_ready()\n",
    "then = time.time()\n",
    "print(then - now)\n",
    "now = time.time()\n",
    "shap_values = fourier_shap_precompute(freq_array, multiplier_matrix, X_train, X_test).block_until_ready()\n",
    "then = time.time()\n",
    "print(then - now)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shap\n",
    "import models.nn_model, models.random_forest_model, models.catboost_model\n",
    "\n",
    "def load_model(dataset, model, depth=\"\"):\n",
    "    if model == \"nn\":\n",
    "        model = models.nn_model.load_model(dataset, best=True)\n",
    "\n",
    "    elif model == \"random_forest\":\n",
    "        model = models.random_forest_model.load_model(dataset, depth)\n",
    "\n",
    "    elif model == \"catboost\":\n",
    "        model = models.catboost_model.load_model(dataset, depth)\n",
    "\n",
    "    return model\n",
    "\n",
    "f = load_model(dataset, model, depth)\n",
    "if model == \"nn\":\n",
    "    explainer = shap.KernelExplainer(f.custom_forward, x_train[0:no_background_samples])\n",
    "else:\n",
    "    explainer = shap.KernelExplainer(f.predict, x_train[0:no_background_samples])\n",
    "result_kernel_shap = explainer.shap_values(x_test[0:no_test_samples]).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(0.90719265, dtype=float64)"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def r2_score(y_true, y_pred):\n",
    "    # Calculate the residual sum of squares\n",
    "    ss_res = np.sum((y_true - y_pred) ** 2)\n",
    "    \n",
    "    # Calculate the total sum of squares\n",
    "    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)\n",
    "    \n",
    "    # Calculate R^2\n",
    "    r2 = 1 - (ss_res / ss_tot)\n",
    "    return r2\n",
    "\n",
    "r2_score(result_kernel_shap.flatten(), shap_values.flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
