{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from warnings import simplefilter \n",
    "simplefilter(action='ignore', category=FutureWarning)\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import time\n",
    "\n",
    "from tensorflow.keras.models import load_model\n",
    "\n",
    "from utils import *\n",
    "from model_architectures import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FEATURE_EXTRACTOR = 'Resnet50V2'\n",
    "\n",
    "features_dir = f'../Data/Feature_vectors_{FEATURE_EXTRACTOR}'\n",
    "labels_df_filtered = pd.read_pickle(f'../Extracted_Concepts/labels_df_filtered_100.pkl')\n",
    "labels = labels_df_filtered.copy()\n",
    "\n",
    "try:\n",
    "    labels = pd.read_pickle('labels_100.pkl')\n",
    "    X = np.load('data.npy')\n",
    "    \n",
    "except:\n",
    "\n",
    "    X = []\n",
    "    for id in labels_df_filtered['Id']:\n",
    "        feature_path = os.path.join(features_dir,id+'.npy')\n",
    "        if os.path.isfile(feature_path):\n",
    "            X.append(np.load(feature_path).T)\n",
    "\n",
    "        else:\n",
    "            labels = labels[labels['Id']!=id]\n",
    "\n",
    "    labels = labels.reset_index(drop=True)\n",
    "    labels.to_pickle('labels_100.pkl')\n",
    "    X = np.stack(X, axis=0)\n",
    "    np.save('data.npy',X)\n",
    "    \n",
    "    \n",
    "print(X.shape)\n",
    "print(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dir = 'Models'\n",
    "os.makedirs(model_dir, exist_ok=True)\n",
    "\n",
    "num_classes = 5\n",
    "concepts_text = pd.read_csv('../Extract_Concepts/concepts_100.csv')\n",
    "\n",
    "classes = ['strike', 'ball', 'play', 'foul', 'out']\n",
    "\n",
    "class_dict = {\n",
    "    'strike': 0,\n",
    "    'ball':1,\n",
    "    'play':2,\n",
    "    'foul':3,\n",
    "    'out':4 }\n",
    "\n",
    "inv_class_dict = {v: k for k, v in class_dict.items()}\n",
    "\n",
    "concept_matrix = labels['Concepts'].values\n",
    "concept_matrix = np.stack(concept_matrix, axis=0)\n",
    "idx = np.argwhere(np.all(concept_matrix[..., :] == 0, axis=0))\n",
    "concept_matrix = np.delete(concept_matrix, idx, axis=1)\n",
    "concept_matrix = concept_matrix[:,:n_concepts]\n",
    "print(concept_matrix.shape)\n",
    "\n",
    "y = np.array([class_dict[label] for label in labels['Label']])\n",
    "\n",
    "y_binary = tf.keras.utils.to_categorical(y,num_classes)\n",
    "print(y_binary.shape)\n",
    "\n",
    "X_train0 = X[:1700,:,:]\n",
    "y_train_binary = y_binary[:1700,:] \n",
    "X_test0 = X[1700:,:,:]\n",
    "y_test_binary = y_binary[1700:,:] \n",
    "concept_train = concept_matrix[:1700,:]\n",
    "concept_test = concept_matrix[1700:,:]\n",
    "\n",
    "print(X_train0.shape)\n",
    "print(y_train_binary.shape)\n",
    "print(concept_train.shape)\n",
    "print(X_test0.shape)\n",
    "print(y_test_binary.shape)\n",
    "print(concept_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# network_type = 'Conv1D'\n",
    "# network_type = 'concept_Conv'\n",
    "network_type = 'concept_Conv_attn'\n",
    "# network_type = 'LSTM'\n",
    "# network_type = 'concept_LSTM'\n",
    "# network_type = 'concept_LSTM_attn'\n",
    "\n",
    "#specifying hyper-parameters\n",
    "batch_size = 16\n",
    "_, win_len, dim = X_train0.shape\n",
    "n_concepts = concept_train.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print('building the model ...')\n",
    "\n",
    "if network_type =='Conv1D': \n",
    "    model = model_Conv1D(dim, win_len, num_classes, num_feat_map=128, p=0.5)\n",
    "    \n",
    "if network_type =='concept_Conv': \n",
    "    model = model_Conv1D_concepts(dim, win_len, num_classes, n_concepts, p=0.5)\n",
    "\n",
    "if network_type =='concept_Conv_attn': \n",
    "    model = model_Conv1D_attn_concepts(dim, win_len, num_classes, n_concepts, num_feat_map=64, p=0.5)\n",
    "\n",
    "if network_type =='LSTM':\n",
    "    model = model_LSTM(dim, win_len, num_classes, num_hidden_lstm=512, p=0.2)\n",
    "\n",
    "if network_type =='concept_LSTM_attn': \n",
    "    model = model_LSTM_attn_concepts(dim, win_len, num_classes, n_concepts, num_hidden_lstm=512, p=0.2)  \n",
    "\n",
    "    \n",
    "if network_type =='concept_LSTM_attn': \n",
    "    model = model_LSTM_attn_concepts(dim, win_len, num_classes, n_concepts, num_hidden_lstm=512, p=0.2)\n",
    "    \n",
    "print(model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "t = int(time.time())\n",
    "\n",
    "if(network_type =='Conv1D' or network_type =='LSTM'):\n",
    "    model, H = train_model(model, X_train0, y_train_binary, X_test0, y_test_binary,\n",
    "                               model_dir, t, batch_size=batch_size, epochs=100, name=network_type)\n",
    "    \n",
    "else:\n",
    "    model, H = train_concept_model(model, X_train0, y_train_binary, concept_train, \n",
    "                               X_test0, y_test_binary, concept_test,\n",
    "                               model_dir, t, n_concepts, batch_size=batch_size, epochs=100, name=network_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#Load BEST Trained Model\n",
    "model = load_model(model_dir + '/best_concept_Conv_attn_78_1622778374.h5')\n",
    "print(model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "if(network_type =='Conv1D' or network_type =='LSTM'):\n",
    "    cf_matrix, accuracy, macro_f1, mismatch, y_pred, = calculate_metrics(model, X_test0, \n",
    "                                                                            y_test_binary)\n",
    "    print('Accuracy : {}'.format(accuracy))\n",
    "    print('F1-score : {}'.format(macro_f1))\n",
    "    print(cf_matrix)\n",
    "    \n",
    "else:\n",
    "    cf_matrix, accuracy, macro_f1, mismatch, y_pred, cf_concepts, accuracy_concepts = calculate_concept_metrics(model, X_test0, \n",
    "                                                                            y_test_binary,concept_test)\n",
    "    \n",
    "    print('Accuracy : {}'.format(accuracy))\n",
    "    print('F1-score : {}'.format(macro_f1))\n",
    "    print(cf_matrix)\n",
    "    print(cf_concepts)\n",
    "    print(accuracy_concepts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results over multiple runs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The performance of our models lie between 60-70%. The main reason for this is the confusion between strike and ball. These two activities are very similar to eachother in certain cases which are difficult to distinguish even for human eyes. A simple way to improve the performance is to increase the training data of strike and balls "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = [15, 30, 50, 78, 100, 200]\n",
    "mean_a = [0.6169, 0.6389, 0.6437, 0.66133, 0.6614, 0.6610 ]\n",
    "std_a = [0.0082, 0.0166, 0.0177, 0.00716, 0.0028, 0.0037]\n",
    "mean_f = [0.6145, 0.6452, 0.64715, 0.6632, 0.6639, 0.6645]\n",
    "std_f = [0.0126, 0.01769, 0.0134, 0.00676, 0.0049, 0.0042]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use('seaborn-whitegrid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,6))\n",
    "plt.errorbar(indices, mean_a, yerr=std_a, fmt='-ok', ecolor='gray', capsize=3);\n",
    "# plt.ylim([0.5,0.75])\n",
    "# plt.xticks(fontsize=14, rotation=90)\n",
    "\n",
    "plt.xticks(fontsize=18)\n",
    "plt.yticks(fontsize=18)\n",
    "plt.xlabel('Number of Concepts', fontsize=22)\n",
    "plt.ylabel('Accuracy', fontsize=22)\n",
    "plt.savefig('Accuracy')\n",
    "plt.show()\n",
    "\n",
    "\n",
    "plt.figure(figsize=(10,6))\n",
    "plt.errorbar(indices, mean_f, yerr=std_f, fmt='-ok', ecolor='gray', capsize=3);\n",
    "plt.xticks(fontsize=18)\n",
    "plt.yticks(fontsize=18)\n",
    "plt.xlabel('Number of Concepts', fontsize=22)\n",
    "plt.ylabel('F1-Score', fontsize=22)\n",
    "plt.savefig('F1_score')\n",
    "plt.show()\n",
    "\n",
    "\n",
    "\n",
    "plt.figure(figsize=(10,6))\n",
    "plt.errorbar(indices, mean_a, yerr=std_a, fmt='-ok', ecolor='gray', capsize=3);\n",
    "plt.errorbar(indices, mean_f, yerr=std_f, fmt='-or', ecolor='blue', capsize=3);\n",
    "plt.xticks(fontsize=18)\n",
    "plt.yticks(fontsize=18)\n",
    "plt.xlabel('Number of Concepts', fontsize=22)\n",
    "plt.ylabel('Accuracy/F1-Score', fontsize=22)\n",
    "plt.legend(['Accuracy','F1-Score'], fontsize=18, loc=4)\n",
    "plt.savefig('both')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
