{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import argparse\n",
    "import cxplain\n",
    "\n",
    "from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda\n",
    "from tensorflow.python.keras.layers.normalization import BatchNormalization\n",
    "from tensorflow.python.keras import regularizers\n",
    "from tensorflow.python.keras.models import Model, Sequential\n",
    "\n",
    "from cxplain import MLPModelBuilder, ZeroMasking, CXPlain\n",
    "from tensorflow.python.keras.losses import categorical_crossentropy\n",
    "\n",
    "\n",
    "from utils.explanations import calculate_robust_astute_sampled\n",
    "\n",
    "\n",
    "np.random.seed(0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cxplain_explainer(datatype, ball_r, epsilon, prop_points, exponentiate, classifier):\n",
    "    data_dict = pickle.load(open('data/' + datatype + '.pk', 'rb'))\n",
    "    x_train, y_train, x_val, y_val, datatype_val, input_shape = data_dict['x_train'], data_dict['y_train'], \\\n",
    "                                                                data_dict['x_val'], data_dict['y_val'], \\\n",
    "                                                                data_dict['datatype_val'], data_dict['input_shape']\n",
    "    \n",
    "    masking_operation = ZeroMasking()\n",
    "    loss = categorical_crossentropy\n",
    "    if classifier == '2layer':\n",
    "        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'\n",
    "        model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "        net = Dense(200, activation=activation, name='dense1',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "        net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "        net = Dense(200, activation=activation, name='dense2',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        net = BatchNormalization()(net)\n",
    "        preds = Dense(2, activation='softmax', name='dense4',\n",
    "                      kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        bbox_model = Model(model_input, preds)\n",
    "        bbox_model.load_weights('models/' + datatype + '_blackbox.hdf5',\n",
    "                                by_name=True)\n",
    "        pred_model = Model(model_input, preds)\n",
    "\n",
    "        model_builder = MLPModelBuilder(num_layers=2, num_units=200, activation=activation, verbose=1,\n",
    "                                        batch_size=1000, learning_rate=0.001, num_epochs=5, early_stopping_patience=15,\n",
    "                                        with_bn=True)\n",
    "\n",
    "    elif classifier == '4layer':\n",
    "        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'\n",
    "\n",
    "        model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "        net = Dense(50, activation=activation, name='dense1',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "        net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "        net = Dense(50, activation=activation, name='dense2',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        net = BatchNormalization()(net)\n",
    "        net = Dense(50, activation=activation, name='dense3',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        net = BatchNormalization()(net)\n",
    "        net = Dense(50, activation=activation, name='dense4',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        net = BatchNormalization()(net)\n",
    "        preds = Dense(2, activation='softmax', name='dense5',\n",
    "                      kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        bbox_model = Model(model_input, preds)\n",
    "        bbox_model.load_weights('models/' + datatype + '_blackbox_extra.hdf5',\n",
    "                                by_name=True)\n",
    "        pred_model = Model(model_input, preds)\n",
    "        model_builder = MLPModelBuilder(num_layers=4, num_units=50, activation=activation, verbose=1,\n",
    "                                        batch_size=1000, learning_rate=0.001, num_epochs=5, early_stopping_patience=15,\n",
    "                                        with_bn=True)\n",
    "\n",
    "    elif classifier == 'linear':\n",
    "        activation = None\n",
    "\n",
    "        model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "\n",
    "        net = Dense(200, activation=activation, name='dense1',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "        net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "        preds = Dense(2, activation='softmax', name='dense4',\n",
    "                      kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        bbox_model = Model(model_input, preds)\n",
    "        bbox_model.load_weights('models/' + datatype + '_blackbox_linear.hdf5',\n",
    "                                by_name=True)\n",
    "        pred_model = Model(model_input, preds)\n",
    "        model_builder = MLPModelBuilder(num_layers=1, num_units=200, activation=activation, verbose=1,\n",
    "                                        batch_size=1000, learning_rate=0.001, num_epochs=5, early_stopping_patience=15,\n",
    "                                        with_bn=True)\n",
    "\n",
    "\n",
    "    elif classifier == 'svm':\n",
    "        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'\n",
    "        pred_model = pickle.load(open('models/' + datatype + '_svm.pk', 'rb'))\n",
    "        model_builder = MLPModelBuilder(num_layers=2, num_units=200, activation=activation, verbose=1,\n",
    "                                        batch_size=1000, learning_rate=0.001, num_epochs=5, early_stopping_patience=15,\n",
    "                                        with_bn=True)\n",
    "    if classifier == 'svm':\n",
    "        training_indices = np.random.choice(len(x_train), int(0.01*len(x_train)), replace=False)\n",
    "        explainer = CXPlain(pred_model, model_builder, masking_operation, loss, num_models=1)\n",
    "        explainer.fit(x_train, y_train)\n",
    "        explanation = calculate_robust_astute_sampled(data=x_val,\n",
    "                                                   explainer=explainer,\n",
    "                                                   explainer_type='cxplain',\n",
    "                                                   explanation_type='attribution',\n",
    "                                                   ball_r=ball_r,\n",
    "                                                   epsilon=epsilon,\n",
    "                                                   num_points=int(prop_points * len(x_val)),\n",
    "                                                   exponentiate=exponentiate,\n",
    "                                                   calculate_astuteness=False,\n",
    "                                                   NN=False)\n",
    "    else:\n",
    "        training_indices = np.random.choice(len(x_train), int(0.01*len(x_train)), replace=False)\n",
    "        explainer = CXPlain(pred_model, model_builder, masking_operation, loss, num_models=1)\n",
    "        explainer.fit(x_train, y_train)\n",
    "        explanation = calculate_robust_astute_sampled(data=x_val,\n",
    "                                                      explainer=explainer,\n",
    "                                                      explainer_type='cxplain',\n",
    "                                                      explanation_type='attribution',\n",
    "                                                      ball_r=ball_r,\n",
    "                                                      epsilon=epsilon,\n",
    "                                                      num_points=int(prop_points * len(x_val)),\n",
    "                                                      exponentiate=exponentiate,\n",
    "                                                      calculate_astuteness=False)\n",
    "\n",
    "    del pred_model\n",
    "    return np.abs(explanation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-09-29 20:39:10.464585: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA\n",
      "2021-09-29 20:39:10.485944: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3600000000 Hz\n",
      "2021-09-29 20:39:10.486599: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x556091789e20 executing computations on platform Host. Devices:\n",
      "2021-09-29 20:39:10.486622: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/keras/utils/losses_utils.py:170: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n",
      "WARNING:tensorflow:From /home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/network.py:1436: update_checkpoint_state (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.train.CheckpointManager to manage checkpoints rather than manually editing the Checkpoint proto.\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "WARNING:tensorflow:From /home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 11s 12us/sample - loss: 0.4103 - dense_2_loss: 0.3736 - all_loss: 0.0080 - lambda_1_loss: 0.0286 - val_loss: 0.2640 - val_dense_2_loss: 0.2271 - val_all_loss: 0.0081 - val_lambda_1_loss: 0.0288\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 10s 11us/sample - loss: 0.2221 - dense_2_loss: 0.1855 - all_loss: 0.0080 - lambda_1_loss: 0.0286 - val_loss: 0.2134 - val_dense_2_loss: 0.1766 - val_all_loss: 0.0081 - val_lambda_1_loss: 0.0288\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 10s 12us/sample - loss: 0.1860 - dense_2_loss: 0.1494 - all_loss: 0.0080 - lambda_1_loss: 0.0286 - val_loss: 0.1727 - val_dense_2_loss: 0.1359 - val_all_loss: 0.0081 - val_lambda_1_loss: 0.0288\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 10s 11us/sample - loss: 0.1709 - dense_2_loss: 0.1342 - all_loss: 0.0080 - lambda_1_loss: 0.0286 - val_loss: 0.1674 - val_dense_2_loss: 0.1305 - val_all_loss: 0.0081 - val_lambda_1_loss: 0.0288\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 11s 13us/sample - loss: 0.1605 - dense_2_loss: 0.1238 - all_loss: 0.0080 - lambda_1_loss: 0.0286 - val_loss: 0.1536 - val_dense_2_loss: 0.1168 - val_all_loss: 0.0081 - val_lambda_1_loss: 0.0288\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 6s 7us/sample - loss: 1.1868 - dense_4_loss: 0.7790 - all_loss: 0.2029 - lambda_3_loss: 0.2049 - val_loss: 1.1627 - val_dense_4_loss: 0.7551 - val_all_loss: 0.2028 - val_lambda_3_loss: 0.2048\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 5s 6us/sample - loss: 1.1639 - dense_4_loss: 0.7562 - all_loss: 0.2029 - lambda_3_loss: 0.2049 - val_loss: 1.1637 - val_dense_4_loss: 0.7561 - val_all_loss: 0.2028 - val_lambda_3_loss: 0.2048\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 5s 6us/sample - loss: 1.1639 - dense_4_loss: 0.7561 - all_loss: 0.2029 - lambda_3_loss: 0.2049 - val_loss: 1.1631 - val_dense_4_loss: 0.7555 - val_all_loss: 0.2028 - val_lambda_3_loss: 0.2048\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 5s 6us/sample - loss: 1.1637 - dense_4_loss: 0.7559 - all_loss: 0.2029 - lambda_3_loss: 0.2049 - val_loss: 1.1622 - val_dense_4_loss: 0.7546 - val_all_loss: 0.2028 - val_lambda_3_loss: 0.2048\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 5s 6us/sample - loss: 1.1635 - dense_4_loss: 0.7558 - all_loss: 0.2029 - lambda_3_loss: 0.2049 - val_loss: 1.1634 - val_dense_4_loss: 0.7559 - val_all_loss: 0.2028 - val_lambda_3_loss: 0.2048\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 11s 12us/sample - loss: 1.0521 - dense_9_loss: 0.3489 - all_loss: 0.3495 - lambda_5_loss: 0.3536 - val_loss: 0.8994 - val_dense_9_loss: 0.1962 - val_all_loss: 0.3496 - val_lambda_5_loss: 0.3536\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 11s 12us/sample - loss: 0.8758 - dense_9_loss: 0.1726 - all_loss: 0.3495 - lambda_5_loss: 0.3536 - val_loss: 0.8509 - val_dense_9_loss: 0.1477 - val_all_loss: 0.3496 - val_lambda_5_loss: 0.3536\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 11s 12us/sample - loss: 0.8507 - dense_9_loss: 0.1475 - all_loss: 0.3495 - lambda_5_loss: 0.3536 - val_loss: 0.8613 - val_dense_9_loss: 0.1581 - val_all_loss: 0.3496 - val_lambda_5_loss: 0.3536\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 10s 11us/sample - loss: 0.8346 - dense_9_loss: 0.1315 - all_loss: 0.3495 - lambda_5_loss: 0.3536 - val_loss: 0.8143 - val_dense_9_loss: 0.1111 - val_all_loss: 0.3496 - val_lambda_5_loss: 0.3536\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 10s 11us/sample - loss: 0.8230 - dense_9_loss: 0.1198 - all_loss: 0.3495 - lambda_5_loss: 0.3536 - val_loss: 0.8118 - val_dense_9_loss: 0.1086 - val_all_loss: 0.3496 - val_lambda_5_loss: 0.3536\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 11s 13us/sample - loss: 0.5569 - dense_12_loss: 0.4032 - all_loss: 0.0728 - lambda_7_loss: 0.0809 - val_loss: 0.4400 - val_dense_12_loss: 0.2856 - val_all_loss: 0.0731 - val_lambda_7_loss: 0.0813\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 9s 10us/sample - loss: 0.3811 - dense_12_loss: 0.2274 - all_loss: 0.0728 - lambda_7_loss: 0.0809 - val_loss: 0.3901 - val_dense_12_loss: 0.2358 - val_all_loss: 0.0731 - val_lambda_7_loss: 0.0813\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 11s 12us/sample - loss: 0.3468 - dense_12_loss: 0.1931 - all_loss: 0.0728 - lambda_7_loss: 0.0809 - val_loss: 0.3305 - val_dense_12_loss: 0.1761 - val_all_loss: 0.0731 - val_lambda_7_loss: 0.0813\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 9s 10us/sample - loss: 0.3261 - dense_12_loss: 0.1724 - all_loss: 0.0728 - lambda_7_loss: 0.0809 - val_loss: 0.3177 - val_dense_12_loss: 0.1633 - val_all_loss: 0.0731 - val_lambda_7_loss: 0.0813\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 9s 10us/sample - loss: 0.3117 - dense_12_loss: 0.1580 - all_loss: 0.0728 - lambda_7_loss: 0.0809 - val_loss: 0.2977 - val_dense_12_loss: 0.1434 - val_all_loss: 0.0731 - val_lambda_7_loss: 0.0813\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 6s 6us/sample - loss: 0.9994 - dense_15_loss: 0.2405 - all_loss: 0.3684 - lambda_9_loss: 0.3905 - val_loss: 0.9221 - val_dense_15_loss: 0.1641 - val_all_loss: 0.3680 - val_lambda_9_loss: 0.3900\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 6s 6us/sample - loss: 0.8972 - dense_15_loss: 0.1383 - all_loss: 0.3684 - lambda_9_loss: 0.3905 - val_loss: 0.8793 - val_dense_15_loss: 0.1212 - val_all_loss: 0.3680 - val_lambda_9_loss: 0.3900\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 6s 7us/sample - loss: 0.8743 - dense_15_loss: 0.1154 - all_loss: 0.3684 - lambda_9_loss: 0.3905 - val_loss: 0.8683 - val_dense_15_loss: 0.1102 - val_all_loss: 0.3680 - val_lambda_9_loss: 0.3900\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 7s 7us/sample - loss: 0.8622 - dense_15_loss: 0.1033 - all_loss: 0.3684 - lambda_9_loss: 0.3905 - val_loss: 0.8643 - val_dense_15_loss: 0.1063 - val_all_loss: 0.3680 - val_lambda_9_loss: 0.3900\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 8s 9us/sample - loss: 0.8541 - dense_15_loss: 0.0952 - all_loss: 0.3684 - lambda_9_loss: 0.3905 - val_loss: 0.8447 - val_dense_15_loss: 0.0866 - val_all_loss: 0.3680 - val_lambda_9_loss: 0.3900\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 4s 5us/sample - loss: 0.9652 - dense_17_loss: 0.6597 - all_loss: 0.1415 - lambda_11_loss: 0.1641 - val_loss: 0.9341 - val_dense_17_loss: 0.6279 - val_all_loss: 0.1419 - val_lambda_11_loss: 0.1643\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 3s 4us/sample - loss: 0.9339 - dense_17_loss: 0.6284 - all_loss: 0.1415 - lambda_11_loss: 0.1641 - val_loss: 0.9348 - val_dense_17_loss: 0.6286 - val_all_loss: 0.1419 - val_lambda_11_loss: 0.1643\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 4s 4us/sample - loss: 0.9339 - dense_17_loss: 0.6283 - all_loss: 0.1415 - lambda_11_loss: 0.1641 - val_loss: 0.9357 - val_dense_17_loss: 0.6295 - val_all_loss: 0.1419 - val_lambda_11_loss: 0.1643\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 4s 4us/sample - loss: 0.9338 - dense_17_loss: 0.6282 - all_loss: 0.1415 - lambda_11_loss: 0.1641 - val_loss: 0.9342 - val_dense_17_loss: 0.6280 - val_all_loss: 0.1419 - val_lambda_11_loss: 0.1643\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 4s 5us/sample - loss: 0.9336 - dense_17_loss: 0.6280 - all_loss: 0.1415 - lambda_11_loss: 0.1641 - val_loss: 0.9341 - val_dense_17_loss: 0.6279 - val_all_loss: 0.1419 - val_lambda_11_loss: 0.1643\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 12s 13us/sample - loss: 0.6307 - dense_22_loss: 0.3189 - all_loss: 0.1369 - lambda_13_loss: 0.1748 - val_loss: 0.5130 - val_dense_22_loss: 0.2013 - val_all_loss: 0.1368 - val_lambda_13_loss: 0.1748\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 13s 14us/sample - loss: 0.4920 - dense_22_loss: 0.1802 - all_loss: 0.1369 - lambda_13_loss: 0.1748 - val_loss: 0.4725 - val_dense_22_loss: 0.1608 - val_all_loss: 0.1368 - val_lambda_13_loss: 0.1748\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 10s 11us/sample - loss: 0.4637 - dense_22_loss: 0.1519 - all_loss: 0.1369 - lambda_13_loss: 0.1748 - val_loss: 0.4524 - val_dense_22_loss: 0.1407 - val_all_loss: 0.1368 - val_lambda_13_loss: 0.1748\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 10s 11us/sample - loss: 0.4467 - dense_22_loss: 0.1350 - all_loss: 0.1369 - lambda_13_loss: 0.1748 - val_loss: 0.4319 - val_dense_22_loss: 0.1202 - val_all_loss: 0.1368 - val_lambda_13_loss: 0.1748\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 9s 11us/sample - loss: 0.4367 - dense_22_loss: 0.1250 - all_loss: 0.1369 - lambda_13_loss: 0.1748 - val_loss: 0.4320 - val_dense_22_loss: 0.1203 - val_all_loss: 0.1368 - val_lambda_13_loss: 0.1748\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 8s 9us/sample - loss: 0.5065 - dense_25_loss: 0.3976 - all_loss: 0.0373 - lambda_15_loss: 0.0715 - val_loss: 0.3984 - val_dense_25_loss: 0.2892 - val_all_loss: 0.0372 - val_lambda_15_loss: 0.0719\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 9s 10us/sample - loss: 0.3702 - dense_25_loss: 0.2614 - all_loss: 0.0373 - lambda_15_loss: 0.0715 - val_loss: 0.3693 - val_dense_25_loss: 0.2601 - val_all_loss: 0.0372 - val_lambda_15_loss: 0.0719\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 9s 10us/sample - loss: 0.3457 - dense_25_loss: 0.2369 - all_loss: 0.0373 - lambda_15_loss: 0.0715 - val_loss: 0.3422 - val_dense_25_loss: 0.2331 - val_all_loss: 0.0372 - val_lambda_15_loss: 0.0719\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 9s 10us/sample - loss: 0.3295 - dense_25_loss: 0.2206 - all_loss: 0.0373 - lambda_15_loss: 0.0715 - val_loss: 0.3195 - val_dense_25_loss: 0.2103 - val_all_loss: 0.0372 - val_lambda_15_loss: 0.0719\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 9s 10us/sample - loss: 0.3175 - dense_25_loss: 0.2086 - all_loss: 0.0373 - lambda_15_loss: 0.0715 - val_loss: 0.3057 - val_dense_25_loss: 0.1966 - val_all_loss: 0.0372 - val_lambda_15_loss: 0.0719\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 7s 8us/sample - loss: 1.1389 - dense_28_loss: 0.3459 - all_loss: 0.3916 - lambda_17_loss: 0.4015 - val_loss: 1.0210 - val_dense_28_loss: 0.2316 - val_all_loss: 0.3897 - val_lambda_17_loss: 0.3997\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 6s 7us/sample - loss: 0.9628 - dense_28_loss: 0.1697 - all_loss: 0.3916 - lambda_17_loss: 0.4015 - val_loss: 0.9386 - val_dense_28_loss: 0.1492 - val_all_loss: 0.3897 - val_lambda_17_loss: 0.3997\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 8s 9us/sample - loss: 0.9204 - dense_28_loss: 0.1274 - all_loss: 0.3916 - lambda_17_loss: 0.4015 - val_loss: 0.9054 - val_dense_28_loss: 0.1160 - val_all_loss: 0.3897 - val_lambda_17_loss: 0.3997\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 13s 15us/sample - loss: 0.9046 - dense_28_loss: 0.1116 - all_loss: 0.3916 - lambda_17_loss: 0.4015 - val_loss: 0.8929 - val_dense_28_loss: 0.1034 - val_all_loss: 0.3897 - val_lambda_17_loss: 0.3997\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 9s 10us/sample - loss: 0.8932 - dense_28_loss: 0.1002 - all_loss: 0.3916 - lambda_17_loss: 0.4015 - val_loss: 0.8815 - val_dense_28_loss: 0.0921 - val_all_loss: 0.3897 - val_lambda_17_loss: 0.3997\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 12s 13us/sample - loss: 1.1346 - dense_30_loss: 0.7866 - all_loss: 0.1689 - lambda_19_loss: 0.1792 - val_loss: 1.0977 - val_dense_30_loss: 0.7508 - val_all_loss: 0.1682 - val_lambda_19_loss: 0.1786\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 8s 9us/sample - loss: 1.1004 - dense_30_loss: 0.7524 - all_loss: 0.1689 - lambda_19_loss: 0.1792 - val_loss: 1.0984 - val_dense_30_loss: 0.7515 - val_all_loss: 0.1682 - val_lambda_19_loss: 0.1786\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 7s 8us/sample - loss: 1.1003 - dense_30_loss: 0.7522 - all_loss: 0.1689 - lambda_19_loss: 0.1792 - val_loss: 1.0979 - val_dense_30_loss: 0.7510 - val_all_loss: 0.1682 - val_lambda_19_loss: 0.1786\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 8s 9us/sample - loss: 1.1001 - dense_30_loss: 0.7520 - all_loss: 0.1689 - lambda_19_loss: 0.1792 - val_loss: 1.1006 - val_dense_30_loss: 0.7538 - val_all_loss: 0.1682 - val_lambda_19_loss: 0.1786\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 8s 9us/sample - loss: 1.1000 - dense_30_loss: 0.7520 - all_loss: 0.1689 - lambda_19_loss: 0.1792 - val_loss: 1.0981 - val_dense_30_loss: 0.7513 - val_all_loss: 0.1682 - val_lambda_19_loss: 0.1786\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 14s 16us/sample - loss: 1.0325 - dense_35_loss: 0.4391 - all_loss: 0.2825 - lambda_21_loss: 0.3110 - val_loss: 0.8806 - val_dense_35_loss: 0.2912 - val_all_loss: 0.2804 - val_lambda_21_loss: 0.3090\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 24s 27us/sample - loss: 0.8477 - dense_35_loss: 0.2543 - all_loss: 0.2825 - lambda_21_loss: 0.3110 - val_loss: 0.8221 - val_dense_35_loss: 0.2327 - val_all_loss: 0.2804 - val_lambda_21_loss: 0.3090\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 16s 17us/sample - loss: 0.8178 - dense_35_loss: 0.2244 - all_loss: 0.2825 - lambda_21_loss: 0.3110 - val_loss: 0.8127 - val_dense_35_loss: 0.2234 - val_all_loss: 0.2804 - val_lambda_21_loss: 0.3090\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 15s 17us/sample - loss: 0.8000 - dense_35_loss: 0.2066 - all_loss: 0.2825 - lambda_21_loss: 0.3110 - val_loss: 0.7798 - val_dense_35_loss: 0.1904 - val_all_loss: 0.2804 - val_lambda_21_loss: 0.3090\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 15s 17us/sample - loss: 0.7863 - dense_35_loss: 0.1929 - all_loss: 0.2825 - lambda_21_loss: 0.3110 - val_loss: 0.7768 - val_dense_35_loss: 0.1874 - val_all_loss: 0.2804 - val_lambda_21_loss: 0.3090\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 18s 20us/sample - loss: 0.6819 - dense_38_loss: 0.4995 - all_loss: 0.0821 - lambda_23_loss: 0.1003 - val_loss: 0.5739 - val_dense_38_loss: 0.3927 - val_all_loss: 0.0815 - val_lambda_23_loss: 0.0997\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 14s 16us/sample - loss: 0.5282 - dense_38_loss: 0.3458 - all_loss: 0.0821 - lambda_23_loss: 0.1003 - val_loss: 0.4903 - val_dense_38_loss: 0.3091 - val_all_loss: 0.0815 - val_lambda_23_loss: 0.0997\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 14s 16us/sample - loss: 0.4721 - dense_38_loss: 0.2897 - all_loss: 0.0821 - lambda_23_loss: 0.1003 - val_loss: 0.4499 - val_dense_38_loss: 0.2687 - val_all_loss: 0.0815 - val_lambda_23_loss: 0.0997\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 15s 16us/sample - loss: 0.4358 - dense_38_loss: 0.2535 - all_loss: 0.0821 - lambda_23_loss: 0.1003 - val_loss: 0.4240 - val_dense_38_loss: 0.2428 - val_all_loss: 0.0815 - val_lambda_23_loss: 0.0997\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 15s 16us/sample - loss: 0.4129 - dense_38_loss: 0.2305 - all_loss: 0.0821 - lambda_23_loss: 0.1003 - val_loss: 0.4166 - val_dense_38_loss: 0.2354 - val_all_loss: 0.0815 - val_lambda_23_loss: 0.0997\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 17s 19us/sample - loss: 1.1110 - dense_41_loss: 0.3209 - all_loss: 0.3924 - lambda_25_loss: 0.3977 - val_loss: 0.9871 - val_dense_41_loss: 0.1945 - val_all_loss: 0.3935 - val_lambda_25_loss: 0.3991\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 14s 16us/sample - loss: 0.9336 - dense_41_loss: 0.1436 - all_loss: 0.3924 - lambda_25_loss: 0.3977 - val_loss: 0.9160 - val_dense_41_loss: 0.1235 - val_all_loss: 0.3935 - val_lambda_25_loss: 0.3991\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 14s 16us/sample - loss: 0.9064 - dense_41_loss: 0.1163 - all_loss: 0.3924 - lambda_25_loss: 0.3977 - val_loss: 0.8954 - val_dense_41_loss: 0.1029 - val_all_loss: 0.3935 - val_lambda_25_loss: 0.3991\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 15s 16us/sample - loss: 0.8928 - dense_41_loss: 0.1028 - all_loss: 0.3924 - lambda_25_loss: 0.3977 - val_loss: 0.9076 - val_dense_41_loss: 0.1151 - val_all_loss: 0.3935 - val_lambda_25_loss: 0.3991\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 14s 16us/sample - loss: 0.8844 - dense_41_loss: 0.0944 - all_loss: 0.3924 - lambda_25_loss: 0.3977 - val_loss: 0.8867 - val_dense_41_loss: 0.0941 - val_all_loss: 0.3935 - val_lambda_25_loss: 0.3991\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 10s 12us/sample - loss: 1.1838 - dense_43_loss: 0.8089 - all_loss: 0.1857 - lambda_27_loss: 0.1892 - val_loss: 1.1657 - val_dense_43_loss: 0.7921 - val_all_loss: 0.1850 - val_lambda_27_loss: 0.1886\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 8s 9us/sample - loss: 1.1671 - dense_43_loss: 0.7922 - all_loss: 0.1857 - lambda_27_loss: 0.1892 - val_loss: 1.1667 - val_dense_43_loss: 0.7932 - val_all_loss: 0.1850 - val_lambda_27_loss: 0.1886\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 8s 9us/sample - loss: 1.1668 - dense_43_loss: 0.7919 - all_loss: 0.1857 - lambda_27_loss: 0.1892 - val_loss: 1.1659 - val_dense_43_loss: 0.7924 - val_all_loss: 0.1850 - val_lambda_27_loss: 0.1886\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 7s 8us/sample - loss: 1.1667 - dense_43_loss: 0.7918 - all_loss: 0.1857 - lambda_27_loss: 0.1892 - val_loss: 1.1681 - val_dense_43_loss: 0.7945 - val_all_loss: 0.1850 - val_lambda_27_loss: 0.1886\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 7s 8us/sample - loss: 1.1666 - dense_43_loss: 0.7917 - all_loss: 0.1857 - lambda_27_loss: 0.1892 - val_loss: 1.1655 - val_dense_43_loss: 0.7919 - val_all_loss: 0.1850 - val_lambda_27_loss: 0.1886\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 17s 19us/sample - loss: 1.1182 - dense_48_loss: 0.3696 - all_loss: 0.3716 - lambda_29_loss: 0.3771 - val_loss: 0.9790 - val_dense_48_loss: 0.2292 - val_all_loss: 0.3720 - val_lambda_29_loss: 0.3778\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 13s 15us/sample - loss: 0.9444 - dense_48_loss: 0.1957 - all_loss: 0.3716 - lambda_29_loss: 0.3771 - val_loss: 0.9198 - val_dense_48_loss: 0.1700 - val_all_loss: 0.3720 - val_lambda_29_loss: 0.3778\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 12s 14us/sample - loss: 0.9143 - dense_48_loss: 0.1656 - all_loss: 0.3716 - lambda_29_loss: 0.3771 - val_loss: 0.8953 - val_dense_48_loss: 0.1455 - val_all_loss: 0.3720 - val_lambda_29_loss: 0.3778\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 12s 14us/sample - loss: 0.8974 - dense_48_loss: 0.1488 - all_loss: 0.3716 - lambda_29_loss: 0.3771 - val_loss: 0.8901 - val_dense_48_loss: 0.1403 - val_all_loss: 0.3720 - val_lambda_29_loss: 0.3778\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 14s 15us/sample - loss: 0.8861 - dense_48_loss: 0.1375 - all_loss: 0.3716 - lambda_29_loss: 0.3771 - val_loss: 0.8766 - val_dense_48_loss: 0.1268 - val_all_loss: 0.3720 - val_lambda_29_loss: 0.3778\n",
      "Train on 900000 samples, validate on 100000 samples\n",
      "Epoch 1/5\n",
      "900000/900000 [==============================] - 16s 18us/sample - loss: 0.6589 - dense_51_loss: 0.4795 - all_loss: 0.0851 - lambda_31_loss: 0.0944 - val_loss: 0.5699 - val_dense_51_loss: 0.3901 - val_all_loss: 0.0852 - val_lambda_31_loss: 0.0947\n",
      "Epoch 2/5\n",
      "900000/900000 [==============================] - 13s 14us/sample - loss: 0.4836 - dense_51_loss: 0.3042 - all_loss: 0.0851 - lambda_31_loss: 0.0944 - val_loss: 0.4516 - val_dense_51_loss: 0.2717 - val_all_loss: 0.0852 - val_lambda_31_loss: 0.0947\n",
      "Epoch 3/5\n",
      "900000/900000 [==============================] - 14s 15us/sample - loss: 0.4381 - dense_51_loss: 0.2587 - all_loss: 0.0850 - lambda_31_loss: 0.0944 - val_loss: 0.4246 - val_dense_51_loss: 0.2448 - val_all_loss: 0.0852 - val_lambda_31_loss: 0.0947\n",
      "Epoch 4/5\n",
      "900000/900000 [==============================] - 14s 15us/sample - loss: 0.4125 - dense_51_loss: 0.2331 - all_loss: 0.0850 - lambda_31_loss: 0.0944 - val_loss: 0.4099 - val_dense_51_loss: 0.2300 - val_all_loss: 0.0852 - val_lambda_31_loss: 0.0947\n",
      "Epoch 5/5\n",
      "900000/900000 [==============================] - 14s 15us/sample - loss: 0.3947 - dense_51_loss: 0.2153 - all_loss: 0.0851 - lambda_31_loss: 0.0944 - val_loss: 0.3971 - val_dense_51_loss: 0.2172 - val_all_loss: 0.0852 - val_lambda_31_loss: 0.0947\n"
     ]
    }
   ],
   "source": [
    "ball_radius = 2\n",
    "epsilon = 0.05\n",
    "prop_points = 0.05\n",
    "run_times = 1\n",
    "exponentiate = 0\n",
    "classifiers = ['2layer', 'linear', '4layer', 'svm']\n",
    "\n",
    "for datatype in ['switch_[10, 10, 10, 70]gt', 'switch_[30, 60, 10]gt', \n",
    "                 'switch_[25, 30, 15, 30]gt', 'switch_[25, 15, 10, 50]gt']:\n",
    "    for c in range(len(classifiers)):\n",
    "        for i in range(run_times):\n",
    "            fname = 'explained_weights/cxplain/' + 'cxplain_' + datatype + '_' + classifiers[c] + '_' + str(i) + '.gz'\n",
    "            explanation = cxplain_explainer(datatype=datatype,\n",
    "                                           ball_r=ball_radius,\n",
    "                                           epsilon=epsilon,\n",
    "                                           prop_points=prop_points,\n",
    "                                           exponentiate=exponentiate,\n",
    "                                           classifier=classifiers[c])\n",
    "            np.savetxt(X=explanation, fname=fname, delimiter=',')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
