{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "10347669",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/bin/sh: 1: sox: not found\n",
      "SoX could not be found!\n",
      "\n",
      "    If you do not have SoX, proceed here:\n",
      "     - - - http://sox.sourceforge.net/ - - -\n",
      "\n",
      "    If you do (or think that you should) have SoX, double-check your\n",
      "    path variables.\n",
      "    \n"
     ]
    }
   ],
   "source": [
    "from functools import partial\n",
    "import os\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.plugins import DDPPlugin\n",
    "from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint\n",
    "\n",
    "from End2End.Data import DataModuleEnd2End\n",
    "import End2End.tasks.detection as Detection\n",
    "from End2End.MIDI_program_map import (\n",
    "                                      MIDI_Class_NUM,\n",
    "                                      MIDIClassName2class_idx,\n",
    "                                      class_idx2MIDIClass,\n",
    "                                      )\n",
    "# import End2End.models.instrument_detection as DectectionModel\n",
    "import End2End.models.instrument_detection.combined as CombinedModel\n",
    "import End2End.models.instrument_detection.backbone as BackBone\n",
    "import End2End.models.transformer as Transformer\n",
    "\n",
    "from End2End.data.augmentors import Augmentor\n",
    "from End2End.lr_schedulers import get_lr_lambda\n",
    "# from jointist.models.instruments_classification_models import get_model_class\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Libraries related to hydra\n",
    "from IPython.display import Audio\n",
    "\n",
    "import hydra\n",
    "from hydra.utils import to_absolute_path\n",
    "from omegaconf import OmegaConf\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6f72a80e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tiger/anaconda3/envs/jointist/lib/python3.8/site-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'Instrument_Filter': Defaults list is missing `_self_`. See https://hydra.cc/docs/upgrades/1.0_to_1.1/default_composition_order for more information\n",
      "  warnings.warn(msg, UserWarning)\n"
     ]
    }
   ],
   "source": [
    "hydra.initialize(config_path=\"End2End/config/\", job_name=\"debug\")\n",
    "cfg = hydra.compose(config_name=\"Instrument_Filter\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dc873e5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading train hdf5 files: 100%|██████████| 1500/1500 [00:00<00:00, 1501540.81it/s]\n",
      "Loading train pkl files: 100%|██████████| 1500/1500 [00:23<00:00, 63.55it/s] \n",
      "Loading validation hdf5 files: 100%|██████████| 375/375 [00:00<00:00, 1807889.66it/s]\n",
      "Loading validation pkl files: 100%|██████████| 375/375 [00:05<00:00, 63.12it/s] \n",
      "Loading test hdf5 files: 100%|██████████| 225/225 [00:00<00:00, 1493225.32it/s]\n",
      "Loading test pkl files: 100%|██████████| 225/225 [00:01<00:00, 159.31it/s]\n"
     ]
    }
   ],
   "source": [
    "cfg.datamodule.waveform_hdf5s_dir = to_absolute_path(os.path.join('hdf5s', 'waveforms'))   \n",
    "\n",
    "if cfg.MIDI_MAPPING.type=='plugin_names':\n",
    "    cfg.MIDI_MAPPING.plugin_labels_num = PLUGIN_LABELS_NUM\n",
    "    cfg.MIDI_MAPPING.NAME_TO_IX = PLUGIN_LB_TO_IX\n",
    "    cfg.MIDI_MAPPING.IX_TO_NAME = PLUGIN_IX_TO_LB\n",
    "    cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes3/')   \n",
    "elif cfg.MIDI_MAPPING.type=='MIDI_class':\n",
    "    cfg.MIDI_MAPPING.plugin_labels_num = MIDI_Class_NUM\n",
    "    cfg.MIDI_MAPPING.NAME_TO_IX = MIDIClassName2class_idx\n",
    "    cfg.MIDI_MAPPING.IX_TO_NAME = class_idx2MIDIClass\n",
    "    cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes_MIDI_class/')\n",
    "else:\n",
    "    raise ValueError(f\"Please choose the correct MIDI_MAPPING.type\")        \n",
    "\n",
    "\n",
    "experiment_name = (\n",
    "                  f\"{cfg.transcription.model.type}-\"\n",
    "                  f\"{cfg.MIDI_MAPPING.type}-\"\n",
    "                  f\"hidden=256-\"\n",
    "                  f\"fps={cfg.transcription.model.args.frames_per_second}-\"\n",
    "                  f\"csize={cfg.transcription.model.args.condition_size}-\"\n",
    "                  f\"bz={cfg.batch_size}\"\n",
    "                  )\n",
    "\n",
    "# augmentor\n",
    "augmentor = Augmentor(augmentation=cfg.augmentation) if cfg.augmentation else None\n",
    "\n",
    "# data module\n",
    "data_module = DataModuleEnd2End(**cfg.datamodule,augmentor=augmentor, MIDI_MAPPING=cfg.MIDI_MAPPING)\n",
    "data_module.setup()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "acbce1be",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa656578a60>\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/tiger/anaconda3/envs/jointist/lib/python3.8/site-packages/torch/utils/data/dataloader.py\", line 1328, in __del__\n",
      "    self._shutdown_workers()\n",
      "  File \"/home/tiger/anaconda3/envs/jointist/lib/python3.8/site-packages/torch/utils/data/dataloader.py\", line 1320, in _shutdown_workers\n",
      "    if w.is_alive():\n",
      "  File \"/home/tiger/anaconda3/envs/jointist/lib/python3.8/multiprocessing/process.py\", line 160, in is_alive\n",
      "    assert self._parent_pid == os.getpid(), 'can only test a child process'\n",
      "AssertionError: can only test a child process\n"
     ]
    }
   ],
   "source": [
    "train_loader = iter(data_module.train_dataloader())\n",
    "val_loader = iter(data_module.val_dataloader())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2fa18821",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['target_dict', 'waveform', 'start_sample', 'valid_length', 'hdf5_name', 'instruments', 'plugin_id'])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(train_loader).keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "45a6d968",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['target_dict', 'waveform', 'start_sample', 'valid_length', 'hdf5_name', 'instruments', 'plugin_id'])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(val_loader).keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7d2179b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jointist",
   "language": "python",
   "name": "jointist"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
