{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "af048f80-e7d8-411d-acf3-46b9ec43e1db",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "46477d51-592c-4ac1-8750-789a73250827",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from sklearn import datasets\n",
    "from sklearn.model_selection import train_test_split\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f045e3fb-30cf-45d5-85f3-e925479799a6",
   "metadata": {},
   "source": [
    "# Load CoverType from OpenML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e16c9219-649f-429d-8f5e-f1d5e7f06aee",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y_mc = datasets.fetch_openml(data_id=180, return_X_y=True, as_frame=False)\n",
    "X = X.astype(float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3774ac64-5166-4a81-9823-ccc32a674cbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# make it a binary classification problem for the sake of simplicity\n",
    "y = y_mc == \"Lodgepole_Pine\"\n",
    "xtrain, xtest, ytrain, ytest, ytrain_mc, ytest_mc = train_test_split(X, y, y_mc, test_size=0.2, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a6192bd-375f-428d-b94f-37cd77a4b241",
   "metadata": {},
   "source": [
    "# Train an XGBoost model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b1561682-2628-4447-b62b-f24d88c0ac08",
   "metadata": {},
   "outputs": [],
   "source": [
    "import veritas\n",
    "import xgboost as xgb\n",
    "import tqdm\n",
    "from sklearn.metrics import accuracy_score, roc_auc_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0686c394-8a18-4c7f-91a6-7a768732d955",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "XGB trained in 0.8577525615692139 seconds\n"
     ]
    }
   ],
   "source": [
    "params = {\n",
    "    \"n_estimators\": 100,\n",
    "    \"eval_metric\": \"auc\",\n",
    "    \"seed\": 197,\n",
    "    \"max_depth\": 10,\n",
    "    \"learning_rate\": 0.2\n",
    "}\n",
    "model = xgb.XGBClassifier(**params)\n",
    "\n",
    "t = time.time()\n",
    "model.fit(X, y)\n",
    "print(f\"XGB trained in {time.time()-t} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "41e806a9-5015-40dd-9bf3-9ee5f56bd796",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train acc: 0.945, test acc: 0.945 wrt true labels\n",
      "Train auc: 0.946, test auc: 0.947 wrt true labels\n"
     ]
    }
   ],
   "source": [
    "ytrain_pred = model.predict(xtrain)\n",
    "ytest_pred = model.predict(xtest)\n",
    "acc_train = accuracy_score(ytrain, ytrain_pred)\n",
    "acc_test = accuracy_score(ytest, ytest_pred)\n",
    "auc_train = roc_auc_score(ytrain, ytrain_pred)\n",
    "auc_test = roc_auc_score(ytest, ytest_pred)\n",
    "\n",
    "print(f\"Train acc: {acc_train:.3f}, test acc: {acc_test:.3f} wrt true labels\")\n",
    "print(f\"Train auc: {auc_train:.3f}, test auc: {auc_test:.3f} wrt true labels\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3271f1f8-695d-4b83-bab2-6760e3558bd5",
   "metadata": {},
   "source": [
    "# Generate adversarial examples\n",
    "\n",
    "We only allow changes to the first 10 numerical attribute values. The remaining 44 attribute values are binary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b7ce0ec8-69a9-4188-a697-212828123232",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# FROM CLASS 0 -> CLASS 1\n",
    "number_of_adv_examples = 10\n",
    "\n",
    "rng = np.random.default_rng(seed=19656)\n",
    "xtest0 = xtest[ytest==0, :]\n",
    "subset = xtest0[rng.choice(range(xtest0.shape[0]), number_of_adv_examples), :]\n",
    "\n",
    "eps = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "84380740-f95d-4110-a679-c66e886dfdff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "| XGBOOST's base_score\n",
      "|   base_score diff std      5.40878462362659e-07 OK\n",
      "|   base_score reported      0.46820667\n",
      "|   versus manually detected -0.1273450059940551\n",
      "|   abs err                  0.5955516759940551\n",
      "|   rel err                  1.2719846045637393\n",
      "|   (!) base_score NOT THE SAME with relative tolerance 0.001\n",
      "\n"
     ]
    }
   ],
   "source": [
    "at = veritas.get_addtree(model)\n",
    "Xstd = xtrain.std(axis=0)\n",
    "Xstd[10:] = 0.0 # allow no change in binary attributes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "142b4cf5-131c-4fc7-b991-a6bb7efa219c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  7.45it/s]\n"
     ]
    }
   ],
   "source": [
    "adv_examples = []\n",
    "for i in tqdm.tqdm(range(subset.shape[0])):\n",
    "    base_example = subset[i, :]\n",
    "    \n",
    "    # allow each attribute to vary by eps*stddev of attribute to either side\n",
    "    prune_box = [veritas.Interval(x-eps*s, x+eps*s) if s > 0.0\n",
    "                 else veritas.Interval.constant(x)\n",
    "                 for x, s in zip(base_example, Xstd)]\n",
    "    \n",
    "    config = veritas.Config(veritas.HeuristicType.MAX_OUTPUT)\n",
    "    search = config.get_search(at, prune_box)\n",
    "\n",
    "    # continue the search for at most 1 second\n",
    "    # stop early when the optimal solution is found\n",
    "    stop_reason = search.step_for(1.0, 1000)\n",
    "    \n",
    "    if search.num_solutions() > 0:\n",
    "        sol = search.get_solution(0)\n",
    "        adv_example = veritas.get_closest_example(sol, base_example, eps=1e-4)\n",
    "\n",
    "        res = {\"i\": i, \"adv_example\": adv_example, \"base_example\": base_example}\n",
    "\n",
    "        res[\"base_ypred_at\"] = at.predict(np.atleast_2d(base_example))[0]\n",
    "        res[\"base_ypred\"] = model.predict_proba(np.atleast_2d(base_example))[0,1]\n",
    "        res[\"adv_ypred\"] = model.predict_proba(np.atleast_2d(adv_example))[0,1]\n",
    "        res[\"adv_ypred_at\"] = at.predict(np.atleast_2d(adv_example))[0]\n",
    "        res[\"optimal\"] = search.is_optimal()\n",
    "\n",
    "        adv_examples.append(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "364a9553-7401-4150-b535-c439da90430d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i  p(y=1|normal)  p(y=1|adv)     success?  optimal?\n",
      "0          14.2%       43.6%            -         y\n",
      "1          14.7%       55.3%            y         y\n",
      "2           6.2%       86.4%            y         y\n",
      "3           1.6%       74.7%            y         y\n",
      "4           1.1%       44.2%            -         y\n",
      "5          11.0%       79.1%            y         y\n",
      "6           5.3%       67.6%            y         y\n",
      "7           2.7%       36.8%            -         y\n",
      "8          14.7%       93.9%            y         y\n",
      "9          74.5%       99.1%            y         y\n"
     ]
    }
   ],
   "source": [
    "print(\"i  {:15}{:15}{:>8}{:>10}\".format(\"p(y=1|normal)\", \"p(y=1|adv)\", \"success?\", \"optimal?\"))\n",
    "for res in adv_examples:\n",
    "    i = res[\"i\"]\n",
    "    success = \"y\" if res[\"adv_ypred\"]>0.5 else \"-\"\n",
    "    optimal = \"y\" if res[\"optimal\"] else \"-\"\n",
    "    base = res[\"base_ypred\"]\n",
    "    adv = res[\"adv_ypred\"]\n",
    "    print(\"{:<3}{:>13}  {:>10}     {:>8}{:>10}\".format(i, f'{base*100:4.1f}%', f'{adv*100:4.1f}%', success, optimal))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2edde6d5-4377-458d-a0fa-4510e99c4af0",
   "metadata": {},
   "source": [
    "**Which attributes differ?**\n",
    "\n",
    "> Name / Data Type / Measurement / Description\n",
    "> --------------------------------------------\n",
    "> Elevation / quantitative /meters / Elevation in meters  \n",
    "> Aspect / quantitative / azimuth / Aspect in degrees azimuth  \n",
    "> Slope / quantitative / degrees / Slope in degrees  \n",
    "> Horizontal_Distance_To_Hydrology / quantitative / meters / Horz Dist to nearest surface water features  \n",
    "> Vertical_Distance_To_Hydrology / quantitative / meters / Vert Dist to nearest surface water features  \n",
    "> Horizontal_Distance_To_Roadways / quantitative / meters / Horz Dist to nearest roadway   \n",
    "> Hillshade_9am / quantitative / 0 to 255 index / Hillshade index at 9am, summer solstice   \n",
    "> Hillshade_Noon / quantitative / 0 to 255 index / Hillshade index at noon, summer solstice   \n",
    "> Hillshade_3pm / quantitative / 0 to 255 index / Hillshade index at 3pm, summer solstice   \n",
    "> Horizontal_Distance_To_Fire_Points / quantitative / meters / Horz Dist to nearest wildfire ignition points   \n",
    "> Wilderness_Area (4 binary columns) / qualitative / 0 (absence) or 1 (presence) / Wilderness area designation   \n",
    "> Soil_Type (40 binary columns) / qualitative / 0 (absence) or 1 (presence) / Soil Type designation   \n",
    "> Cover_Type (7 types) / integer / 1 to 7 / Forest Cover Type designation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a6944eba-b6c9-4124-a109-6ad4053f9d40",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i=0 p(y=1) = 14.2% -> 43.6%\n",
      "    attribute          normal      adv.     diff.\n",
      "    elevation          2806.0    2840.0      34.0\n",
      "    aspect              185.0     163.0     -22.0\n",
      "    slope                 6.0       6.0      -0.0\n",
      "    hoz_dist_hydro      120.0      85.0     -35.0\n",
      "    ver_dist_hydro       18.0      29.0      11.0\n",
      "    hoz_dist_road      1127.0     967.0    -160.0\n",
      "    shade_9am           221.0     221.0       0.0\n",
      "    shade_noon          244.0     245.0       1.0\n",
      "    shade_3am           158.0     158.0       0.0\n",
      "    hoz_dist_fire      1124.0    1383.0     259.0\n",
      "i=1 p(y=1) = 14.7% -> 55.3%\n",
      "    attribute          normal      adv.     diff.\n",
      "    elevation          2952.0    2915.0     -37.0\n",
      "    aspect              270.0     263.0      -7.0\n",
      "    slope                 9.0       9.0      -0.0\n",
      "    hoz_dist_hydro       30.0      67.0      37.0\n",
      "    ver_dist_hydro       -1.0      10.0      11.0\n",
      "    hoz_dist_road       330.0     433.0     103.0\n",
      "    shade_9am           197.0     197.0       0.0\n",
      "    shade_noon          243.0     243.0       0.0\n",
      "    shade_3am           186.0     189.0       3.0\n",
      "    hoz_dist_fire       721.0     573.0    -148.0\n",
      "i=2 p(y=1) = 6.2% -> 86.4%\n",
      "    attribute          normal      adv.     diff.\n",
      "    elevation          3195.0    3250.0      55.0\n",
      "    aspect              176.0     169.0      -7.0\n",
      "    slope                15.0      15.0       0.0\n",
      "    hoz_dist_hydro      300.0     300.0       0.0\n",
      "    ver_dist_hydro       31.0      20.0     -11.0\n",
      "    hoz_dist_road       551.0     551.0       0.0\n",
      "    shade_9am           226.0     230.0       4.0\n",
      "    shade_noon          247.0     246.0      -1.0\n",
      "    shade_3am           149.0     149.0       0.0\n",
      "    hoz_dist_fire      2657.0    2918.0     261.0\n",
      "i=3 p(y=1) = 1.6% -> 74.7%\n",
      "    attribute          normal      adv.     diff.\n",
      "    elevation          2224.0    2224.0       0.0\n",
      "    aspect               21.0      32.0      11.0\n",
      "    slope                14.0      14.0       0.0\n",
      "    hoz_dist_hydro      270.0     258.0     -12.0\n",
      "    ver_dist_hydro       96.0      86.0     -10.0\n",
      "    hoz_dist_road       595.0     324.0    -271.0\n",
      "    shade_9am           210.0     210.0       0.0\n",
      "    shade_noon          209.0     209.0       0.0\n",
      "    shade_3am           133.0     131.0      -2.0\n",
      "    hoz_dist_fire       120.0     324.0     204.0\n",
      "i=4 p(y=1) = 1.1% -> 44.2%\n",
      "    attribute          normal      adv.     diff.\n",
      "    elevation          3254.0    3254.0       0.0\n",
      "    aspect              350.0     356.0       6.0\n",
      "    slope                13.0      13.0      -0.0\n",
      "    hoz_dist_hydro      201.0     170.0     -31.0\n",
      "    ver_dist_hydro       65.0      65.0      -0.0\n",
      "    hoz_dist_road      2731.0    2854.0     123.0\n",
      "    shade_9am           197.0     197.0       0.0\n",
      "    shade_noon          218.0     217.0      -1.0\n",
      "    shade_3am           159.0     155.0      -4.0\n",
      "    hoz_dist_fire      2574.0    2742.0     168.0\n",
      "i=5 p(y=1) = 11.0% -> 79.1%\n",
      "    attribute          normal      adv.     diff.\n",
      "    elevation          3084.0    3032.0     -52.0\n",
      "    aspect              177.0     175.0      -2.0\n",
      "    slope                16.0      15.0      -1.0\n",
      "    hoz_dist_hydro       42.0      60.0      18.0\n",
      "    ver_dist_hydro        4.0       9.0       5.0\n",
      "    hoz_dist_road      1550.0    1550.0       0.0\n",
      "    shade_9am           226.0     229.0       3.0\n",
      "    shade_noon          247.0     250.0       3.0\n",
      "    shade_3am           148.0     145.0      -3.0\n",
      "    hoz_dist_fire      2313.0    2073.0    -240.0\n",
      "i=6 p(y=1) = 5.3% -> 67.6%\n",
      "    attribute          normal      adv.     diff.\n",
      "    elevation          3049.0    3011.0     -38.0\n",
      "    aspect              327.0     338.0      11.0\n",
      "    slope                11.0      11.0       0.0\n",
      "    hoz_dist_hydro      212.0     240.0      28.0\n",
      "    ver_dist_hydro       -4.0       6.0      10.0\n",
      "    hoz_dist_road      2336.0    2058.0    -278.0\n",
      "    shade_9am           193.0     189.0      -4.0\n",
      "    shade_noon          228.0     229.0       1.0\n",
      "    shade_3am           174.0     171.0      -3.0\n",
      "    hoz_dist_fire       757.0     999.0     242.0\n",
      "i=7 p(y=1) = 2.7% -> 36.8%\n",
      "    attribute          normal      adv.     diff.\n",
      "    elevation          3253.0    3230.0     -23.0\n",
      "    aspect              102.0     114.0      12.0\n",
      "    slope                 8.0       7.0      -1.0\n",
      "    hoz_dist_hydro       85.0      85.0       0.0\n",
      "    ver_dist_hydro       13.0      22.0       9.0\n",
      "    hoz_dist_road      1624.0    1570.0     -54.0\n",
      "    shade_9am           234.0     235.0       1.0\n",
      "    shade_noon          230.0     232.0       2.0\n",
      "    shade_3am           128.0     128.0       0.0\n",
      "    hoz_dist_fire       845.0     607.0    -238.0\n",
      "i=8 p(y=1) = 14.7% -> 93.9%\n",
      "    attribute          normal      adv.     diff.\n",
      "    elevation          3011.0    3044.0      33.0\n",
      "    aspect              180.0     185.0       5.0\n",
      "    slope                17.0      18.0       1.0\n",
      "    hoz_dist_hydro      247.0     285.0      38.0\n",
      "    ver_dist_hydro       23.0      16.0      -7.0\n",
      "    hoz_dist_road      5574.0    5851.0     277.0\n",
      "    shade_9am           224.0     228.0       4.0\n",
      "    shade_noon          248.0     249.0       1.0\n",
      "    shade_3am           151.0     145.0      -6.0\n",
      "    hoz_dist_fire      1936.0    1959.0      23.0\n",
      "i=9 p(y=1) = 74.5% -> 99.1%\n",
      "    attribute          normal      adv.     diff.\n",
      "    elevation          3034.0    3016.0     -18.0\n",
      "    aspect              140.0     154.0      14.0\n",
      "    slope                24.0      23.0      -1.0\n",
      "    hoz_dist_hydro      979.0     979.0       0.0\n",
      "    ver_dist_hydro      149.0     156.0       7.0\n",
      "    hoz_dist_road      3108.0    3150.0      42.0\n",
      "    shade_9am           247.0     245.0      -2.0\n",
      "    shade_noon          225.0     227.0       2.0\n",
      "    shade_3am            91.0      87.0      -4.0\n",
      "    hoz_dist_fire      1606.0    1606.0       0.0\n"
     ]
    }
   ],
   "source": [
    "import pprint\n",
    "attribute_names = [\"elevation\", \"aspect\", \"slope\", \"hoz_dist_hydro\", \"ver_dist_hydro\", \"hoz_dist_road\", \"shade_9am\", \"shade_noon\", \"shade_3am\", \"hoz_dist_fire\"]# + [f\"wilderness{k}\" for k in range(4)] + [f\"soil{k}\" for k in range(40)]\n",
    "\n",
    "for adv in adv_examples:\n",
    "    i = adv[\"i\"]\n",
    "    base_ex, adv_ex = adv[\"base_example\"], adv_examples[i][\"adv_example\"]\n",
    "    print(f'i={i} p(y=1) = {adv[\"base_ypred\"]*100:.1f}% -> {adv[\"adv_ypred\"]*100:.1f}%')\n",
    "    print(\"    {:15}{:>10}{:>10}{:>10}\".format(\"attribute\", \"normal\", \"adv.\", \"diff.\"))\n",
    "    for j, attrname in enumerate(attribute_names):\n",
    "        b, a = base_ex[j], adv_ex[j]\n",
    "        print(\"    {:15}{:>10}{:>10}{:>10}\".format(attrname, f'{b:.1f}', f'{a:.1f}', f'{a-b:.1f}'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a534ef-c906-45a8-b51d-5c12baffbdc8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
