{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "40bd588c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bc0ae5bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) Facebook, Inc. and its affiliates.\n",
    "\n",
    "\"\"\"The Epic Kitchens dataset loaders.\"\"\"\n",
    "\n",
    "from typing import List, Dict, Sequence, Tuple, Union\n",
    "from datetime import datetime, date\n",
    "from collections import OrderedDict\n",
    "import pickle as pkl\n",
    "import csv\n",
    "import logging\n",
    "from pathlib import Path\n",
    "import lmdb\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from epic_kitchens.base_dataloader import BaseVideoDataset, RULSTM_TSN_FPS\n",
    "from epic_kitchens.reader_fns import Reader\n",
    "\n",
    "EGTEA_VERSION = -1  # This class also supports EGTEA Gaze+\n",
    "EPIC55_VERSION = 0.1\n",
    "EPIC100_VERSION = 0.2\n",
    "\n",
    "\n",
    "class EPICKitchens(BaseVideoDataset):\n",
    "    \"\"\"EPICKitchens dataloader.\"\"\"\n",
    "    def __init__(\n",
    "            self,\n",
    "            annotation_path: Sequence[Path],\n",
    "            only_keep_persons: str = None,\n",
    "            only_keep_videos: Path = None,\n",
    "            action_labels_fpath: Path = None,\n",
    "            annotation_dir: Path = None,\n",
    "            rulstm_annotation_dir: Path = None,\n",
    "            _precomputed_metadata: Path = None,\n",
    "            version: float = EPIC55_VERSION,\n",
    "            video_info_path: Path = None,\n",
    "            frame_root: str = '/',\n",
    "            process_inorder = False,\n",
    "            use_timestamps = True,\n",
    "            **other_kwargs,\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            label_type (str): The type of label to return\n",
    "            only_keep_persons (str): If None, ignore. Else, will only keep\n",
    "                videos of persons P<start> to P<end> (both included), where this\n",
    "                string is \"<start>-<end>\". This is used to create\n",
    "                the train_minus_val and val sets, as per\n",
    "                https://arxiv.org/abs/1806.06157\n",
    "            only_keep_videos (Path): Path to a file with list of videos to keep.\n",
    "                This was used to define the val set as used in anticipation\n",
    "                in https://arxiv.org/abs/1905.09035\n",
    "            action_labels_fpath (Path): Path to map the verb and noun labels to\n",
    "                actions. It was used in the anticipation paper, that defines\n",
    "                a set of actions and train for action prediction, as opposed\n",
    "                to verb and noun prediction.\n",
    "            annotation_dir: Where all the other annotations are typically stored\n",
    "        \"\"\"\n",
    "        self.version = version\n",
    "        self.frame_root = frame_root\n",
    "        self.video_info = pd.read_csv(video_info_path)\n",
    "        self.video_to_fps = self.video_info.set_index('video_id')['fps'].to_dict()\n",
    "        self.use_timestamps = use_timestamps\n",
    "        df = pd.concat([self._load_df(el) for el in annotation_path])\n",
    "        df.reset_index(inplace=True, drop=True)  # to combine all of them\n",
    "        df = self._subselect_df_by_videos(\n",
    "            self._subselect_df_by_person(df, only_keep_persons),\n",
    "            only_keep_videos)\n",
    "        # If no specific annotation_dir specified, use the parent dir of\n",
    "        # the first annot path\n",
    "        if annotation_dir is None:\n",
    "            self.annotation_dir = Path(annotation_path[0]).parent\n",
    "        else:\n",
    "            self.annotation_dir = Path(annotation_dir)\n",
    "        self.rulstm_annotation_dir = rulstm_annotation_dir\n",
    "        epic_postfix = ''\n",
    "        if self.version == EPIC100_VERSION:\n",
    "            epic_postfix = '_100'\n",
    "        if self.version != EGTEA_VERSION:\n",
    "            verb_classes = self._load_class_names(\n",
    "                self.annotation_dir / f'EPIC{epic_postfix}_verb_classes.csv')\n",
    "            noun_classes = self._load_class_names(\n",
    "                self.annotation_dir / f'EPIC{epic_postfix}_noun_classes.csv')\n",
    "        else:\n",
    "            verb_classes, noun_classes = [], []\n",
    "        # Create action classes\n",
    "        if action_labels_fpath is not None:\n",
    "            load_action_fn = self._load_action_classes\n",
    "            if self.version == EGTEA_VERSION:\n",
    "                load_action_fn = self._load_action_classes_egtea\n",
    "            action_classes, verb_noun_to_action = (\n",
    "                load_action_fn(action_labels_fpath))\n",
    "            \n",
    "            # add in mapping from verb and noun to all actions containing them\n",
    "            actions_df = pd.read_csv(action_labels_fpath)\n",
    "            noun_to_actions = {}\n",
    "            verb_to_actions = {}\n",
    "\n",
    "            for i, row in actions_df.iterrows():\n",
    "                noun = row[\"noun\"]\n",
    "                if noun not in noun_to_actions:\n",
    "                    noun_to_actions[noun] = []\n",
    "                noun_to_actions[noun].append(row[\"id\"])\n",
    "\n",
    "            for i, row in actions_df.iterrows():\n",
    "                verb = row[\"verb\"]\n",
    "                if verb not in verb_to_actions:\n",
    "                    verb_to_actions[verb] = []\n",
    "                verb_to_actions[verb].append(row[\"id\"])\n",
    "            \n",
    "            self.noun_to_actions = noun_to_actions\n",
    "            self.verb_to_actions = verb_to_actions\n",
    "        else:\n",
    "            action_classes, verb_noun_to_action = self._gen_all_actions(\n",
    "                verb_classes, noun_classes)\n",
    "        # Add the action classes to the data frame\n",
    "        if ('action_class' not in df.columns\n",
    "                and {'noun_class', 'verb_class'}.issubset(df.columns)):\n",
    "            df.loc[:, 'action_class'] = df.loc[:, (\n",
    "                'verb_class', 'noun_class')].apply(\n",
    "                    lambda row: (verb_noun_to_action[\n",
    "                        (row.at['verb_class'], row.at['noun_class'])]\n",
    "                                 if (row.at['verb_class'], row.at['noun_class']\n",
    "                                     ) in verb_noun_to_action else -1),\n",
    "                    axis=1)\n",
    "        elif 'action_class' not in df.columns:\n",
    "            df.loc[:, 'action_class'] = -1\n",
    "            df.loc[:, 'verb_class'] = -1\n",
    "            df.loc[:, 'noun_class'] = -1\n",
    "        num_undefined_actions = len(df[df['action_class'] == -1].index)\n",
    "        if num_undefined_actions > 0:\n",
    "            logging.error(\n",
    "                'Did not found valid action label for %d/%d samples!',\n",
    "                num_undefined_actions, len(df))\n",
    "        assert _precomputed_metadata is None, 'Not supported yet'\n",
    "        other_kwargs['verb_classes'] = verb_classes\n",
    "        other_kwargs['noun_classes'] = noun_classes\n",
    "        other_kwargs['action_classes'] = action_classes\n",
    "        super().__init__(df, **other_kwargs)\n",
    "        \n",
    "        if process_inorder:\n",
    "            self.df = self.df.sort_values(by=['video_id', 'start_timestamp'])\n",
    "\n",
    "            # Resetting the indices\n",
    "            self.df = self.df.reset_index(drop=True)\n",
    "\n",
    "        # following is used in the notebooks for marginalization, so save it\n",
    "        self.verb_noun_to_action = verb_noun_to_action\n",
    "        logging.info('Created EPIC %s dataset with %d samples', self.version,\n",
    "                     len(self))\n",
    "\n",
    "    @property\n",
    "    def primary_metric(self) -> str:\n",
    "        if self.version == EPIC100_VERSION:\n",
    "            # For EK100, we want to optimize for AR5\n",
    "            return 'final_acc/action/AR5'\n",
    "        return super().primary_metric\n",
    "\n",
    "    @property\n",
    "    def class_mappings(self) -> Dict[Tuple[str, str], torch.FloatTensor]:\n",
    "        num_verbs = len(self.verb_classes)\n",
    "        if num_verbs == 0:\n",
    "            num_verbs = len(\n",
    "                set([el[0] for el, _ in self.verb_noun_to_action.items()]))\n",
    "        num_nouns = len(self.noun_classes)\n",
    "        if num_nouns == 0:\n",
    "            num_nouns = len(\n",
    "                set([el[1] for el, _ in self.verb_noun_to_action.items()]))\n",
    "        num_actions = len(self.action_classes)\n",
    "        if num_actions == 0:\n",
    "            num_actions = len(\n",
    "                set([el for _, el in self.verb_noun_to_action.items()]))\n",
    "        verb_in_action = torch.zeros((num_actions, num_verbs),\n",
    "                                     dtype=torch.float)\n",
    "        noun_in_action = torch.zeros((num_actions, num_nouns),\n",
    "                                     dtype=torch.float)\n",
    "        for (verb, noun), action in self.verb_noun_to_action.items():\n",
    "            verb_in_action[action, verb] = 1.0\n",
    "            noun_in_action[action, noun] = 1.0\n",
    "        return {\n",
    "            ('verb', 'action'): verb_in_action,\n",
    "            ('noun', 'action'): noun_in_action\n",
    "        }\n",
    "\n",
    "    @property\n",
    "    def classes_manyshot(self) -> OrderedDict:\n",
    "        \"\"\"\n",
    "        In EPIC-55, the recall computation was done for \"many shot\" classes,\n",
    "        and not for all classes. So, for that version read the class names as\n",
    "        provided by RULSTM.\n",
    "        Function adapted from\n",
    "        https://github.com/fpv-iplab/rulstm/blob/57842b27d6264318be2cb0beb9e2f8c2819ad9bc/RULSTM/main.py#L386\n",
    "        \"\"\"\n",
    "        if self.version != EPIC55_VERSION:\n",
    "            return super().classes_manyshot\n",
    "        # read the list of many shot verbs\n",
    "        many_shot_verbs = {\n",
    "            el['verb']: el['verb_class']\n",
    "            for el in pd.read_csv(self.annotation_dir /\n",
    "                                  'EPIC_many_shot_verbs.csv').to_dict(\n",
    "                                      'records')\n",
    "        }\n",
    "        # read the list of many shot nouns\n",
    "        many_shot_nouns = {\n",
    "            el['noun']: el['noun_class']\n",
    "            for el in pd.read_csv(self.annotation_dir /\n",
    "                                  'EPIC_many_shot_nouns.csv').to_dict(\n",
    "                                      'records')\n",
    "        }\n",
    "        # create the list of many shot actions\n",
    "        # an action is \"many shot\" if at least one\n",
    "        # between the related verb and noun are many shot\n",
    "        many_shot_actions = {}\n",
    "        action_names = {val: key for key, val in self.action_classes.items()}\n",
    "        for (verb_id, noun_id), action_id in self.verb_noun_to_action.items():\n",
    "            if (verb_id in many_shot_verbs.values()) or (\n",
    "                    noun_id in many_shot_nouns.values()):\n",
    "                many_shot_actions[action_names[action_id]] = action_id\n",
    "        return {\n",
    "            'verb': many_shot_verbs,\n",
    "            'noun': many_shot_nouns,\n",
    "            'action': many_shot_actions,\n",
    "        }\n",
    "\n",
    "    @staticmethod\n",
    "    def _load_action_classes(\n",
    "            action_labels_fpath: Path\n",
    "    ) -> Tuple[Dict[str, int], Dict[Tuple[int, int], int]]:\n",
    "        \"\"\"\n",
    "        Given a CSV file with the actions (as from RULSTM paper), construct\n",
    "        the set of actions and mapping from verb/noun to action\n",
    "        Args:\n",
    "            action_labels_fpath: path to the file\n",
    "        Returns:\n",
    "            class_names: Dict of action class names\n",
    "            verb_noun_to_action: Mapping from verb/noun to action IDs\n",
    "        \"\"\"\n",
    "        class_names = {}\n",
    "        verb_noun_to_action = {}\n",
    "        with open(action_labels_fpath, 'r') as fin:\n",
    "            reader = csv.DictReader(fin, delimiter=',')\n",
    "            for lno, line in enumerate(reader):\n",
    "                class_names[line['action']] = lno\n",
    "                verb_noun_to_action[(int(line['verb']),\n",
    "                                     int(line['noun']))] = int(line['id'])\n",
    "#         print(class_names)\n",
    "#         print(len(class_names))\n",
    "        return class_names, verb_noun_to_action\n",
    "\n",
    "    @staticmethod\n",
    "    def _load_action_classes_egtea(\n",
    "            action_labels_fpath: Path\n",
    "    ) -> Tuple[Dict[str, int], Dict[Tuple[int, int], int]]:\n",
    "        \"\"\"\n",
    "        Given a CSV file with the actions (as from RULSTM paper), construct\n",
    "        the set of actions and mapping from verb/noun to action\n",
    "        Args:\n",
    "            action_labels_fpath: path to the file\n",
    "        Returns:\n",
    "            class_names: Dict of action class names\n",
    "            verb_noun_to_action: Mapping from verb/noun to action IDs\n",
    "        \"\"\"\n",
    "        class_names = {}\n",
    "        verb_noun_to_action = {}\n",
    "        with open(action_labels_fpath, 'r') as fin:\n",
    "            reader = csv.DictReader(\n",
    "                fin,\n",
    "                delimiter=',',\n",
    "                # Assuming the order is verb/noun\n",
    "                # TODO check if that is correct\n",
    "                fieldnames=['id', 'verb_noun', 'action'])\n",
    "            for lno, line in enumerate(reader):\n",
    "                class_names[line['action']] = lno\n",
    "                verb, noun = [int(el) for el in line['verb_noun'].split('_')]\n",
    "                verb_noun_to_action[(verb, noun)] = int(line['id'])\n",
    "        return class_names, verb_noun_to_action\n",
    "\n",
    "    @staticmethod\n",
    "    def _gen_all_actions(\n",
    "            verb_classes: List[str], noun_classes: List[str]\n",
    "    ) -> Tuple[Dict[str, int], Dict[Tuple[int, int], int]]:\n",
    "        \"\"\"\n",
    "        Given all possible verbs and nouns, construct all possible actions\n",
    "        Args:\n",
    "            verb_classes: All verbs\n",
    "            noun_classes: All nouns\n",
    "        Returns:\n",
    "            class_names: list of action class names\n",
    "            verb_noun_to_action: Mapping from verb/noun to action IDs\n",
    "        \"\"\"\n",
    "        class_names = {}\n",
    "        verb_noun_to_action = {}\n",
    "        action_id = 0\n",
    "        for verb_id, verb_cls in enumerate(verb_classes):\n",
    "            for noun_id, noun_cls in enumerate(noun_classes):\n",
    "                class_names[f'{verb_cls}:{noun_cls}'] = action_id\n",
    "                verb_noun_to_action[(verb_id, noun_id)] = action_id\n",
    "                action_id += 1\n",
    "        return class_names, verb_noun_to_action\n",
    "\n",
    "    def _load_class_names(self, annot_path: Path):\n",
    "        res = {}\n",
    "        with open(annot_path, 'r') as fin:\n",
    "            reader = csv.DictReader(fin, delimiter=',')\n",
    "            for lno, line in enumerate(reader):\n",
    "                res[line['class_key' if self.version ==\n",
    "                         EPIC55_VERSION else 'key']] = lno\n",
    "        return res\n",
    "\n",
    "    def _load_df(self, annotation_path):\n",
    "        if annotation_path.endswith('.pkl'):\n",
    "            return self._init_df_orig(annotation_path)\n",
    "        elif annotation_path.endswith('.csv'):\n",
    "            # Else, it must be the RULSTM annotations (which are a\n",
    "            # little different, perhaps due to quantization into frames)\n",
    "            return self._init_df_rulstm(annotation_path)\n",
    "        else:\n",
    "            raise NotImplementedError(annotation_path)\n",
    "\n",
    "    def _init_df_gen_vidpath(self, df):\n",
    "        # generate video_path\n",
    "        if self.version == EGTEA_VERSION:\n",
    "            df.loc[:, 'video_path'] = df.apply(\n",
    "                lambda x: Path(x.video_id + '.mp4'),\n",
    "                axis=1,\n",
    "            )\n",
    "        else:  # For the EPIC datasets\n",
    "            df.loc[:, 'video_path'] = df.apply(\n",
    "                lambda x: (Path(x.participant_id) / Path('videos/' + x.video_id + '.MP4')),\n",
    "                axis=1,\n",
    "            )\n",
    "            df.loc[:, 'frame_path'] = df.apply(\n",
    "                lambda x: Path(self.frame_root) / Path(x.video_id),\n",
    "                axis=1\n",
    "            )\n",
    "            df.loc[:, 'video_fps'] = df.apply(\n",
    "                lambda x: self.video_to_fps[x.video_id],\n",
    "                axis=1\n",
    "            )\n",
    "        return df\n",
    "\n",
    "    def _init_df_rulstm(self, annotation_path):\n",
    "        logging.info('Loading RULSTM EPIC csv annotations %s', annotation_path)\n",
    "        df = pd.read_csv(\n",
    "            annotation_path,\n",
    "            names=[\n",
    "                'uid',\n",
    "                'video_id',\n",
    "                'start_frame_30fps',\n",
    "                'end_frame_30fps',\n",
    "                'verb_class',\n",
    "                'noun_class',\n",
    "                'action_class',\n",
    "            ],\n",
    "            index_col=0,\n",
    "            skipinitialspace=True,\n",
    "            dtype={\n",
    "                'uid': str,  # In epic-100, this is a str\n",
    "                'video_id': str,\n",
    "                'start_frame_30fps': int,\n",
    "                'end_frame_30fps': int,\n",
    "                'verb_class': int,\n",
    "                'noun_class': int,\n",
    "                'action_class': int,\n",
    "            })\n",
    "        # Make a copy of the UID column, since that will be needed to gen\n",
    "        # output files\n",
    "        df.reset_index(drop=False, inplace=True)\n",
    "        # Convert the frame number to start and end\n",
    "        df.loc[:, 'start'] = df.loc[:, 'start_frame_30fps'].apply(\n",
    "            lambda x: x / RULSTM_TSN_FPS)\n",
    "        df.loc[:, 'end'] = df.loc[:, 'end_frame_30fps'].apply(\n",
    "            lambda x: x / RULSTM_TSN_FPS)\n",
    "        # Participant ID from video_id\n",
    "        df.loc[:, 'participant_id'] = df.loc[:, 'video_id'].apply(\n",
    "            lambda x: x.split('_')[0])\n",
    "        df = self._init_df_gen_vidpath(df)\n",
    "        df.reset_index(inplace=True, drop=True)\n",
    "        return df\n",
    "\n",
    "    def _init_df_orig(self, annotation_path):\n",
    "        \"\"\"\n",
    "        Loading the original EPIC Kitchens annotations\n",
    "        \"\"\"\n",
    "        def timestr_to_sec(s, fmt='%H:%M:%S.%f'):\n",
    "            timeobj = datetime.strptime(s, fmt).time()\n",
    "            td = datetime.combine(date.min, timeobj) - datetime.min\n",
    "            return td.total_seconds()\n",
    "\n",
    "        # Load the DF from annot path\n",
    "        logging.info('Loading original EPIC pkl annotations %s',\n",
    "                     annotation_path)\n",
    "        with open(annotation_path, 'rb') as fin:\n",
    "            df = pkl.load(fin)\n",
    "        # Make a copy of the UID column, since that will be needed to gen\n",
    "        # output files\n",
    "        df.reset_index(drop=False, inplace=True)\n",
    "\n",
    "        # parse timestamps from the video\n",
    "        if self.use_timestamps:\n",
    "            df.loc[:, 'start'] = df.start_timestamp.apply(timestr_to_sec)\n",
    "            df.loc[:, 'end'] = df.stop_timestamp.apply(timestr_to_sec)\n",
    "\n",
    "        # original annotations have text in weird format - fix that\n",
    "        if 'noun' in df.columns:\n",
    "            df.loc[:, 'noun'] = df.loc[:, 'noun'].apply(\n",
    "                lambda s: ' '.join(s.replace(':', ' ').split(sep=' ')[::-1]))\n",
    "        if 'verb' in df.columns:\n",
    "            df.loc[:, 'verb'] = df.loc[:, 'verb'].apply(\n",
    "                lambda s: ' '.join(s.replace('-', ' ').split(sep=' ')))\n",
    "        df = self._init_df_gen_vidpath(df)\n",
    "        df.reset_index(inplace=True, drop=True)\n",
    "        return df\n",
    "\n",
    "    @staticmethod\n",
    "    def _subselect_df_by_person(df, only_keep_persons):\n",
    "        if only_keep_persons is None:\n",
    "            return df\n",
    "        start, end = [int(el) for el in only_keep_persons.split('-')]\n",
    "        df = df.loc[df['participant_id'].isin(\n",
    "            ['P{:02d}'.format(el) for el in range(start, end + 1)]), :]\n",
    "        df.reset_index(inplace=True, drop=True)\n",
    "        return df\n",
    "\n",
    "    @staticmethod\n",
    "    def _subselect_df_by_videos(df, videos_fpath):\n",
    "        if videos_fpath is None:\n",
    "            return df\n",
    "        with open(videos_fpath, 'r') as fin:\n",
    "            videos_to_keep = [el.strip() for el in fin.read().splitlines()]\n",
    "        df = df.loc[df['video_id'].isin(videos_to_keep), :]\n",
    "        df.reset_index(inplace=True, drop=True)\n",
    "        return df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3b1ff948",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "import warnings\n",
    "from decord import VideoReader, cpu\n",
    "from torch.utils.data import Dataset\n",
    "from PIL import Image\n",
    "\n",
    "from augmentations import video_transforms as video_transforms\n",
    "from augmentations import volume_transforms as volume_transforms\n",
    "from augmentations.random_erasing import RandomErasing\n",
    "import einops\n",
    "\n",
    "class EpicKitchensWrapper(Dataset):\n",
    "    \"\"\"\n",
    "    Custom dataloader for EpicKitchens. Given dataset object that loads EpicKitchens videos and labels, and the mode (train, test, val), \n",
    "    return resized and augmented samples from the dataset. Adapted from: https://github.com/MCG-NJU/VideoMAE/blob/main/kinetics.py\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, dataset, num_sample, \n",
    "                frame_sample_rate=4, mode='train', crop_size=224, short_side_size=240,\n",
    "                test_num_segment=5, test_num_crop=3, num_aug_sample=1, args=None, label_file=None, \n",
    "                load_from=\"video\", image_tmpl='img_{:05d}.jpg', index_bias=1, included_classes=None, \n",
    "                streaming=False, streaming_max_num_prev=4, **kwargs):\n",
    "        \"\"\"\n",
    "        Params:\n",
    "        dataset_path: str\n",
    "            Path to the folder containing Kinetics videos\n",
    "        label_path: str\n",
    "            Path to the folder containing kinetics_video_train_labels.txt and kinetics_video_val_labels.txt\n",
    "        num_sample: int\n",
    "            Number of frames to sample per clip\n",
    "        frame_sample_rate: int\n",
    "            Rate at which to sample frames (i.e. frame_sample_rate=2 would get frames 0, 2, 4, ...)\n",
    "        mode: str\n",
    "            One of \"train\", \"val\"\n",
    "        crop_size: int\n",
    "            Size to crop each frame\n",
    "        short_side_size: int\n",
    "            Size of the shorter side of the videos\n",
    "        test_num_segment: int\n",
    "            Number of temporal views to sample at test time\n",
    "        test_num_crop: int\n",
    "            Number of spatial views to sample at test time - uniformly sampled along the longer dimension\n",
    "            of the input video\n",
    "        num_aug_sample: int\n",
    "            Number of times to augment the input video (repeated augmentation)\n",
    "        args: argparse.Namespace\n",
    "            Arguments to pass to the augmentations\n",
    "        label_file: Optional[str]\n",
    "            Prefix for the file containing the labels - if None, assumed to be equal to mode. Used for the case where we want to\n",
    "            use validation/test view logic on the training set\n",
    "        load_from: str\n",
    "            Whether paths should contain video or image data to load from. \"video\" to use decord, or \"rgb\" to\n",
    "            use proprocessed frames\n",
    "        image_tmpl: str\n",
    "            String template to use for when loading images from preprocessed frames\n",
    "        index_bias: int\n",
    "            Value to add to each sampling index when loading images from preprocessed frames\n",
    "        \"\"\"        \n",
    "        self.ek_dataset = dataset\n",
    "    \n",
    "        self.short_side_size = short_side_size\n",
    "        self.frame_sample_rate = frame_sample_rate\n",
    "        self.crop_size = crop_size\n",
    "        self.mode = mode\n",
    "        self.num_sample = num_sample\n",
    "        self.test_num_segment = test_num_segment\n",
    "        self.test_num_crop = test_num_crop\n",
    "        self.num_aug_sample = num_aug_sample\n",
    "        self.args = args\n",
    "        self.load_from = load_from\n",
    "        self.image_tmpl = image_tmpl\n",
    "        self.index_bias = index_bias\n",
    "        self.streaming = streaming\n",
    "        self.streaming_max_num_prev = streaming_max_num_prev\n",
    "\n",
    "        if (mode == \"train\"):\n",
    "            assert self.args != None, \"Must pass arguments to augmentations\"\n",
    "            self.rand_erase = False\n",
    "            if self.args.reprob > 0:\n",
    "                self.rand_erase = True\n",
    "\n",
    "        elif mode == 'test' or mode == 'val':\n",
    "            self.data_resize = transforms.Compose([\n",
    "                video_transforms.Resize(size=(self.short_side_size), interpolation='bilinear')\n",
    "            ])\n",
    "            self.data_transform = transforms.Compose([\n",
    "                volume_transforms.ClipToTensor(),\n",
    "                video_transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n",
    "                                        std=[0.26862954, 0.26130258, 0.27577711])\n",
    "            ])\n",
    "            self.test_seg = []\n",
    "            self.test_indices = []\n",
    "            for ck in range(self.test_num_segment):\n",
    "                for cp in range(self.test_num_crop):\n",
    "                    for idx in range(len(self.ek_dataset)):\n",
    "                        self.test_indices.append(idx)\n",
    "                        self.test_seg.append((ck, cp))\n",
    "\n",
    "    def __len__(self):\n",
    "        if self.mode == \"train\":\n",
    "            return len(self.ek_dataset)\n",
    "        else:\n",
    "            return len(self.test_indices)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        if self.mode == 'train':\n",
    "            sample = self.ek_dataset.__getitem__(index)\n",
    "            buffer = sample[\"video\"]\n",
    "            label = sample[\"target\"][\"action\"]\n",
    "            \n",
    "            video_id = sample[\"video_id\"]\n",
    "            \n",
    "            noun_label = sample[\"target\"][\"noun\"]\n",
    "            verb_label = sample[\"target\"][\"verb\"]\n",
    "            \n",
    "            buffer = einops.rearrange(buffer, \"w t a h c -> t (a w) h c\", a=1)\n",
    "            \n",
    "            if len(buffer) == 0:\n",
    "                while len(buffer) == 0:\n",
    "                    warnings.warn(\"video {} not correctly loaded during validation\".format(sample))\n",
    "                    index = np.random.randint(self.__len__())\n",
    "                    sample = self.ek_dataset.__getitem__(index)\n",
    "                    buffer = sample[\"video\"]\n",
    "                    label = sample[\"target\"][\"action\"]\n",
    "                    \n",
    "                    video_id = sample[\"video_id\"]\n",
    "                    \n",
    "                    noun_label = sample[\"target\"][\"noun\"]\n",
    "                    verb_label = sample[\"target\"][\"verb\"]\n",
    "                    \n",
    "                    buffer = einops.rearrange(buffer, \"w t a h c -> t (a w) h c\", a=1)\n",
    "            \n",
    "            last_label = []\n",
    "            if isinstance(label, list) and -1 in label:\n",
    "                labels = torch.zeros(3806)\n",
    "            else:\n",
    "                labels = torch.nn.functional.one_hot(torch.LongTensor([label]), num_classes=3806).squeeze()\n",
    "                if len(labels.shape) == 1:\n",
    "                    last_label = labels\n",
    "                else:\n",
    "                    last_label = labels[-1]\n",
    "            \n",
    "            if self.streaming:\n",
    "                end_token_idxs = sample[\"end_token_idxs\"]\n",
    "            if self.streaming and len(labels.shape) == 1:\n",
    "                labels = labels.unsqueeze(0)\n",
    "            if self.streaming and labels.shape[0] < self.streaming_max_num_prev:\n",
    "                full_labels = torch.zeros(self.streaming_max_num_prev, 3806, dtype=torch.long)\n",
    "                for i in range(labels.shape[0]):\n",
    "                    full_labels[i] = labels[i]\n",
    "                for i in range(self.streaming_max_num_prev - labels.shape[0]):\n",
    "                    end_token_idxs.append(-1)\n",
    "                labels = full_labels\n",
    "\n",
    "            video_indices = index\n",
    "            if self.streaming:\n",
    "                video_indices = []\n",
    "                for i in range(len(end_token_idxs)):\n",
    "                    if end_token_idxs[i] == -1:\n",
    "                        video_indices.append(-1)\n",
    "                    else:\n",
    "                        video_indices.append(index)\n",
    "\n",
    "            if self.num_aug_sample > 1:\n",
    "                frame_list = []\n",
    "                label_list = []\n",
    "                index_list = []\n",
    "                path_list = []\n",
    "                verb_list = []\n",
    "                noun_list = []\n",
    "                last_label_list = []\n",
    "                for _ in range(self.num_aug_sample):\n",
    "                    new_frames = self._aug_frame(buffer, self.args)\n",
    "                    frame_list.append(new_frames)\n",
    "                    label_list.append(labels)\n",
    "                    verb_list.append(verb_label)\n",
    "                    noun_list.append(verb_label)\n",
    "\n",
    "                    index_list.append(index)\n",
    "                    path_list.append(video_id)    \n",
    "                    last_label_list.append(last_label)            \n",
    "                return {\n",
    "                    \"video_features\": frame_list,\n",
    "                    \"labels\": label_list,\n",
    "                    \"video_indices\": index_list,\n",
    "                    \"video_path\": path_list,\n",
    "                    \"verb_label\": verb_list,\n",
    "                    \"noun_label\": noun_list,\n",
    "                    \"end_token_idxs\": torch.LongTensor(sample[\"end_token_idxs\"]) if \"end_token_idxs\" in sample else None,\n",
    "                    \"last_label\": last_label_list\n",
    "                }\n",
    "            else:\n",
    "                buffer = self._aug_frame(buffer, self.args)\n",
    "            return {\n",
    "                \"video_features\": buffer, \n",
    "                \"labels\": labels,\n",
    "                \"video_indices\": index if 'end_token_idxs' not in sample else [index for _ in range(num_labels)],\n",
    "                \"video_path\": video_id,\n",
    "                \"end_token_idxs\": torch.LongTensor(end_token_idxs) if \"end_token_idxs\" in sample else [-1],\n",
    "                \"last_label\": last_label\n",
    "            }\n",
    "\n",
    "        elif self.mode == 'test' or self.mode == 'val':\n",
    "            dataset_idx = self.test_indices[index]\n",
    "            sample = self.ek_dataset.__getitem__(dataset_idx)\n",
    "            chunk_nb, split_nb = self.test_seg[index]\n",
    "            \n",
    "            buffer = sample[\"video\"]\n",
    "            buffer = einops.rearrange(buffer, \"w t a h c -> t (a w) h c\", a=1)\n",
    "            label = sample[\"target\"][\"action\"]\n",
    "            \n",
    "            video_id = sample[\"video_id\"]\n",
    "            \n",
    "            noun_label = sample[\"target\"][\"noun\"]\n",
    "            verb_label = sample[\"target\"][\"verb\"]\n",
    "            \n",
    "\n",
    "            while len(buffer) == 0:\n",
    "                warnings.warn(\"video {}, temporal {}, spatial {} not found during testing\".format(\\\n",
    "                    str(self.test_indices[index]), chunk_nb, split_nb))\n",
    "                index = np.random.randint(self.__len__())\n",
    "                dataset_idx = self.test_indices[index]\n",
    "                sample = self.ek_dataset.__getitem__(dataset_idx)\n",
    "                chunk_nb, split_nb = self.test_seg[index]\n",
    "                buffer = sample[\"video\"]\n",
    "                buffer = einops.rearrange(buffer, \"w t a h c -> t (a w) h c\", a=1)\n",
    "                label = sample[\"target\"][\"action\"]\n",
    "                \n",
    "                video_id = sample[\"video_id\"]\n",
    "\n",
    "                noun_label = sample[\"target\"][\"noun\"]\n",
    "                verb_label = sample[\"target\"][\"verb\"]\n",
    "            \n",
    "                \n",
    "            if self.test_num_segment == 1:\n",
    "                temporal_start = max((buffer.shape[0] - self.num_sample) // 2, 0)\n",
    "            else:\n",
    "                raise Error(\"num_segment > 1 not yet supported for EpicKitchens (we use center temporal clip)\")\n",
    "#                 temporal_step = max(1.0 * (buffer.shape[0] - self.num_sample) \\\n",
    "#                                     / (self.test_num_segment - 1), 0)\n",
    "#                 temporal_start = int(chunk_nb * temporal_step)\n",
    "            buffer = buffer[temporal_start:temporal_start + self.num_sample, :, :, :]\n",
    "            \n",
    "            buffer = self.data_resize(buffer.numpy())\n",
    "            if isinstance(buffer, list):\n",
    "                buffer = np.stack(buffer, 0)\n",
    "\n",
    "            if self.test_num_crop == 1:\n",
    "                spatial_step = self.short_side_size\n",
    "                if buffer.shape[1] >= buffer.shape[2]:\n",
    "                    spatial_start = (buffer.shape[1] - buffer.shape[2]) // 2\n",
    "                else:\n",
    "                    spatial_start = (buffer.shape[2] - buffer.shape[1]) // 2\n",
    "            else:\n",
    "                spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \\\n",
    "                                / (self.test_num_crop - 1)\n",
    "                spatial_start = int(split_nb * spatial_step)\n",
    "                # print(spatial_step, spatial_start)\n",
    "\n",
    "            if buffer.shape[1] >= buffer.shape[2]:\n",
    "                buffer = buffer[:, spatial_start:spatial_start + self.short_side_size, :, :]\n",
    "            else:\n",
    "                buffer = buffer[:, :, spatial_start:spatial_start + self.short_side_size, :]\n",
    "            buffer = self.data_transform(buffer)            \n",
    "            \n",
    "            if isinstance(label, list) and -1 in label:\n",
    "                labels = torch.zeros(3806)\n",
    "            else:\n",
    "                labels = torch.nn.functional.one_hot(torch.LongTensor([label]), num_classes=3806).squeeze()\n",
    "            if self.streaming:\n",
    "                end_token_idxs = sample[\"end_token_idxs\"]\n",
    "            if self.streaming and len(labels.shape) == 1 and self.streaming_max_num_prev > 1:\n",
    "                labels = labels.unsqueeze(0)\n",
    "            if self.streaming and len(labels.shape) > 1 and labels.shape[0] < self.streaming_max_num_prev:\n",
    "                full_labels = torch.zeros(self.streaming_max_num_prev, 3806, dtype=torch.long)\n",
    "                for i in range(labels.shape[0]):\n",
    "                    full_labels[i] = labels[i]\n",
    "                for i in range(self.streaming_max_num_prev - labels.shape[0]):\n",
    "                    end_token_idxs.append(-1)\n",
    "                labels = full_labels\n",
    "                \n",
    "            video_indices = self.test_indices[index]\n",
    "            if self.streaming:\n",
    "                video_indices = []\n",
    "                for i in range(len(end_token_idxs)):\n",
    "                    if end_token_idxs[i] == -1:\n",
    "                        video_indices.append(-1)\n",
    "                    else:\n",
    "                        video_indices.append(self.test_indices[index])\n",
    "\n",
    "                \n",
    "\n",
    "            return {\n",
    "                \"video_features\": buffer, \n",
    "                \"labels\": labels,\n",
    "                \"video_indices\": video_indices, \n",
    "                \"video_path\": video_id,\n",
    "                \"chunk_nbs\": chunk_nb, \"split_nbs\": split_nb,\n",
    "                \"end_token_idxs\": torch.LongTensor(end_token_idxs) if \"end_token_idxs\" in sample else [-1],\n",
    "            }\n",
    "        else:\n",
    "            raise NameError('mode {} unkown'.format(self.mode))\n",
    "\n",
    "    def _aug_frame(\n",
    "        self,\n",
    "        buffer,\n",
    "        args,\n",
    "    ):\n",
    "        if args.aa != 'None':\n",
    "            aug_transform = video_transforms.create_random_augment(\n",
    "                input_size=(self.crop_size, self.crop_size),\n",
    "                auto_augment=args.aa,\n",
    "                interpolation=args.train_interpolation,\n",
    "            )\n",
    "\n",
    "        # T H W C \n",
    "        buffer = tensor_normalize(\n",
    "            buffer, [0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]\n",
    "        )\n",
    "        # T H W C -> C T H W.\n",
    "        buffer = buffer.permute(3, 0, 1, 2)\n",
    "        # Perform data augmentation.\n",
    "        scl, asp = (\n",
    "            [0.08, 1.0],\n",
    "            [0.75, 1.3333],\n",
    "        )\n",
    "        buffer = spatial_sampling(\n",
    "            buffer,\n",
    "            spatial_idx=-1,\n",
    "            min_scale=256,\n",
    "            max_scale=320,\n",
    "            crop_size=self.crop_size,\n",
    "            random_horizontal_flip=True,\n",
    "            inverse_uniform_sampling=False,\n",
    "            aspect_ratio=asp,\n",
    "            scale=scl,\n",
    "            motion_shift=False\n",
    "        )\n",
    "\n",
    "        if self.rand_erase:\n",
    "            erase_transform = RandomErasing(\n",
    "                args.reprob,\n",
    "                mode=args.remode,\n",
    "                max_count=args.recount,\n",
    "                num_splits=args.recount,\n",
    "                device=\"cpu\",\n",
    "            )\n",
    "            buffer = buffer.permute(1, 0, 2, 3)\n",
    "            buffer = erase_transform(buffer)\n",
    "            buffer = buffer.permute(1, 0, 2, 3)\n",
    "\n",
    "        return buffer\n",
    "    \n",
    "    def _load_image(self, directory, idx):\n",
    "        try:\n",
    "            return [np.asarray(Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB'))]\n",
    "        except Exception:\n",
    "            print('error loading image:', os.path.join(directory, self.image_tmpl.format(idx)))\n",
    "            return [np.asarray(Image.open(os.path.join(directory, self.image_tmpl.format(1))).convert('RGB'))]\n",
    "\n",
    "\n",
    "def tensor_normalize(tensor, mean, std):\n",
    "    \"\"\"\n",
    "    Normalize a given tensor by subtracting the mean and dividing the std.\n",
    "    Args:\n",
    "        tensor (tensor): tensor to normalize.\n",
    "        mean (tensor or list): mean value to subtract.\n",
    "        std (tensor or list): std to divide.\n",
    "    \"\"\"\n",
    "    if tensor.dtype == torch.uint8:\n",
    "        tensor = tensor.float()\n",
    "        tensor = tensor / 255.0\n",
    "    if type(mean) == list:\n",
    "        mean = torch.tensor(mean)\n",
    "    if type(std) == list:\n",
    "        std = torch.tensor(std)\n",
    "    tensor = tensor - mean\n",
    "    tensor = tensor / std\n",
    "    return tensor\n",
    "\n",
    "def spatial_sampling(\n",
    "    frames,\n",
    "    spatial_idx=-1,\n",
    "    min_scale=256,\n",
    "    max_scale=320,\n",
    "    crop_size=224,\n",
    "    random_horizontal_flip=True,\n",
    "    inverse_uniform_sampling=False,\n",
    "    aspect_ratio=None,\n",
    "    scale=None,\n",
    "    motion_shift=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Perform spatial sampling on the given video frames. If spatial_idx is\n",
    "    -1, perform random scale, random crop, and random flip on the given\n",
    "    frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling\n",
    "    with the given spatial_idx.\n",
    "    Args:\n",
    "        frames (tensor): frames of images sampled from the video. The\n",
    "            dimension is `num frames` x `height` x `width` x `channel`.\n",
    "        spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,\n",
    "            or 2, perform left, center, right crop if width is larger than\n",
    "            height, and perform top, center, buttom crop if height is larger\n",
    "            than width.\n",
    "        min_scale (int): the minimal size of scaling.\n",
    "        max_scale (int): the maximal size of scaling.\n",
    "        crop_size (int): the size of height and width used to crop the\n",
    "            frames.\n",
    "        inverse_uniform_sampling (bool): if True, sample uniformly in\n",
    "            [1 / max_scale, 1 / min_scale] and take a reciprocal to get the\n",
    "            scale. If False, take a uniform sample from [min_scale,\n",
    "            max_scale].\n",
    "        aspect_ratio (list): Aspect ratio range for resizing.\n",
    "        scale (list): Scale range for resizing.\n",
    "        motion_shift (bool): Whether to apply motion shift for resizing.\n",
    "    Returns:\n",
    "        frames (tensor): spatially sampled frames.\n",
    "    \"\"\"\n",
    "    assert spatial_idx in [-1, 0, 1, 2]\n",
    "    if spatial_idx == -1:\n",
    "        if aspect_ratio is None and scale is None:\n",
    "            frames, _ = video_transforms.random_short_side_scale_jitter(\n",
    "                images=frames,\n",
    "                min_size=min_scale,\n",
    "                max_size=max_scale,\n",
    "                inverse_uniform_sampling=inverse_uniform_sampling,\n",
    "            )\n",
    "            out = video_transforms.random_crop(frames, crop_size)\n",
    "            frames = out[0]\n",
    "        else:\n",
    "            transform_func = (\n",
    "                video_transforms.random_resized_crop_with_shift\n",
    "                if motion_shift\n",
    "                else video_transforms.random_resized_crop\n",
    "            )\n",
    "            frames = transform_func(\n",
    "                images=frames,\n",
    "                target_height=crop_size,\n",
    "                target_width=crop_size,\n",
    "                scale=scale,\n",
    "                ratio=aspect_ratio,\n",
    "            )\n",
    "        if random_horizontal_flip:\n",
    "            frames, _ = video_transforms.horizontal_flip(0.5, frames)\n",
    "    else:\n",
    "        # The testing is deterministic and no jitter should be performed.\n",
    "        # min_scale, max_scale, and crop_size are expect to be the same.\n",
    "        assert len({min_scale, max_scale, crop_size}) == 1\n",
    "        frames, _ = video_transforms.random_short_side_scale_jitter(\n",
    "            frames, min_scale, max_scale\n",
    "        )\n",
    "        frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)\n",
    "    return frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bcea4c99",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch.utils.data._utils.collate import string_classes\n",
    "\n",
    "# Fix imports, but keep backwards compatibility\n",
    "TORCH_MAJOR = int(torch.__version__.split('.')[0])\n",
    "TORCH_MINOR = int(torch.__version__.split('.')[1])\n",
    "if TORCH_MAJOR == 1 and TORCH_MINOR < 8:\n",
    "    from torch._six import container_abcs, int_classes\n",
    "else:\n",
    "    import collections.abc as container_abcs\n",
    "    int_classes = int\n",
    "\n",
    "def collate_with_pad(batch, allow_pad=True, pad_right=True):\n",
    "    r\"\"\"Puts each data field into a tensor with outer dimension batch size.\n",
    "    Will pad with zeros if there are sequences of varying lenghts in the batch BUT ONLY IF seq is first dimension.\n",
    "    Will pad on the right by default, except when `pad_right==False`.\n",
    "    \"\"\"\n",
    "    # print(batch[0].keys())\n",
    "    # print(batch[0].values())\n",
    "    elem = batch[0]\n",
    "    elem_type = type(elem)\n",
    "    if isinstance(elem, torch.Tensor):\n",
    "        out = None\n",
    "        if torch.utils.data.get_worker_info() is not None:\n",
    "            # If we're in a background process, concatenate directly into a\n",
    "            # shared memory tensor to avoid an extra copy\n",
    "            numel = sum([x.numel() for x in batch])\n",
    "            storage = elem.storage()._new_shared(numel)\n",
    "            out = elem.new(storage).view(-1, *list(elem.size()))\n",
    "        ###########################################\n",
    "        # NEW: if tensors are different lengths PAD\n",
    "        ###########################################\n",
    "        it = iter(batch)\n",
    "        elem_size = torch.Tensor.size(next(it))\n",
    "        if not all(torch.Tensor.size(elem) == elem_size for elem in it):\n",
    "            if allow_pad:\n",
    "                # TRY TO PAD along the first dimension\n",
    "                max_tensor_len = max(map(lambda tensor: tensor.size(1), batch))\n",
    "                if len(batch[0].shape) == 2:\n",
    "                    stacked_padded_tensors = torch.zeros(len(batch), max_tensor_len)\n",
    "                    for idx_in_batch, tensor in enumerate(batch):\n",
    "                        ## assume always padding right in this case\n",
    "                        stacked_padded_tensors[idx_in_batch, :tensor.size(1), ...] = tensor      # pad on the right\n",
    "                    return stacked_padded_tensors\n",
    "                else:\n",
    "                    stacked_padded_tensors = torch.zeros(len(batch), max_tensor_len, elem_size[0], *elem_size[2:])\n",
    "                    for idx_in_batch, tensor in enumerate(batch):\n",
    "                        if not pad_right:\n",
    "                            stacked_padded_tensors[idx_in_batch, -1*tensor.size(1):, ...] = tensor.permute((1,0,2,3))   # pad on the left\n",
    "                        else:\n",
    "                            stacked_padded_tensors[idx_in_batch, :tensor.size(1), ...] = tensor.permute((1,0,2,3))      # pad on the right\n",
    "                    return stacked_padded_tensors.permute((0,2,1,3,4))\n",
    "            else:\n",
    "                raise RuntimeError('each element in list of batch should be of equal size')\n",
    "        ###########################################\n",
    "        # END NEW\n",
    "        ###########################################\n",
    "        return torch.stack(batch, 0, out=out)\n",
    "    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \\\n",
    "            and elem_type.__name__ != 'string_':\n",
    "        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':\n",
    "#             # array of string classes and object\n",
    "#             if np_str_obj_array_pattern.search(elem.dtype.str) is not None:\n",
    "#                 raise TypeError(f\"Unsupported type: {elem_type}\")\n",
    "\n",
    "            return collate_with_pad([torch.as_tensor(b) for b in batch], pad_right=pad_right)\n",
    "        elif elem.shape == ():  # scalars\n",
    "            return torch.as_tensor(batch)\n",
    "    elif isinstance(elem, float):\n",
    "        return torch.tensor(batch, dtype=torch.float64)\n",
    "    elif isinstance(elem, int_classes):\n",
    "        return torch.tensor(batch)\n",
    "    elif isinstance(elem, string_classes):\n",
    "        return batch\n",
    "    elif isinstance(elem, container_abcs.Mapping):\n",
    "        return {key: collate_with_pad([d[key] for d in batch], pad_right=pad_right) for key in elem}\n",
    "    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple\n",
    "        return elem_type(*(collate_with_pad(samples, pad_right=pad_right) for samples in zip(*batch)))\n",
    "    elif isinstance(elem, container_abcs.Sequence):\n",
    "        # check to make sure that the elements in batch have consistent size\n",
    "        it = iter(batch)\n",
    "        elem_size = len(next(it))\n",
    "        if not all(len(elem) == elem_size for elem in it):\n",
    "            print(batch)\n",
    "            raise RuntimeError('each element in list of batch should be of equal size')\n",
    "        transposed = zip(*batch)\n",
    "        return [collate_with_pad(samples, pad_right=pad_right) for samples in transposed]\n",
    "\n",
    "    raise TypeError(f\"Unsupported type: {elem_type}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "55f1fb2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from argparse import ArgumentParser\n",
    "\n",
    "parser = ArgumentParser()\n",
    "\n",
    "# training arguments\n",
    "parser.add_argument('--batch_size', default=512, type=int)\n",
    "parser.add_argument(\"--criterion_name\", type=str, default=\"binary_crossentropy\", choices=[\"binary_crossentropy\"])\n",
    "parser.add_argument('--balance_classes', default=False, type=lambda x: (str(x).lower() == 'true'))\n",
    "parser.add_argument('--epochs', default=40, type=int)\n",
    "parser.add_argument('--gradient_clip_val', default=1, type=float)\n",
    "parser.add_argument('--gpus', default=1, type=int)\n",
    "parser.add_argument('--num_workers', default=0, type=int)\n",
    "parser.add_argument('--seed', default=None, type=int)\n",
    "parser.add_argument('--checkpoint_every_n_epochs', type=int, default=5)\n",
    "parser.add_argument('--wandb_group', type=str, default=\"latest\")\n",
    "parser.add_argument('--freeze_backbone', default=False, type=lambda x: (str(x).lower() == 'true'))\n",
    "parser.add_argument('--toy_dataloader', default=False, type=lambda x: (str(x).lower() == 'true'))\n",
    "\n",
    "# Optimizer arguments\n",
    "parser.add_argument('--lr', default=5e-4, type=float)\n",
    "parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',\n",
    "                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\n",
    "parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',\n",
    "                    help='epochs to warmup LR, if scheduler supports')\n",
    "parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', \n",
    "                    help='num of steps to warmup LR, will overload warmup_epochs if set > 0')\n",
    "parser.add_argument('--wd', default=0.05, type=float, \n",
    "                    help=\"Weight decay (will use Adam if set to 0, AdamW otherwise).\")\n",
    "parser.add_argument('--gradient_accumulation_steps', default=1, type=int)\n",
    "parser.add_argument('--adam_betas', nargs='+', type=float, default=(0.9, 0.999), help='Adam betas')\n",
    "parser.add_argument('--adam_eps', type=float, default=1e-8, help='Adam epsilon')\n",
    "parser.add_argument('--backbone_lr', type=float, default=-1, help='backbone learning rate (if -1 uses model lr)')\n",
    "parser.add_argument('--min_backbone_lr', type=float, default=-1, help='backbone min learning rate (if -1 uses model min lr)')\n",
    "\n",
    "# regularization\n",
    "parser.add_argument('--drop_path_rate', default=0.1, type=float,\n",
    "                    help=\"Drop path rate (Stochastic Depth) (default: 0.1). Currently only implemented for leaky_clip backbone\")\n",
    "\n",
    "# Augmentation params\n",
    "parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1) - only works if mixup is enabled')\n",
    "parser.add_argument('--num_aug_sample', type=int, default=2,\n",
    "                    help='Repeated_aug (default: 2)')\n",
    "parser.add_argument('--aa', type=str, default='rand-m7-n4-mstd0.5-inc1', metavar='NAME',\n",
    "                    help='Use AutoAugment policy. \"v0\" or \"original\". \" + \"(default: rand-m7-n4-mstd0.5-inc1). Set to \"None\" to disable.'),\n",
    "parser.add_argument('--train_interpolation', type=str, default='bicubic',\n",
    "                    help='Training interpolation (random, bilinear, bicubic default: \"bicubic\")')\n",
    "\n",
    "# Random Erase params\n",
    "parser.add_argument('--reprob', type=float, default=0, metavar='PCT',\n",
    "                    help='Random erase prob (default: 0)')\n",
    "parser.add_argument('--remode', type=str, default='pixel',\n",
    "                    help='Random erase mode (default: \"pixel\")')\n",
    "parser.add_argument('--recount', type=int, default=1,\n",
    "                    help='Random erase count (default: 1)')\n",
    "\n",
    "# Mixup params\n",
    "parser.add_argument('--mixup', type=float, default=0,\n",
    "                    help='mixup alpha, mixup enabled if > 0.')\n",
    "parser.add_argument('--cutmix', type=float, default=0,\n",
    "                    help='cutmix alpha, cutmix enabled if > 0.')\n",
    "parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,\n",
    "                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\n",
    "parser.add_argument('--mixup_prob', type=float, default=1.0,\n",
    "                    help='Probability of performing mixup or cutmix when either/both is enabled')\n",
    "parser.add_argument('--mixup_switch_prob', type=float, default=0.5,\n",
    "                    help='Probability of switching to cutmix when both mixup and cutmix enabled')\n",
    "parser.add_argument('--mixup_mode', type=str, default='batch',\n",
    "                    help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\n",
    "\n",
    "# dataset arguments\n",
    "parser.add_argument('--task_name', type=str, required=True)\n",
    "parser.add_argument('--data_path', type=str, required=True)\n",
    "parser.add_argument('--label_path', type=str, required=True)\n",
    "parser.add_argument('--n_frames', default=32, type=int)\n",
    "parser.add_argument('--test_temporal_views', default=1, type=int)\n",
    "parser.add_argument('--test_spatial_views', default=3, type=int)\n",
    "\n",
    "# model structure arguments\n",
    "parser.add_argument(\"--model_name\", type=str, default=\"encode_pool_classify\", choices=[\"encode_pool_classify\"])\n",
    "\n",
    "# backbone arguments\n",
    "parser.add_argument(\"--backbone_name\", type=str, default=\"clip_ViT-B/32\")\n",
    "\n",
    "# encoder arguments\n",
    "parser.add_argument(\"--temporal_pooling_name\", type=str, default=\"mean\", choices=[\"mean\", \"transformer\", \"identity\"])\n",
    "# transformer specific arguments used if `temporal_pooling_name` is `transformer`\n",
    "parser.add_argument('--temporal_pooling_transformer_depth', default=3, type=int)\n",
    "parser.add_argument('--temporal_pooling_transformer_heads', default=4, type=int)\n",
    "parser.add_argument('--temporal_pooling_transformer_dim', default=512, type=int)\n",
    "parser.add_argument('--temporal_pooling_transformer_ff_dim', default=512, type=int)\n",
    "parser.add_argument('--temporal_pooling_transformer_input_dim', default=512, type=int)\n",
    "parser.add_argument('--temporal_pooling_transformer_emb_dropout', default=0.1, type=float)\n",
    "\n",
    "# classifier arguments\n",
    "parser.add_argument(\"--classification_layer_name\", type=str, default=\"linear\", choices=[\"linear\"])\n",
    "parser.add_argument(\"--classification_input_dim\", type=int, default=512)\n",
    "parser.add_argument(\"--num_classes\", type=int, required=True)\n",
    "\n",
    "args = parser.parse_args(\"--task_name epic_kitchens_aa --data_path /svl/data/SomethingSomethingV2/20bn-something-something-v2/ \\\n",
    "    --label_path /vision/u/eatang/SSV2 --n_frames 8 --num_classes 3500 --num_aug_sample 1 --aa rand-m9-n2-mstd0.5-inc1\".split())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "de9b3290",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "n_frames = 16\n",
    "\n",
    "# anticipate_fn = {\n",
    "#                                 '_target_': 'epic_kitchens.base_dataloader.convert_to_anticipation',\n",
    "#                                 'tau_a': 1.0,\n",
    "#                                 'tau_o': 2.5,\n",
    "#                             }\n",
    "# label_file = \"/svl/data/EpicKitchens/annotations/EPIC_100_train.pkl\"\n",
    "# streaming = False\n",
    "# streaming_max_num_prev = 1\n",
    "\n",
    "anticipate_fn = None\n",
    "label_file = '/vision/u/eatang/leaky_video/datasets/epic_kitchens/epic-kitchens-100-annotations/EPIC_100_streaming_stride_2.5_validation.pkl'\n",
    "streaming = True\n",
    "streaming_max_num_prev = 8\n",
    "dataset = EPICKitchens([label_file], version=EPIC100_VERSION, \n",
    "                       root=[\"/svl/data/EpicKitchens/EPIC-KITCHENS/\"] , label_type=[\"action\", \"noun\", \"verb\"], frames_per_clip=n_frames,\n",
    "                       subclips_options={\"stride\" : 1, \"num_frames\" : 1}, sample_strategy=\"uniform\",\n",
    "                       action_labels_fpath=\"/vision/u/eatang/leaky_video/datasets/epic_kitchens/actions.csv\",\n",
    "                        conv_to_anticipate_fn=anticipate_fn,\n",
    "                       rulstm_annotation_dir=\"/vision/u/eatang/leaky_video/datasets/epic_kitchens/rulstm/RULSTM/data/ek100\",\n",
    "                      frame_rate=1, frame_root=\"/svl/data/kinetics-400/EpicKitchens/EPIC-KITCHENS/\", \n",
    "                      video_info_path='/svl/data/kinetics-400/EpicKitchens/EPIC_100_video_info_updated.csv',\n",
    "                      process_inorder=True, use_timestamps=True, jitter_frames=True)\n",
    "dataset.df = dataset.df[:50]\n",
    "train_dataset = EpicKitchensWrapper(dataset, n_frames, mode=\"val\", \n",
    "            crop_size=224, \n",
    "            short_side_size=224,\n",
    "            num_aug_sample=1,\n",
    "            args=args,\n",
    "            test_num_segment=1,\n",
    "            test_num_crop=1,\n",
    "            streaming=streaming,\n",
    "            streaming_max_num_prev=streaming_max_num_prev)\n",
    "\n",
    "from torch.utils.data import DataLoader, SubsetRandomSampler\n",
    "from functools import partial\n",
    "\n",
    "train_loader = DataLoader(\n",
    "        train_dataset,\n",
    "        batch_size=1,\n",
    "        shuffle=False,\n",
    "        num_workers=10,\n",
    "        collate_fn=partial(collate_with_pad, allow_pad=False, pad_right=True),\n",
    "        pin_memory=True,\n",
    "    )\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "49cd1b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = train_dataset.__getitem__(16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "1e5bc5f0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['video_features', 'labels', 'video_indices', 'video_path', 'chunk_nbs', 'split_nbs', 'end_token_idxs'])"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "03897919",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1, 11, -1, -1, -1, -1, -1, -1])"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x['end_token_idxs']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "edceb1a2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[16, 16]"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x['video_indices']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "70522d93",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "video_indices = []\n",
    "for i, batch in enumerate(train_loader):\n",
    "    video_indices.extend(batch[\"video_indices\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "fb1aa8f1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({2: 1477, 1: 17607, 3: 173, 4: 13, 6: 1})"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import Counter\n",
    "Counter([len(x) for x in dataset.df[\"ends\"]])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "fa0f49d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.options.display.max_columns = None\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e193ac6",
   "metadata": {},
   "source": [
    "## Make Streaming Dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "58ae56ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "## get all video ids and number of frames for each video \n",
    "video_ids = set(dataset.df[\"video_id\"])\n",
    "frame_paths = set(dataset.df[\"frame_path\"])\n",
    "vid_id_to_num_frames = {}\n",
    "for path in frame_paths:\n",
    "    vid_id = str(path).split('/')[-1]\n",
    "    vid_id_to_num_frames[vid_id] = len(os.listdir(path))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "03b15263",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import deque, defaultdict\n",
    "from datetime import datetime, timedelta\n",
    "\n",
    "video_id_to_labels = defaultdict(deque)\n",
    "vid_id_to_fps = {}\n",
    "vid_id_to_paths = {}\n",
    "\n",
    "for i, row in dataset.df.iterrows():\n",
    "    vid_id_to_paths[row[\"video_id\"]] = [row[\"video_path\"], row[\"frame_path\"]]\n",
    "    video_id_to_labels[row[\"video_id\"]].append([row[\"end\"], row[\"verb\"], row[\"verb_class\"], row[\"noun\"], row[\"noun_class\"], row[\"action_class\"]])\n",
    "    vid_id_to_fps[row[\"video_id\"]] = row[\"video_fps\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "1f21f7cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "vid_id_to_len_in_sec = {}\n",
    "for vid_id, num_frames in vid_id_to_num_frames.items():\n",
    "    vid_id_to_len_in_sec[vid_id] = num_frames / vid_id_to_fps[vid_id]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "30a51dc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import deque\n",
    "from datetime import datetime, timedelta\n",
    "\n",
    "for i, row in dataset.df.iterrows():\n",
    "    vid_id_to_paths[row[\"video_id\"]] = [row[\"video_path\"], row[\"frame_path\"]]\n",
    "    video_id_to_labels[row[\"video_id\"]].append([row[\"end\"], row[\"verb\"], row[\"verb_class\"], row[\"noun\"], row[\"noun_class\"], row[\"action_class\"]])\n",
    "    vid_id_to_fps[row[\"video_id\"]] = row[\"video_fps\"]\n",
    "\n",
    "rows = []\n",
    "stride = 2.5\n",
    "uid_idx = 0\n",
    "for vid_id, len_in_sec in vid_id_to_len_in_sec.items():\n",
    "    actions_queue = video_id_to_labels[vid_id]\n",
    "    for i, start in enumerate(np.arange(0, len_in_sec, stride)):\n",
    "        row = dataset.df.iloc[0].copy()\n",
    "        row = row.drop(['start_frame', 'stop_frame', 'narration_timestamp'])\n",
    "\n",
    "        row['narration_id'] = vid_id + f\"_{i}\"\n",
    "        row['participant_id'] = vid_id.split('_')[0]\n",
    "        row['video_id'] = vid_id\n",
    "        row['video_fps'] = vid_id_to_fps[vid_id]\n",
    "        row[\"video_path\"] = vid_id_to_paths[vid_id][0]\n",
    "        row[\"frame_path\"] = vid_id_to_paths[vid_id][1]\n",
    "        row[\"uid\"] = uid_idx\n",
    "        uid_idx += 1\n",
    "            \n",
    "        end = min(start + stride, len_in_sec)\n",
    "        time_object = datetime.strptime(\"00:00:00.00\", \"%H:%M:%S.%f\")\n",
    "        start_time_object = time_object + timedelta(seconds=start)\n",
    "        stop_time_object = time_object + timedelta(seconds=end)\n",
    "        \n",
    "        start_time_str = start_time_object.strftime(\"%H:%M:%S.%f\")[:-3]\n",
    "        stop_time_str = stop_time_object.strftime(\"%H:%M:%S.%f\")[:-3]\n",
    "        \n",
    "        row['start_timestamp'] = start_time_str\n",
    "        row['stop_timestamp'] = stop_time_str\n",
    "        row['start'] = start\n",
    "        row['end'] = end\n",
    "        \n",
    "        row[\"narration\"] = \"\"\n",
    "        row[\"all_nouns\"] = []\n",
    "        row[\"all_noun_classes\"] = []\n",
    "        row[\"orig_start\"] = 0\n",
    "        row[\"orig_end\"] = 0\n",
    "        row[\"future_0_start\"] = 0\n",
    "        row[\"future_0_end\"] = 0\n",
    "        \n",
    "        if len(actions_queue) != 0:\n",
    "            \n",
    "            next_action = actions_queue[0]\n",
    "            \n",
    "            action_end, verb, verb_class, noun, noun_class, action_class = next_action\n",
    "            \n",
    "            row[\"verb\"] = \"\"\n",
    "            row[\"noun\"] = \"\"\n",
    "            if action_end <= end and action_end > start:\n",
    "                row[\"verb_class\"] = []\n",
    "                row[\"noun_class\"] = []\n",
    "                row[\"action_class\"] = []\n",
    "                row[\"ends\"] = []\n",
    "            else:\n",
    "                row[\"verb_class\"] = [-1]\n",
    "                row[\"noun_class\"] = [-1]\n",
    "                row[\"action_class\"] = [-1]\n",
    "                row[\"ends\"] = [-1]\n",
    "            \n",
    "            while action_end <= end and action_end > start and len(actions_queue) > 0:\n",
    "                actions_queue.popleft()\n",
    "                row[\"verb\"] = verb\n",
    "                row[\"verb_class\"].append(verb_class)\n",
    "                row[\"noun\"] = noun\n",
    "                row[\"noun_class\"].append(noun_class)\n",
    "                row[\"action_class\"].append(action_class)\n",
    "                row[\"ends\"].append(action_end)\n",
    "                \n",
    "                if len(actions_queue) != 0:\n",
    "                    next_action = actions_queue[0]\n",
    "                    action_end, verb, verb_class, noun, noun_class, action_class = next_action\n",
    "                    \n",
    "        else:\n",
    "            row[\"verb\"] = \"\"\n",
    "            row[\"noun\"] = \"\"\n",
    "            row[\"verb_class\"] = [-1]\n",
    "            row[\"noun_class\"] = [-1]\n",
    "            row[\"action_class\"] = [-1]\n",
    "            row[\"ends\"] = [-1]\n",
    "                    \n",
    "                    \n",
    "        rows.append(row)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "b7b37a97",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_df = pd.DataFrame(rows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "92b0eb57",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>narration_id</th>\n",
       "      <th>participant_id</th>\n",
       "      <th>video_id</th>\n",
       "      <th>start_timestamp</th>\n",
       "      <th>stop_timestamp</th>\n",
       "      <th>narration</th>\n",
       "      <th>verb</th>\n",
       "      <th>verb_class</th>\n",
       "      <th>noun</th>\n",
       "      <th>noun_class</th>\n",
       "      <th>...</th>\n",
       "      <th>video_path</th>\n",
       "      <th>frame_path</th>\n",
       "      <th>video_fps</th>\n",
       "      <th>action_class</th>\n",
       "      <th>uid</th>\n",
       "      <th>orig_start</th>\n",
       "      <th>orig_end</th>\n",
       "      <th>future_0_start</th>\n",
       "      <th>future_0_end</th>\n",
       "      <th>ends</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>P03_18_0</td>\n",
       "      <td>P03</td>\n",
       "      <td>P03_18</td>\n",
       "      <td>00:00:00.000</td>\n",
       "      <td>00:00:02.500</td>\n",
       "      <td></td>\n",
       "      <td>open</td>\n",
       "      <td>[3]</td>\n",
       "      <td>fridge</td>\n",
       "      <td>[12]</td>\n",
       "      <td>...</td>\n",
       "      <td>P03/videos/P03_18.MP4</td>\n",
       "      <td>/svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...</td>\n",
       "      <td>60.229027</td>\n",
       "      <td>[2345]</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[0.8700000000000001]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>P03_18_1</td>\n",
       "      <td>P03</td>\n",
       "      <td>P03_18</td>\n",
       "      <td>00:00:02.500</td>\n",
       "      <td>00:00:05.000</td>\n",
       "      <td></td>\n",
       "      <td>put in</td>\n",
       "      <td>[5]</td>\n",
       "      <td>grape</td>\n",
       "      <td>[202]</td>\n",
       "      <td>...</td>\n",
       "      <td>P03/videos/P03_18.MP4</td>\n",
       "      <td>/svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...</td>\n",
       "      <td>60.229027</td>\n",
       "      <td>[3077]</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[3.2300000000000004]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>P03_18_2</td>\n",
       "      <td>P03</td>\n",
       "      <td>P03_18</td>\n",
       "      <td>00:00:05.000</td>\n",
       "      <td>00:00:07.500</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>[-1]</td>\n",
       "      <td></td>\n",
       "      <td>[-1]</td>\n",
       "      <td>...</td>\n",
       "      <td>P03/videos/P03_18.MP4</td>\n",
       "      <td>/svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...</td>\n",
       "      <td>60.229027</td>\n",
       "      <td>[-1]</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[-1]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>P03_18_3</td>\n",
       "      <td>P03</td>\n",
       "      <td>P03_18</td>\n",
       "      <td>00:00:07.500</td>\n",
       "      <td>00:00:10.000</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>[-1]</td>\n",
       "      <td></td>\n",
       "      <td>[-1]</td>\n",
       "      <td>...</td>\n",
       "      <td>P03/videos/P03_18.MP4</td>\n",
       "      <td>/svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...</td>\n",
       "      <td>60.229027</td>\n",
       "      <td>[-1]</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[-1]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>P03_18_4</td>\n",
       "      <td>P03</td>\n",
       "      <td>P03_18</td>\n",
       "      <td>00:00:10.000</td>\n",
       "      <td>00:00:12.500</td>\n",
       "      <td></td>\n",
       "      <td>put in</td>\n",
       "      <td>[4, 5]</td>\n",
       "      <td>grape</td>\n",
       "      <td>[12, 202]</td>\n",
       "      <td>...</td>\n",
       "      <td>P03/videos/P03_18.MP4</td>\n",
       "      <td>/svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...</td>\n",
       "      <td>60.229027</td>\n",
       "      <td>[2729, 3077]</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[10.13, 11.36]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>P04_107_589</td>\n",
       "      <td>P04</td>\n",
       "      <td>P04_107</td>\n",
       "      <td>00:24:32.500</td>\n",
       "      <td>00:24:35.000</td>\n",
       "      <td></td>\n",
       "      <td>put down</td>\n",
       "      <td>[5, 17, 1]</td>\n",
       "      <td>spoon</td>\n",
       "      <td>[144, 144, 1]</td>\n",
       "      <td>...</td>\n",
       "      <td>P04/videos/P04_107.MP4</td>\n",
       "      <td>/svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...</td>\n",
       "      <td>50.002696</td>\n",
       "      <td>[3027, 985, 1112]</td>\n",
       "      <td>108002</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[1473.16, 1473.83, 1474.09]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>P04_107_590</td>\n",
       "      <td>P04</td>\n",
       "      <td>P04_107</td>\n",
       "      <td>00:24:35.000</td>\n",
       "      <td>00:24:37.500</td>\n",
       "      <td></td>\n",
       "      <td>pick up</td>\n",
       "      <td>[3, 0]</td>\n",
       "      <td>oven dish</td>\n",
       "      <td>[46, 35]</td>\n",
       "      <td>...</td>\n",
       "      <td>P04/videos/P04_107.MP4</td>\n",
       "      <td>/svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...</td>\n",
       "      <td>50.002696</td>\n",
       "      <td>[2429, 196]</td>\n",
       "      <td>108003</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[1476.35, 1477.29]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>P04_107_591</td>\n",
       "      <td>P04</td>\n",
       "      <td>P04_107</td>\n",
       "      <td>00:24:37.500</td>\n",
       "      <td>00:24:40.000</td>\n",
       "      <td></td>\n",
       "      <td>put into</td>\n",
       "      <td>[5]</td>\n",
       "      <td>oven dish</td>\n",
       "      <td>[35]</td>\n",
       "      <td>...</td>\n",
       "      <td>P04/videos/P04_107.MP4</td>\n",
       "      <td>/svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...</td>\n",
       "      <td>50.002696</td>\n",
       "      <td>[3138]</td>\n",
       "      <td>108004</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[1478.14]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>P04_107_592</td>\n",
       "      <td>P04</td>\n",
       "      <td>P04_107</td>\n",
       "      <td>00:24:40.000</td>\n",
       "      <td>00:24:42.500</td>\n",
       "      <td></td>\n",
       "      <td>put on</td>\n",
       "      <td>[4, 1]</td>\n",
       "      <td>oven glove</td>\n",
       "      <td>[46, 60]</td>\n",
       "      <td>...</td>\n",
       "      <td>P04/videos/P04_107.MP4</td>\n",
       "      <td>/svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...</td>\n",
       "      <td>50.002696</td>\n",
       "      <td>[2788, 1330]</td>\n",
       "      <td>108005</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[1480.2, 1481.05]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>P04_107_593</td>\n",
       "      <td>P04</td>\n",
       "      <td>P04_107</td>\n",
       "      <td>00:24:42.500</td>\n",
       "      <td>00:24:43.720</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td>[-1]</td>\n",
       "      <td></td>\n",
       "      <td>[-1]</td>\n",
       "      <td>...</td>\n",
       "      <td>P04/videos/P04_107.MP4</td>\n",
       "      <td>/svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...</td>\n",
       "      <td>50.002696</td>\n",
       "      <td>[-1]</td>\n",
       "      <td>108006</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[-1]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>108007 rows × 24 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   narration_id participant_id video_id start_timestamp stop_timestamp  \\\n",
       "0      P03_18_0            P03   P03_18    00:00:00.000   00:00:02.500   \n",
       "0      P03_18_1            P03   P03_18    00:00:02.500   00:00:05.000   \n",
       "0      P03_18_2            P03   P03_18    00:00:05.000   00:00:07.500   \n",
       "0      P03_18_3            P03   P03_18    00:00:07.500   00:00:10.000   \n",
       "0      P03_18_4            P03   P03_18    00:00:10.000   00:00:12.500   \n",
       "..          ...            ...      ...             ...            ...   \n",
       "0   P04_107_589            P04  P04_107    00:24:32.500   00:24:35.000   \n",
       "0   P04_107_590            P04  P04_107    00:24:35.000   00:24:37.500   \n",
       "0   P04_107_591            P04  P04_107    00:24:37.500   00:24:40.000   \n",
       "0   P04_107_592            P04  P04_107    00:24:40.000   00:24:42.500   \n",
       "0   P04_107_593            P04  P04_107    00:24:42.500   00:24:43.720   \n",
       "\n",
       "   narration      verb  verb_class        noun     noun_class  ...  \\\n",
       "0                 open         [3]      fridge           [12]  ...   \n",
       "0               put in         [5]       grape          [202]  ...   \n",
       "0                             [-1]                       [-1]  ...   \n",
       "0                             [-1]                       [-1]  ...   \n",
       "0               put in      [4, 5]       grape      [12, 202]  ...   \n",
       "..       ...       ...         ...         ...            ...  ...   \n",
       "0             put down  [5, 17, 1]       spoon  [144, 144, 1]  ...   \n",
       "0              pick up      [3, 0]   oven dish       [46, 35]  ...   \n",
       "0             put into         [5]   oven dish           [35]  ...   \n",
       "0               put on      [4, 1]  oven glove       [46, 60]  ...   \n",
       "0                             [-1]                       [-1]  ...   \n",
       "\n",
       "                video_path                                         frame_path  \\\n",
       "0    P03/videos/P03_18.MP4  /svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...   \n",
       "0    P03/videos/P03_18.MP4  /svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...   \n",
       "0    P03/videos/P03_18.MP4  /svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...   \n",
       "0    P03/videos/P03_18.MP4  /svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...   \n",
       "0    P03/videos/P03_18.MP4  /svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...   \n",
       "..                     ...                                                ...   \n",
       "0   P04/videos/P04_107.MP4  /svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...   \n",
       "0   P04/videos/P04_107.MP4  /svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...   \n",
       "0   P04/videos/P04_107.MP4  /svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...   \n",
       "0   P04/videos/P04_107.MP4  /svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...   \n",
       "0   P04/videos/P04_107.MP4  /svl/data/kinetics-400/EpicKitchens/EPIC-KITCH...   \n",
       "\n",
       "    video_fps       action_class     uid orig_start  orig_end future_0_start  \\\n",
       "0   60.229027             [2345]       0          0         0              0   \n",
       "0   60.229027             [3077]       1          0         0              0   \n",
       "0   60.229027               [-1]       2          0         0              0   \n",
       "0   60.229027               [-1]       3          0         0              0   \n",
       "0   60.229027       [2729, 3077]       4          0         0              0   \n",
       "..        ...                ...     ...        ...       ...            ...   \n",
       "0   50.002696  [3027, 985, 1112]  108002          0         0              0   \n",
       "0   50.002696        [2429, 196]  108003          0         0              0   \n",
       "0   50.002696             [3138]  108004          0         0              0   \n",
       "0   50.002696       [2788, 1330]  108005          0         0              0   \n",
       "0   50.002696               [-1]  108006          0         0              0   \n",
       "\n",
       "    future_0_end                         ends  \n",
       "0              0         [0.8700000000000001]  \n",
       "0              0         [3.2300000000000004]  \n",
       "0              0                         [-1]  \n",
       "0              0                         [-1]  \n",
       "0              0               [10.13, 11.36]  \n",
       "..           ...                          ...  \n",
       "0              0  [1473.16, 1473.83, 1474.09]  \n",
       "0              0           [1476.35, 1477.29]  \n",
       "0              0                    [1478.14]  \n",
       "0              0            [1480.2, 1481.05]  \n",
       "0              0                         [-1]  \n",
       "\n",
       "[108007 rows x 24 columns]"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "8bc10a38",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({1: 92969, 2: 12255, 3: 2378, 4: 352, 5: 46, 6: 6, 8: 1})"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import Counter\n",
    "Counter([len(x) for x in new_df[\"ends\"]])\n",
    "\n",
    "## 6 max for train\n",
    "## 4 max for val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "eb996aae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(f'/vision/u/eatang/leaky_video/datasets/epic_kitchens/epic-kitchens-100-annotations/EPIC_100_streaming_stride_2.5_train.pkl', 'wb') as f:\n",
    "    pickle.dump(new_df, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99987ff4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:root] *",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
