{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from warnings import simplefilter \n",
    "simplefilter(action='ignore', category=FutureWarning)\n",
    "import numpy as np\n",
    "import os\n",
    "from  natsort import natsorted\n",
    "import imageio\n",
    "import time\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.models import Sequential, Model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.layers import Dense, Activation, Flatten, Dropout, BatchNormalization, Conv1D\n",
    "from tensorflow.keras.layers import Input, GlobalAveragePooling1D, MaxPooling1D, Dot, Multiply\n",
    "from tensorflow.keras.layers import Conv2D, MaxPooling2D, Bidirectional, LSTM, Permute, Reshape\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard\n",
    "from tensorflow.keras.models import load_model, Model\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.metrics import f1_score, accuracy_score\n",
    "from sklearn.metrics import precision_recall_curve\n",
    "from sklearn.metrics import roc_curve, auc, roc_auc_score\n",
    "import matplotlib.pyplot as plt\n",
    "import itertools\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "features_dir = '/mnt/sdb1/Large_Datasets/Feature_vectors'\n",
    "labels_df_filtered = pd.read_pickle('labels_df_filtered_100.pkl')\n",
    "labels = labels_df_filtered.copy()\n",
    "\n",
    "\n",
    "try:\n",
    "    labels = pd.read_pickle('labels_100')\n",
    "    X = np.load('data.npy')\n",
    "    \n",
    "except:\n",
    "\n",
    "    X = []\n",
    "    for id in labels_df_filtered['Id']:\n",
    "        feature_path = os.path.join(features_dir,id+'.npy')\n",
    "        if os.path.isfile(feature_path):\n",
    "            X.append(np.load(feature_path).T)\n",
    "\n",
    "        else:\n",
    "            labels = labels[labels['Id']!=id]\n",
    "\n",
    "    labels = labels.reset_index(drop=True)\n",
    "    labels.to_pickle('labels_78')\n",
    "    X = np.stack(X, axis=0)\n",
    "    np.save('data.npy',X)\n",
    "    \n",
    "    \n",
    "print(X.shape)\n",
    "print(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "NAME = 'MLB_CNN'\n",
    "model_dir = 'Models/New'\n",
    "os.makedirs(model_dir, exist_ok=True)\n",
    "\n",
    "num_classes = 5\n",
    "n_concepts = 78\n",
    "\n",
    "classes = ['strike', 'ball', 'play', 'foul', 'out']\n",
    "\n",
    "class_dict = {\n",
    "    'strike': 0,\n",
    "    'ball':1,\n",
    "    'play':2,\n",
    "    'foul':3,\n",
    "    'out':4 }\n",
    "\n",
    "inv_class_dict = {v: k for k, v in class_dict.items()}\n",
    "\n",
    "concept_matrix = labels['Concepts'].values\n",
    "concept_matrix = np.stack(concept_matrix, axis=0)\n",
    "idx = np.argwhere(np.all(concept_matrix[..., :] == 0, axis=0))\n",
    "concept_matrix = np.delete(concept_matrix, idx, axis=1)\n",
    "concept_matrix = concept_matrix[:,:n_concepts]\n",
    "print(concept_matrix.shape)\n",
    "\n",
    "y = np.array([class_dict[label] for label in labels['Label']])\n",
    "\n",
    "y_binary = tf.keras.utils.to_categorical(y,num_classes)\n",
    "print(y_binary.shape)\n",
    "\n",
    "n_train = 1700\n",
    "X_train0 = X[:n_train,:,:]\n",
    "y_train_binary = y_binary[:n_train,:] \n",
    "X_test0 = X[n_train:,:,:]\n",
    "y_test_binary = y_binary[n_train:,:] \n",
    "concept_train = concept_matrix[:n_train,:]\n",
    "concept_test = concept_matrix[n_train:,:]\n",
    "\n",
    "print(X_train0.shape)\n",
    "print(y_train_binary.shape)\n",
    "print(concept_train.shape)\n",
    "print(X_test0.shape)\n",
    "print(y_test_binary.shape)\n",
    "print(concept_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, X_train, y_train, X_val, y_val, model_dir, t, batch_size, epochs=50, name = NAME):\n",
    "    \n",
    "    model.compile(loss=tf.keras.losses.categorical_crossentropy,\n",
    "                  optimizer='adam',\n",
    "                  metrics=['accuracy'])\n",
    "\n",
    "    # checkpoint\n",
    "    chk_path = os.path.join(model_dir, 'best_{}_{}.h5'.format(name,t))\n",
    "    checkpoint = ModelCheckpoint(chk_path, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')\n",
    "    tensorboard = TensorBoard(log_dir=\"logs/{}_{}\".format(name,t))\n",
    "    callbacks_list = [checkpoint, tensorboard]\n",
    "\n",
    "    history = model.fit(X_train, y_train,\n",
    "                batch_size=batch_size,\n",
    "                epochs=epochs,\n",
    "                verbose=1,\n",
    "                shuffle=True,\n",
    "                validation_data=(X_val, y_val),\n",
    "                callbacks=callbacks_list)\n",
    "    \n",
    "    #Saving the model\n",
    "    model.save(os.path.join(model_dir, 'final_{}_{}.h5'.format(NAME,t)))\n",
    "    \n",
    "    return model, history\n",
    "\n",
    "def train_concept_model(model, X_train, y_train, c_train, X_val, y_val, c_val, \n",
    "                         model_dir, t, n_concepts, batch_size=256, epochs=50, name = NAME):\n",
    "    \n",
    "    \n",
    "    losses={\n",
    "        \"c_probs\": tf.keras.losses.binary_crossentropy,\n",
    "        \"probs\": tf.keras.losses.categorical_crossentropy,\n",
    "    }\n",
    "    \n",
    "    model.compile(loss=losses,\n",
    "                  optimizer='adam',\n",
    "                  metrics=['accuracy'])\n",
    "\n",
    "    # checkpoint\n",
    "    chk_path = os.path.join(model_dir, 'best_{}_{}_{}.h5'.format(name,n_concepts,t))\n",
    "    checkpoint = ModelCheckpoint(chk_path, monitor='val_probs_accuracy',\n",
    "                                 verbose=1, save_best_only=True, mode='max')\n",
    "    tensorboard = TensorBoard(log_dir=\"logs/{}_{}\".format(name,t))\n",
    "    callbacks_list = [checkpoint, tensorboard]\n",
    "\n",
    "    history = model.fit(X_train, {'probs':y_train, 'c_probs':c_train},\n",
    "                batch_size=batch_size,\n",
    "                epochs=epochs,\n",
    "                verbose=1,\n",
    "                shuffle=True,\n",
    "                validation_data=(X_val, {'probs':y_val, 'c_probs':c_val}),\n",
    "                callbacks=callbacks_list)\n",
    "    \n",
    "    #Saving the model\n",
    "    model.save(os.path.join(model_dir, 'final_{}_{}_{}.h5'.format(name,n_concepts,t)))\n",
    "    \n",
    "    return model, history\n",
    "\n",
    "\n",
    "def calculate_metrics(model, X_test, y_test_binary):\n",
    "    y_pred = np.argmax(model.predict(X_test), axis=1)\n",
    "    y_true = np.argmax(y_test_binary, axis=1)\n",
    "    mismatch = np.where(y_true != y_pred)\n",
    "    cf_matrix = confusion_matrix(y_true, y_pred)\n",
    "    accuracy = accuracy_score(y_true, y_pred)\n",
    "    #micro_f1 = f1_score(y_true, y_pred, average='micro')\n",
    "    macro_f1 = f1_score(y_true, y_pred, average='macro')\n",
    "    \n",
    "    return cf_matrix, accuracy, macro_f1, mismatch, y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_MLP(dim, win_len, num_classes, num_hidden_mlp=256, p=0.3):\n",
    "    model = Sequential()\n",
    "    model.add(Dense(num_hidden_mlp, activation='relu', input_shape=(dim * win_len,), name='dense_1'))\n",
    "    model.add(BatchNormalization(name='Bn_1'))\n",
    "    model.add(Dropout(p, name='Drop_1'))\n",
    "    model.add(Dense(num_hidden_mlp, activation='relu', name='dense_2'))\n",
    "    model.add(BatchNormalization(name='Bn_2'))\n",
    "    model.add(Dropout(p, name='Drop_2'))\n",
    "    model.add(Dense(num_classes, activation='softmax', name='dense_out'))\n",
    "    return model\n",
    "\n",
    "\n",
    "def model_CNN(dim, win_len, num_classes, num_feat_map=64, p=0.3):\n",
    "    model = Sequential()\n",
    "    model.add(Conv2D(num_feat_map, kernel_size=(1, 3),\n",
    "                   activation='relu',\n",
    "                   input_shape=(dim, win_len, 1),\n",
    "                   padding='same', name='Conv_1'))\n",
    "    model.add(BatchNormalization(name='Bn_1'))\n",
    "    model.add(MaxPooling2D(pool_size=(1, 2), name='Max_pool_1'))\n",
    "    model.add(Dropout(p, name='Drop_1'))\n",
    "    model.add(Conv2D(num_feat_map, kernel_size=(1, 3), activation='relu', padding='same', name='Conv_2'))\n",
    "    model.add(BatchNormalization(name='Bn_2'))\n",
    "    model.add(MaxPooling2D(pool_size=(1, 2), name='Max_pool_2'))\n",
    "    model.add(Dropout(p, name='Drop_2'))\n",
    "    model.add(Flatten(name='Flatten_1'))\n",
    "    model.add(Dense(32, activation='relu'))\n",
    "    model.add(BatchNormalization(name='Bn_3'))\n",
    "    model.add(Dropout(p, name='Drop_3'))\n",
    "    model.add(Dense(num_classes, name='logits'))\n",
    "    model.add(Activation('softmax', name = 'probs'))\n",
    "    return model\n",
    "\n",
    "def model_Conv1D(dim, win_len, num_classes, num_feat_map=64, p=0.3):\n",
    "    model = Sequential()\n",
    "    model.add(Conv1D(128, kernel_size=3, activation='relu', padding='same', \n",
    "                     input_shape=(win_len, dim),name='Conv_1'))\n",
    "    model.add(MaxPooling1D(pool_size=4, name='Max_pool_1'))\n",
    "    model.add(BatchNormalization(name='Bn_1'))\n",
    "    model.add(Dropout(p, name='Drop_1'))\n",
    "    model.add(Conv1D(16, kernel_size=3, activation='relu', padding='same',name='Conv_2'))\n",
    "    model.add(BatchNormalization(name='Bn_2'))\n",
    "    model.add(Dropout(p, name='Drop_2'))\n",
    "    model.add(Flatten(name = 'flatten'))\n",
    "    model.add(Dense(300, activation='relu'))\n",
    "    model.add(BatchNormalization(name='Bn_3'))\n",
    "    model.add(Dropout(p, name='Drop_3'))\n",
    "    model.add(Dense(num_classes, name='logits'))\n",
    "    model.add(Activation('softmax', name = 'probs'))\n",
    "    return model\n",
    "\n",
    "def model_LSTM(dim, win_len, num_classes, num_hidden_lstm=200, p=0.3):\n",
    "    model = Sequential()\n",
    "    model.add(Bidirectional(LSTM(num_hidden_lstm,\n",
    "                 return_sequences=True, name='Lstm_1'), input_shape=(win_len, dim)))\n",
    "    model.add(BatchNormalization(name='Bn_1'))\n",
    "    model.add(Dropout(p, name='Drop_1'))\n",
    "    model.add(Bidirectional(LSTM(num_hidden_lstm, return_sequences=False, name='Lstm_2')))\n",
    "    model.add(Dropout(p, name='Drop_2'))\n",
    "    model.add(BatchNormalization(name='Bn_2'))\n",
    "    model.add(Dense(num_classes, name='logits'))\n",
    "    model.add(Activation('softmax', name = 'probs'))\n",
    "    return model\n",
    "\n",
    "def model_ConvLSTM(dim, win_len, num_classes, num_feat_map=64, p=0.3):\n",
    "    model = Sequential()\n",
    "    model.add(Conv2D(num_feat_map, kernel_size=(1, 3),\n",
    "                   activation='relu',\n",
    "                   input_shape=(dim, win_len, 1),\n",
    "                   padding='same', name='Conv_1'))\n",
    "    model.add(BatchNormalization(name='Bn_1'))\n",
    "    model.add(MaxPooling2D(pool_size=(1, 2), name='Max_pool_1'))\n",
    "    model.add(Dropout(p, name='Drop_1'))\n",
    "    model.add(Conv2D(num_feat_map, kernel_size=(1, 3), activation='relu', padding='same', name='Conv_2'))\n",
    "    model.add(BatchNormalization(name='Bn_2'))\n",
    "    model.add(MaxPooling2D(pool_size=(1, 2), name='Max_pool_2'))\n",
    "    model.add(Dropout(p, name='Drop_2'))\n",
    "    model.add(Permute((2, 1, 3), name='Permute_1'))  # for swap-dimension\n",
    "    model.add(Reshape((-1, num_feat_map * dim), name='Reshape_1'))\n",
    "    model.add(LSTM(32, return_sequences=False, stateful=False, name='Lstm_1'))\n",
    "    model.add(Dropout(p, name='Drop_3'))\n",
    "    model.add(Dense(num_classes, activation='softmax', name='dense_out'))\n",
    "    return model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_Conv1D_concepts(dim, win_len, num_classes, n_concepts, num_feat_map=64, p=0.3):\n",
    "    inputs = Input(shape=(win_len, dim), name='Input_1')\n",
    "    x = Conv1D(128, kernel_size=3, activation='relu', padding='same',name='Conv_1')(inputs)\n",
    "    x = MaxPooling1D(pool_size=4, name='Max_pool_1')(x)\n",
    "    x = BatchNormalization(name='Bn_1')(x)\n",
    "    x = Dropout(p, name='Drop_1')(x)\n",
    "    x = Conv1D(16, kernel_size=3, activation='relu', padding='same',name='Conv_2')(x)\n",
    "    x = BatchNormalization(name='Bn_2')(x)\n",
    "    x = Dropout(p, name='Drop_2')(x)\n",
    "    x = Flatten(name = 'flatten')(x)\n",
    "    concepts = Dense(n_concepts, name='concept_logits')(x)\n",
    "    concepts = Activation('sigmoid', name = 'c_probs')(concepts)\n",
    "    out = Dense(num_classes, name='logits')(concepts)\n",
    "    out = Activation('softmax', name = 'probs')(out)\n",
    "    \n",
    "    model = Model(inputs=inputs, outputs=[concepts, out], name=\"Video_concepts\")\n",
    "    return model\n",
    "\n",
    "def model_LSTM_concepts(dim, win_len, num_classes, n_concepts, p=0.3):\n",
    "    inputs = Input(shape=(win_len, dim), name='Input_1')\n",
    "    x = LSTM(128, return_sequences=True,name='lstm_1')(inputs)\n",
    "    x = MaxPooling1D(pool_size=4, name='Max_pool_1')(x)\n",
    "    x = BatchNormalization(name='Bn_1')(x)\n",
    "    x = Dropout(p, name='Drop_1')(x)\n",
    "    x = LSTM(16,return_sequences=True,name='lstm_s')(x)\n",
    "    x = BatchNormalization(name='Bn_2')(x)\n",
    "    x = Dropout(p, name='Drop_2')(x)\n",
    "    x = Flatten(name = 'flatten')(x)\n",
    "    concepts = Dense(n_concepts, name='concept_logits')(x)\n",
    "    concepts = Activation('sigmoid', name = 'c_probs')(concepts)\n",
    "    out = Dense(num_classes, name='logits')(concepts)\n",
    "    out = Activation('softmax', name = 'probs')(out)\n",
    "    \n",
    "    model = Model(inputs=inputs, outputs=[concepts, out], name=\"Video_concepts\")\n",
    "    return model\n",
    "\n",
    "def model_Conv1D_attn_concepts(dim, win_len, num_classes, n_concepts, num_feat_map=64, p=0.3):\n",
    "    inputs = Input(shape=(win_len, dim), name='Input_1')\n",
    "    x = Conv1D(128, kernel_size=3, activation='relu', padding='same',name='Conv_1')(inputs)\n",
    "    x = MaxPooling1D(pool_size=4, name='Max_pool_1')(x)\n",
    "    x = BatchNormalization(name='Bn_1')(x)\n",
    "    x = Dropout(p, name='Drop_1')(x)\n",
    "    x = Conv1D(16, kernel_size=3, activation='relu', padding='same',name='Conv_2')(x)\n",
    "    x = BatchNormalization(name='Bn_2')(x)\n",
    "    x = Dropout(p, name='Drop_2')(x)\n",
    "    x = Flatten(name = 'flatten')(x)\n",
    "    concepts = Dense(n_concepts, name='concept_logits')(x)\n",
    "    concepts = Activation('sigmoid', name = 'c_probs')(concepts)\n",
    "    \n",
    "    attention = Dense(n_concepts, name = 'attention_weights', activation='tanh')(concepts)\n",
    "    attention = Activation('softmax', name='attn_score')(attention)\n",
    "    \n",
    "    out = Multiply(name='mul')([attention,concepts])\n",
    "    out = Dense(num_classes, name='logits')(out)\n",
    "    out = Activation('softmax', name = 'probs')(out)\n",
    "    \n",
    "    model = Model(inputs=inputs, outputs=[concepts, out], name=\"Video_concepts\")\n",
    "    return model\n",
    "\n",
    "def model_LSTM_attn_concepts(dim, win_len, num_classes, n_concepts, num_feat_map=64, p=0.3):\n",
    "    inputs = Input(shape=(win_len, dim), name='Input_1')\n",
    "    x = LSTM(512, return_sequences=True,name='Lstm_1')(inputs)\n",
    "    x = MaxPooling1D(pool_size=4, name='Max_pool_1')(x)\n",
    "    x = BatchNormalization(name='Bn_1')(x)\n",
    "    x = Dropout(p, name='Drop_1')(x)\n",
    "    x = LSTM(16,return_sequences=True,name='Lstm_2')(x)\n",
    "    x = BatchNormalization(name='Bn_2')(x)\n",
    "    x = Dropout(p, name='Drop_2')(x)\n",
    "\n",
    "#     x = LSTM(256, return_sequences=False, dropout=p)(inputs)\n",
    "#     x = Dense(256, activation='relu')(x)\n",
    "#     x = Dropout(p)(x)\n",
    "\n",
    "    \n",
    "    x = Flatten(name = 'flatten')(x)\n",
    "    concepts = Dense(n_concepts, name='concept_logits')(x)\n",
    "    concepts = Activation('sigmoid', name = 'c_probs')(concepts)\n",
    "    \n",
    "    attention = Dense(n_concepts, name = 'attention_weights', activation='tanh')(concepts)\n",
    "    attention = Activation('softmax', name='attn_score')(attention)\n",
    "    \n",
    "    out = Multiply(name='mul')([attention,concepts])\n",
    "    out = Dense(num_classes, name='logits')(out)\n",
    "    out = Activation('softmax', name = 'probs')(out)\n",
    "    \n",
    "    model = Model(inputs=inputs, outputs=[concepts, out], name=\"Video_concepts\")\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#network_type = 'MLP'\n",
    "#network_type = 'CNN'\n",
    "# network_type = 'LSTM'\n",
    "# network_type = 'ConvLSTM'\n",
    "# network_type = 'Conv1D'\n",
    "# network_type = 'concept'\n",
    "network_type = 'concept_attn'\n",
    "# network_type = 'concept_LSTM_attn'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#specifying hyper-parameters\n",
    "batch_size = 16\n",
    "_, win_len, dim = X_train0.shape\n",
    "n_concepts = concept_train.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print('building the model ...')\n",
    "if network_type =='CNN' :\n",
    "    model = model_CNN(dim, win_len, num_classes, num_feat_map=32, p=0.3)\n",
    "\n",
    "if network_type =='ConvLSTM':\n",
    "    model = model_ConvLSTM(dim, win_len, num_classes, num_feat_map=32, p=0.3)\n",
    "\n",
    "if network_type =='LSTM':\n",
    "    model = model_LSTM(dim, win_len, num_classes, num_hidden_lstm=32, p=0.3)\n",
    "    \n",
    "if network_type =='MLP': \n",
    "    model = model_MLP(dim, win_len, num_classes, num_hidden_mlp=256, p=0.3)\n",
    "    \n",
    "if network_type =='Conv1D': \n",
    "    model = model_Conv1D(dim, win_len, num_classes, num_feat_map=64, p=0.5)\n",
    "    \n",
    "if network_type =='concept': \n",
    "    model = model_Conv1D_concepts(dim, win_len, num_classes, n_concepts, p=0.2)\n",
    "\n",
    "if network_type =='concept_attn': \n",
    "    model = model_Conv1D_attn_concepts(dim, win_len, num_classes, n_concepts, num_feat_map=64, p=0.5)\n",
    "    \n",
    "if network_type =='concept_LSTM_attn': \n",
    "    model = model_LSTM_attn_concepts(dim, win_len, num_classes, n_concepts, num_feat_map=64, p=0.2)\n",
    "    \n",
    "print(model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "## Train Baselines\n",
    "\n",
    "t = int(time.time())\n",
    "model, H = train_model(model, X_train0, y_train_binary, X_test0, y_test_binary,\n",
    "                           model_dir, t, batch_size=16, epochs=100, name=network_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "## Train Models with Concept Bottleneck\n",
    "\n",
    "t = int(time.time())\n",
    "model, H = train_concept_model(model, X_train0, y_train_binary, concept_train, \n",
    "                               X_test0, y_test_binary, concept_test,\n",
    "                               model_dir, t, n_concepts, batch_size=16, epochs=100, name=network_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# summarize history for accuracy and loss\n",
    "plt.figure()\n",
    "plt.plot(H.history['accuracy'])\n",
    "plt.plot(H.history['val_accuracy'])\n",
    "plt.title('model accuracy')\n",
    "plt.ylabel('accuracy')\n",
    "plt.xlabel('epoch')\n",
    "plt.legend(['train', 'test'], loc='upper left')\n",
    "plt.show()\n",
    "# summarize history for loss\n",
    "plt.figure()\n",
    "plt.plot(H.history['loss'])\n",
    "plt.plot(H.history['val_loss'])\n",
    "plt.title('model loss')\n",
    "plt.ylabel('loss')\n",
    "plt.xlabel('epoch')\n",
    "plt.legend(['train', 'test'], loc='upper left')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#Load Trained Model\n",
    "model = load_model(model_dir + '/best_concept_attn_78_1632654508.h5')\n",
    "print(model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "## Calculate Metrics for Baseline models\n",
    "\n",
    "cf_matrix, accuracy, macro_f1, mismatch, y_pred, = calculate_metrics(model, X_test0, \n",
    "                                                                            y_test_binary)\n",
    "print('Accuracy : {}'.format(accuracy))\n",
    "print('F1-score : {}'.format(macro_f1))\n",
    "print(cf_matrix)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculate Metrics for Concept Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_concept_metrics(model, X_test, y_test_binary, c_test):\n",
    "    pred = model.predict(X_test)\n",
    "    y_pred = np.argmax(pred[1], axis=1)\n",
    "    y_true = np.argmax(y_test_binary, axis=1)\n",
    "    mismatch = np.where(y_true != y_pred)\n",
    "    cf_matrix = confusion_matrix(y_true, y_pred)\n",
    "    accuracy = accuracy_score(y_true, y_pred)\n",
    "    #micro_f1 = f1_score(y_true, y_pred, average='micro')\n",
    "    macro_f1 = f1_score(y_true, y_pred, average='macro')\n",
    "    \n",
    "    c_test = c_test.flatten()\n",
    "    c_prob = pred[0]\n",
    "    c_pred = c_prob.copy().flatten()\n",
    "    c_pred[c_pred <= 0.5] = 0\n",
    "    c_pred[c_pred > 0.5] = 1\n",
    "    cf_concepts = confusion_matrix(c_test, c_pred)\n",
    "    accuracy_concepts = accuracy_score(c_test, c_pred)\n",
    "\n",
    "    # calculate pr curve\n",
    "    precision, recall, thresholds = precision_recall_curve(c_test, c_prob.flatten())\n",
    "    # convert to f score\n",
    "    fscore = (2 * precision * recall) / (precision + recall)\n",
    "    # locate the index of the largest f score\n",
    "    ix = np.argmax(fscore)\n",
    "    print('Best Threshold=%f, F-Score=%.3f' % (thresholds[ix], fscore[ix]))\n",
    "    \n",
    "    return cf_matrix, accuracy, macro_f1, mismatch, cf_concepts, accuracy_concepts\n",
    "\n",
    "def get_attention(model, layer_name, input_data):\n",
    "\n",
    "    intermediate_layer_model = Model(inputs=model.input,\n",
    "                                     outputs=model.get_layer(layer_name).output)\n",
    "    intermediate_output = intermediate_layer_model.predict(input_data)\n",
    "\n",
    "    return intermediate_output\n",
    "\n",
    "def get_predictions(model, X_test, attn=False):\n",
    "    pred = model.predict(X_test)\n",
    "    y_pred = np.argmax(pred[1], axis=1)\n",
    "    \n",
    "    c_prob = pred[0]\n",
    "    c_pred = c_prob.copy()\n",
    "    c_pred[c_pred <= 0.5] = 0\n",
    "    c_pred[c_pred > 0.5] = 1\n",
    "    \n",
    "    if attn:\n",
    "        attention = np.squeeze(get_attention(model, 'mul', X_test))\n",
    "        return  y_pred, c_pred, c_prob, attention\n",
    "    else:\n",
    "        return y_pred, c_pred, c_prob\n",
    "\n",
    "def get_roc(model, X_test, c_test): \n",
    "    \n",
    "    c_pred = model.predict(X_test)[0]\n",
    "    c_test = c_test.flatten()\n",
    "    c_pred = c_pred.flatten()\n",
    "    fpr = dict()\n",
    "    tpr = dict()\n",
    "    roc_auc = dict()\n",
    "    \n",
    "    fpr, tpr, thresh = roc_curve(c_test, c_pred)\n",
    "    roc_auc = auc(fpr, tpr)\n",
    "\n",
    "    print (roc_auc_score(c_test, c_pred))\n",
    "    plt.figure()\n",
    "    plt.plot(fpr, tpr)\n",
    "    plt.xlim([0.0, 1.0])\n",
    "    plt.ylim([0.0, 1.05])\n",
    "    plt.xlabel('False Positive Rate')\n",
    "    plt.ylabel('True Positive Rate')\n",
    "    plt.title('Receiver operating characteristic')\n",
    "    plt.show()\n",
    "    \n",
    "    # get the best threshold\n",
    "    J = tpr - fpr\n",
    "    ix = np.argmax(J)\n",
    "    best_thresh = thresh[ix]\n",
    "    print('Best Threshold=%f' % (best_thresh))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cf_matrix, accuracy, macro_f1, mismatch,cf_concepts, accuracy_concepts = calculate_concept_metrics(model, X_test0, \n",
    "                                                                            y_test_binary,concept_test)\n",
    "print('Accuracy : {}'.format(accuracy))\n",
    "print('F1-score : {}'.format(macro_f1))\n",
    "print(cf_matrix)\n",
    "print(cf_concepts)\n",
    "print(accuracy_concepts)\n",
    "get_roc(model, X_test0, concept_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "y_pred, c_pred, c_prob = get_predictions(model, X_test0)  ## Without attention\n",
    "# y_pred, c_pred, c_prob, c_attn = get_predictions(model, X_test0, attn=True) ## With Sttention\n",
    "mismatch_og = mismatch\n",
    "print(y_pred.shape)\n",
    "print(c_pred.shape)\n",
    "print(c_prob.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generating options for Explanation Evaluation Study:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts_text = pd.read_csv('concepts_100.csv')\n",
    "top_attn = []\n",
    "\n",
    "k = 3\n",
    "for i,sample in enumerate(c_attn):\n",
    "    if i not in mismatch_og[0]:\n",
    "        top_attn_idx = np.argsort(sample)[::-1][:k]\n",
    "        top_attn.append((labels.iloc[n_train+i]['Id'],labels.iloc[n_train+i]['Label'],\n",
    "                            list(concepts_text.iloc[top_attn_idx]['text'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts_text = pd.read_csv('concepts_100.csv')\n",
    "top_concepts = []\n",
    "\n",
    "for i,sample in enumerate(c_prob):\n",
    "    if i not in mismatch_og[0]:\n",
    "        concepts_idx = np.where(sample>=0.5)\n",
    "        if (len(concepts_idx[0]) == 0):\n",
    "            concepts_idx =  np.argsort(sample)[::-1][:3]\n",
    "        top_concepts.append((labels.iloc[n_train+i]['Id'],labels.iloc[n_train+i]['Label'], \n",
    "                      list(concepts_text.iloc[concepts_idx]['text'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts_text = pd.read_csv('concepts_100.csv')\n",
    "option3 = [] # Concepts from another sample\n",
    "k=3\n",
    "\n",
    "# generate random integer values\n",
    "import random\n",
    "from random import seed\n",
    "# seed random number generator\n",
    "seed(42)\n",
    "\n",
    "for i,sample in enumerate(c_attn):\n",
    "    if i not in mismatch_og[0]:\n",
    "        temp_label = y_pred[i]\n",
    "        \n",
    "        while(temp_label == y_pred[i]):\n",
    "            temp_idx  = random.randint(0, len(c_attn)-1)\n",
    "            temp_label = y_pred[temp_idx]\n",
    "        \n",
    "        option3_idx  = np.argsort(c_attn[temp_idx])[::-1][:k]\n",
    "        option3.append((labels.iloc[n_train+temp_idx]['Id'],labels.iloc[n_train+temp_idx]['Label'],\n",
    "                           list(concepts_text.iloc[option3_idx]['text'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts_text = pd.read_csv('concepts_100.csv')\n",
    "option4 = [] # Random\n",
    "\n",
    "# generate random integer values\n",
    "import random\n",
    "from random import seed\n",
    "# seed random number generator\n",
    "seed(42)\n",
    "\n",
    "k =  random.randint(2, 5)\n",
    "\n",
    "for i,sample in enumerate(c_attn):\n",
    "    if i not in mismatch_og[0]:\n",
    "        option4_idx  = random.sample(range(0, n_concepts), k)\n",
    "        option4.append((labels.iloc[n_train+i]['Id'],labels.iloc[n_train+i]['Label'],\n",
    "                           list(concepts_text.iloc[option4_idx]['text'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n=10\n",
    "\n",
    "print(top_concepts[n])\n",
    "print(top_attn[n])\n",
    "print(option3[n])\n",
    "print(option4[n])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('video_ids.json', mode='wt', encoding='utf-8') as myfile:\n",
    "    for sample in top_attn:\n",
    "        myfile.write(f'\"{sample[0]}\",\\n')\n",
    "        \n",
    "with open('video_labels.json', mode='wt', encoding='utf-8') as myfile:\n",
    "    for sample in top_attn:\n",
    "        myfile.write(f'\"{sample[1]}\",\\n')\n",
    "        \n",
    "with open('option1.json', mode='wt', encoding='utf-8') as myfile:\n",
    "    for sample in top_concepts:\n",
    "        myfile.write(f'\"{sample[2]}\",\\n')\n",
    "        \n",
    "with open('option2.json', mode='wt', encoding='utf-8') as myfile:\n",
    "    for sample in top_attn:\n",
    "        myfile.write(f'\"{sample[2]}\",\\n')\n",
    "        \n",
    "with open('option3.json', mode='wt', encoding='utf-8') as myfile:\n",
    "    for sample in option3:\n",
    "        myfile.write(f'\"{sample[2]}\",\\n')\n",
    "        \n",
    "with open('option4.json', mode='wt', encoding='utf-8') as myfile:\n",
    "    for sample in option4:\n",
    "        myfile.write(f'\"{sample[2]}\",\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize Concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts_text = pd.read_csv('concepts_78.csv')\n",
    "plt.style.use('seaborn-whitegrid')\n",
    "\n",
    "def get_attention(model, layer_name, input_data):\n",
    "\n",
    "    intermediate_layer_model = Model(inputs=model.input,\n",
    "                                     outputs=model.get_layer(layer_name).output)\n",
    "    intermediate_output = intermediate_layer_model.predict(input_data)\n",
    "\n",
    "    return intermediate_output\n",
    "\n",
    "def visualize_concepts(test_input, model, plot=False):\n",
    "    \n",
    "    test_input = np.expand_dims(test_input, axis=0)\n",
    "    pred = model.predict(test_input)\n",
    "    pred_class = inv_class_dict[np.argmax(pred[1],axis=1)[0]]\n",
    "    pred_concepts = np.where(pred[0]>=0.5)\n",
    "    print(pred_concepts[1])\n",
    "    print(pred_class)\n",
    "    attention = np.squeeze(get_attention(model, 'attn_score', test_input))\n",
    "    pred_attn = attention[pred_concepts[1]]\n",
    "    pred_text = concepts_text['text'].iloc[pred_concepts[1]]\n",
    "    \n",
    "    if plot:\n",
    "#     plt.rcdefaults()\n",
    "        plt.style.use('seaborn-whitegrid')\n",
    "        plt.rcParams.update({'font.size': 18})\n",
    "        fig, ax = plt.subplots(figsize=(6,5))\n",
    "\n",
    "        y_pos = np.arange(len(pred_text))\n",
    "        ax.barh(y_pos, pred_attn, align='center')\n",
    "        ax.set_yticks(y_pos)\n",
    "        ax.set_yticklabels(pred_text, fontsize=16)\n",
    "        ax.invert_yaxis()  # labels read top-to-bottom\n",
    "        ax.set_xlabel('Concept Score', fontsize=20)\n",
    "        ax.set_title(f'Predicted Activity: {pred_class}', fontsize=20)\n",
    "        plt.show()\n",
    "\n",
    "    return attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "n =1\n",
    "print(labels['Label'].iloc[1700+n])\n",
    "input_data = X_test0[n]\n",
    "c_attn = visualize_concepts(input_data, model)\n",
    "\n",
    "\n",
    "#7, 24, 53, 55, 58, 59, 73, 79, 17\n",
    "\n",
    "#strike = array([  8,  16,  17,  35,  38,  42,  43,  45,  46,  49,  68,  71,  78,\n",
    "#          80,  84,  87,  93,  98, 110, 111, 112, 123, 129, 130, 133, 136,\n",
    "#         137, 140, 143, 144, 145, 146, 148, 150, 157, 163, 168, 170, 172,\n",
    "#         178, 189, 192, 198, 200, 205, 207, 208, 214, 215, 217]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = model.predict(X_test0)\n",
    "y_pred = np.argmax(pred[1], axis=1)\n",
    "print(pred[0].shape)\n",
    "print(pred[1].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n=22\n",
    "print(labels.iloc[n_train+n])\n",
    "print(inv_class_dict[y_pred[n]])\n",
    "print(np.where(pred[0][n]>=0.5))\n",
    "\n",
    "\n",
    "input_data = np.expand_dims(X_test0[n],axis=0)\n",
    "attention = np.squeeze(get_attention(model, 'mul',input_data ))\n",
    "attn_normalized = (attention - np.min(attention)) / np.sum(attention - np.min(attention))\n",
    "plt.figure()\n",
    "plt.bar(np.arange(len(attn_normalized)),attn_normalized)\n",
    "plt.show()\n",
    "inv_class_dict[np.argmax(model.predict(input_data)[1],axis=1)[0]]\n",
    "\n",
    "pred_attn = attn_normalized[np.where(pred[0][n]>=0.5)[0]]\n",
    "print(pred_attn)\n",
    "pred_text = concepts_text['text'].iloc[np.where(pred[0][n]>=0.5)[0]]\n",
    "print(pred_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unsorted_list = [(importance, feature) for feature, importance in \n",
    "                  zip(concepts_text['text'], attn_normalized)]\n",
    "sorted_list = sorted(unsorted_list)\n",
    "sorted_list = sorted_list[::-1]\n",
    "\n",
    "\n",
    "N=3\n",
    "\n",
    "df = pd.DataFrame(sorted_list, columns = ['Score','Concepts'])\n",
    "\n",
    "features_sorted = []\n",
    "importance_sorted = []\n",
    "\n",
    "for i in reversed(range(N)):\n",
    "    features_sorted.append(sorted_list[i][1])\n",
    "    importance_sorted.append(sorted_list[i][0])\n",
    "    \n",
    "plt.style.use('seaborn-whitegrid')\n",
    "\n",
    "plt.title(f\"Predicted Class : '{inv_class_dict[y_pred[n]]}'\", fontsize=15)\n",
    "plt.xlabel(\"Concept Score\", fontsize=13)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.barh(['Others'],[np.sum(df['Score'].values[N:])], label = \"Others\", color='grey')\n",
    "plt.barh(features_sorted,importance_sorted, label = \"Concepts\", color='green')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Performance for different number of concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = [15, 30, 50, 80, 100, 200]\n",
    "mean_a = [0.6169, 0.6389, 0.6537, 0.67833, 0.6784, 0.6780 ]\n",
    "std_a = [0.0082, 0.0166, 0.0177, 0.00716, 0.0028, 0.0037]\n",
    "mean_f = [0.6145, 0.6452, 0.65715, 0.6802, 0.6809, 0.6805]\n",
    "std_f = [0.0126, 0.01769, 0.0134, 0.00676, 0.0049, 0.0042]\n",
    "\n",
    "plt.figure(figsize=(10,6))\n",
    "plt.errorbar(indices, mean_a, yerr=std_a, fmt='-ok', ecolor='gray', capsize=3);\n",
    "# plt.ylim([0.5,0.75])\n",
    "# plt.xticks(fontsize=14, rotation=90)\n",
    "\n",
    "plt.xticks(fontsize=18)\n",
    "plt.yticks(fontsize=18)\n",
    "plt.xlabel('Number of Concepts', fontsize=22)\n",
    "plt.ylabel('Accuracy', fontsize=22)\n",
    "plt.savefig('Accuracy')\n",
    "plt.show()\n",
    "\n",
    "\n",
    "plt.figure(figsize=(10,6))\n",
    "plt.errorbar(indices, mean_f, yerr=std_f, fmt='-ok', ecolor='gray', capsize=3);\n",
    "plt.xticks(fontsize=18)\n",
    "plt.yticks(fontsize=18)\n",
    "plt.xlabel('Number of Concepts', fontsize=22)\n",
    "plt.ylabel('F1-Score', fontsize=22)\n",
    "plt.savefig('F1_score')\n",
    "plt.show()\n",
    "\n",
    "\n",
    "\n",
    "plt.figure(figsize=(10,6))\n",
    "plt.errorbar(indices, mean_a, yerr=std_a, fmt='-ok', ecolor='gray', capsize=3);\n",
    "plt.errorbar(indices, mean_f, yerr=std_f, fmt='-or', ecolor='blue', capsize=3);\n",
    "plt.xticks(fontsize=18)\n",
    "plt.yticks(fontsize=18)\n",
    "plt.xlabel('Number of Concepts', fontsize=22)\n",
    "plt.ylabel('Accuracy/F1-Score', fontsize=22)\n",
    "plt.legend(['Accuracy','F1-Score'], fontsize=18, loc=4)\n",
    "plt.savefig('both')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Miscellaneous - Closest Training Sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from heapq import nlargest\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "\n",
    "pred_train = model.predict(X_train0)\n",
    "labels_train = np.argmax(pred_train[1],axis=1)\n",
    "np.save('pred_labels_train',labels_train)\n",
    "np.save('pred_train_concepts',pred_train[0])\n",
    "\n",
    "n =79\n",
    "print(labels.iloc[1700+n])\n",
    "test_input = np.expand_dims(X_test0[n], axis=0)\n",
    "pred = model.predict(test_input)\n",
    "pred_label = np.argmax(pred[1],axis=1)[0]\n",
    "pred_class = inv_class_dict[np.argmax(pred[1],axis=1)[0]]\n",
    "pred_concepts = np.where(pred[0]>=0.5)\n",
    "print(pred_concepts[1])\n",
    "print(pred_label)\n",
    "print(pred_class)\n",
    "print(pred[0].shape)\n",
    "\n",
    "sim = cosine_similarity(pred_train[0],pred[0])\n",
    "topk = nlargest(3, range(len(sim)),\n",
    "                key=lambda idx: sim[idx] if pred_label == labels_train[idx] else 0)\n",
    "\n",
    "print(topk)\n",
    "print(labels.iloc[topk])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
