{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from operator import itemgetter\n",
    "from mmaction.apis import init_recognizer, inference_recognizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "config_file = '../demo/demo_configs/tsn_r50_1x1x8_video_infer.py'\n",
    "# download the checkpoint from model zoo and put it in `checkpoints/`\n",
    "checkpoint_file = '../checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loads checkpoint by local backend from path: ../checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth\n"
     ]
    }
   ],
   "source": [
    "# build the model from a config file and a checkpoint file\n",
    "model = init_recognizer(config_file, checkpoint_file, device='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# test a single video and show the result:\n",
    "video = 'demo.mp4'\n",
    "label = '../tools/data/kinetics/label_map_k400.txt'\n",
    "results = inference_recognizer(model, video)\n",
    "\n",
    "pred_scores = results.pred_scores.item.tolist()\n",
    "score_tuples = tuple(zip(range(len(pred_scores)), pred_scores))\n",
    "score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)\n",
    "top5_label = score_sorted[:5]\n",
    "\n",
    "labels = open(label).readlines()\n",
    "labels = [x.strip() for x in labels]\n",
    "results = [(labels[k[0]], k[1]) for k in top5_label]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "arm wrestling:  1.0\n",
      "rock scissors paper:  1.698846019067312e-15\n",
      "massaging feet:  5.157996544393221e-16\n",
      "stretching leg:  1.018867278715779e-16\n",
      "bench pressing:  7.110452486439706e-17\n"
     ]
    }
   ],
   "source": [
    "# show the results\n",
    "for result in results:\n",
    "    print(f'{result[0]}: ', result[1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mmact_dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13 (default, Mar 29 2022, 02:18:16) \n[GCC 7.5.0]"
  },
  "vscode": {
   "interpreter": {
    "hash": "189c342a4747645665e89db23000ac4d4edb7a87c4cd0b2f881610f468fb778d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
