{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "f02d187d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import librosa\n",
    "import torchaudio\n",
    "import pathlib\n",
    "from End2End.constants import SAMPLE_RATE\n",
    "from torch.utils.data import Dataset\n",
    "# Libraries related to hydra\n",
    "import hydra\n",
    "from hydra.utils import to_absolute_path\n",
    "from End2End.MIDI_program_map import (\n",
    "                                      MIDI_Class_NUM,\n",
    "                                      MIDIClassName2class_idx,\n",
    "                                      class_idx2MIDIClass,\n",
    "                                      )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "1aab4a86",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0000, 0.0000, 0.0000,  ..., 0.0001, 0.0002, 0.0001])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "waveform, rate = torchaudio.load('datasets/wild/rock - Metropolis - Part I The Miracle and the Sleeper.mp3')\n",
    "waveform.mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "9ad8280f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 27464832])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "waveform.mean(0, keepdim=True).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6501fa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "audiofolder = pathlib.Path('./datasets/wild')\n",
    "audio_list = list(audiofolder.glob('*.mp3'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "0b82b36a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tiger/anaconda3/envs/jointist/lib/python3.8/site-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'transcription_config': Defaults list is missing `_self_`. See https://hydra.cc/docs/upgrades/1.0_to_1.1/default_composition_order for more information\n",
      "  warnings.warn(msg, UserWarning)\n"
     ]
    }
   ],
   "source": [
    "hydra.initialize(config_path=\"End2End/config/\", job_name=\"debug\")\n",
    "cfg = hydra.compose(config_name=\"transcription_config\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "445e7454",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  6.6936e-05,\n",
       "           1.0514e-04,  2.8205e-04],\n",
       "         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -6.8963e-05,\n",
       "           4.0293e-05, -1.6093e-06]]),\n",
       " 44100)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torchaudio.load(audio_list[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "5b6a02ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "class WildDataset(Dataset):\n",
    "    def __init__(\n",
    "        self,\n",
    "        audio_path,\n",
    "        audio_ext,\n",
    "        MIDI_MAPPING,\n",
    "        segment_samples=None,\n",
    "    ):\n",
    "        r\"\"\"Instrument classification dataset takes the meta of an audio\n",
    "        segment as input, and return the waveform, onset_roll, and targets of\n",
    "        the audio segment. Dataset is used by DataLoader.\n",
    "\n",
    "        Args:\n",
    "            audio_path: str\n",
    "            audio_ext: str, e.g. mp3, wav, flac\n",
    "            segment_samples: int, how long you want to cut the audio. If set to None, get the full audio\n",
    "        \"\"\"\n",
    "        audiofolder = pathlib.Path(audio_path)\n",
    "        self.audio_name_list = list(audiofolder.glob(f'*.{audio_ext}'))\n",
    "\n",
    "        self.sample_rate = SAMPLE_RATE\n",
    "        self.segment_samples = segment_samples\n",
    "        \n",
    "        self.ix_to_name = MIDI_MAPPING.IX_TO_NAME\n",
    "        self.plugin_labels_num = MIDI_MAPPING.plugin_labels_num\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.audio_name_list)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        r\"\"\"Get input and target of a segment for training.\n",
    "\n",
    "        Args:\n",
    "            idx for the audio list\n",
    "\n",
    "        Returns:\n",
    "          data_dict: {\n",
    "            'waveform': (samples_num,)\n",
    "            'onset_roll': (frames_num, classes_num),\n",
    "            'offset_roll': (frames_num, classes_num),\n",
    "            'reg_onset_roll': (frames_num, classes_num),\n",
    "            'reg_offset_roll': (frames_num, classes_num),\n",
    "            'frame_roll': (frames_num, classes_num),\n",
    "            'velocity_roll': (frames_num, classes_num),\n",
    "            'mask_roll':  (frames_num, classes_num),\n",
    "            'pedal_onset_roll': (frames_num,),\n",
    "            'pedal_offset_roll': (frames_num,),\n",
    "            'reg_pedal_onset_roll': (frames_num,),\n",
    "            'reg_pedal_offset_roll': (frames_num,),\n",
    "            'pedal_frame_roll': (frames_num,)}\n",
    "        \"\"\"\n",
    "        \n",
    "        waveform, rate = torchaudio.load(self.audio_name_list[idx])\n",
    "        if waveform.shape[0]==2: # if the audio file is stereo take mean\n",
    "            waveform = waveform.mean(0, keepdim=True) # keep the dim as (1, audio length)\n",
    "            \n",
    "\n",
    "        data_dict = {}\n",
    "\n",
    "        # Load segment waveform.\n",
    "        # with h5py.File(waveform_hdf5_path, 'r') as hf:\n",
    "        audio_length = len(waveform)\n",
    "        if self.segment_samples:\n",
    "            assert (audio_length - self.segment_samples)>0, \\\n",
    "            f\"sequence_length={sequence_length} is longer than the \"\n",
    "            f\"audio_length={audio_length}. Please reduce the sequence_length\"\n",
    "            if self.random_crop:\n",
    "                start_sample = np.random.randint(audio_length - self.segment_samples)\n",
    "            else:\n",
    "                start_sample = self.sample_rate*10\n",
    "            start_time = start_sample/self.sample_rate\n",
    "            end_sample = start_sample + self.segment_samples\n",
    "\n",
    "            waveform_seg = waveform[start_sample : end_sample]            \n",
    "            while waveform_seg.sum()==0 or len(unique_plugin_names) == 0: # resample if the audio is empty                \n",
    "    #             if waveform_seg.sum()==0:\n",
    "    #                 print(f'{hdf5_name} waveform is empty')\n",
    "    #             elif unique_plugin_names==0:\n",
    "    #                 print(f'{hdf5_name} waveform not empty, but no instrumet ')\n",
    "                if self.random_crop:\n",
    "                    start_sample = np.random.randint(audio_length - self.segment_samples)\n",
    "                else:\n",
    "                    start_sample = start_sample + self.sample_rate # Shift the audio by 1 second if the audio is empty\n",
    "                start_time = start_sample/self.sample_rate\n",
    "                end_sample = start_sample + self.segment_samples\n",
    "                waveform_seg = waveform[start_sample : end_sample]\n",
    "\n",
    "                valid_length = waveform_seg.shape[1]\n",
    "                # (segment_samples,), e.g., (160000,)\n",
    "\n",
    "            \n",
    "        else:\n",
    "            start_sample = 0\n",
    "            start_time = 0\n",
    "            end_sample = audio_length\n",
    "\n",
    "            waveform_seg = waveform\n",
    "\n",
    "            valid_length = waveform_seg.shape[1]\n",
    "            # (segment_samples,), e.g., (160000,)            \n",
    "\n",
    "        \n",
    "        data_dict['waveform'] = waveform_seg\n",
    "        data_dict['start_sample'] = start_sample\n",
    "        data_dict['valid_length'] = valid_length\n",
    "        data_dict['file_name'] = self.audio_name_list[idx].name\n",
    "\n",
    "        return data_dict  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "0386a229",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.MIDI_MAPPING.plugin_labels_num = MIDI_Class_NUM\n",
    "cfg.MIDI_MAPPING.NAME_TO_IX = MIDIClassName2class_idx\n",
    "cfg.MIDI_MAPPING.IX_TO_NAME = class_idx2MIDIClass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "1b1341a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = WildDataset('datasets/wild/', 'mp3', cfg.MIDI_MAPPING)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "61fdc336",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'waveform': tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.0133e-06,\n",
       "           7.2718e-05,  1.4022e-04]]),\n",
       " 'start_sample': 0,\n",
       " 'valid_length': 11245824,\n",
       " 'file_name': 'yt1s.com - Nirvana  In Bloom.mp3'}"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2df205e9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jointist",
   "language": "python",
   "name": "jointist"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
