{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "065512cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pathlib\n",
    "import pickle\n",
    "import os\n",
    "import numpy as np\n",
    "import mir_eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e5bdeda5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/bin/sh: 1: sox: not found\n",
      "SoX could not be found!\n",
      "\n",
      "    If you do not have SoX, proceed here:\n",
      "     - - - http://sox.sourceforge.net/ - - -\n",
      "\n",
      "    If you do (or think that you should) have SoX, double-check your\n",
      "    path variables.\n",
      "    \n"
     ]
    }
   ],
   "source": [
    "from functools import partial\n",
    "import os\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.plugins import DDPPlugin\n",
    "from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint\n",
    "\n",
    "from End2End.Data import DataModuleEnd2End, End2EndBatchDataPreprocessor, FullPreprocessor\n",
    "from End2End.tasks.transcription import Transcription, BaselineTranscription\n",
    "from End2End.models.transcription.seg_baseline import Semantic_Segmentation\n",
    "\n",
    "from End2End.MIDI_program_map import (\n",
    "                                      MIDI_Class_NUM,\n",
    "                                      MIDIClassName2class_idx,\n",
    "                                      class_idx2MIDIClass,\n",
    "                                      )\n",
    "from End2End.data.augmentors import Augmentor\n",
    "from End2End.lr_schedulers import get_lr_lambda\n",
    "import End2End.models.transcription.combined as TranscriptionModel\n",
    "from End2End.losses import get_loss_function\n",
    "\n",
    "# Libraries related to hydra\n",
    "import hydra\n",
    "from hydra.utils import to_absolute_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a7b4a986",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tiger/anaconda3/envs/jointist/lib/python3.8/site-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'transcription_config': Defaults list is missing `_self_`. See https://hydra.cc/docs/upgrades/1.0_to_1.1/default_composition_order for more information\n",
      "  warnings.warn(msg, UserWarning)\n"
     ]
    }
   ],
   "source": [
    "hydra.initialize(config_path=\"End2End/config/\", job_name=\"debug\")\n",
    "cfg = hydra.compose(config_name=\"transcription_config\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7db382fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading test hdf5 files: 100%|██████████| 10/10 [00:00<00:00, 156503.88it/s]\n",
      "Loading test pkl files: 100%|██████████| 10/10 [00:00<00:00, 43.94it/s]\n"
     ]
    }
   ],
   "source": [
    "cfg.datamodule.waveform_hdf5s_dir = to_absolute_path(os.path.join('hdf5s', 'waveforms'))\n",
    "cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes3')\n",
    "cfg.batch_size=8\n",
    "\n",
    "cfg.MIDI_MAPPING.plugin_labels_num = MIDI_Class_NUM\n",
    "cfg.MIDI_MAPPING.NAME_TO_IX = MIDIClassName2class_idx\n",
    "cfg.MIDI_MAPPING.IX_TO_NAME = class_idx2MIDIClass\n",
    "cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes_MIDI_class/')\n",
    "\n",
    "# augmentor\n",
    "augmentor = Augmentor(augmentation=cfg.augmentation) if cfg.augmentation else None\n",
    "\n",
    "# data module\n",
    "data_module = DataModuleEnd2End(**cfg.datamodule,augmentor=augmentor, MIDI_MAPPING=cfg.MIDI_MAPPING)\n",
    "data_module.setup('test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "60d48b42",
   "metadata": {},
   "outputs": [],
   "source": [
    "plugins_output_dict = {}\n",
    "plugins_output_dict['plugin_name'] = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3cbb7ebc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plugins_output_dict['plugin_name']['key_a'] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "197de474",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_F1(pred_path, label_path):\n",
    "    \"\"\"\n",
    "    Calcuate piece-wise F1 score for both note and note with offset metrics.\n",
    "    notewise_dict will be a nested dictionary with the following keys:\n",
    "    {'note':\n",
    "        {filename: \n",
    "            {inst_name:\n",
    "                {'precision': int\n",
    "                 'recall': int\n",
    "                 'f1': int\n",
    "                 }\n",
    "            },\n",
    "            .\n",
    "            .\n",
    "            .\n",
    "        .\n",
    "        .\n",
    "        .\n",
    "        },\n",
    "     'note_w_off':\n",
    "        {filename: \n",
    "            {inst_name:\n",
    "                {'precision': int\n",
    "                 'recall': int\n",
    "                 'f1': int\n",
    "                 },\n",
    "            },\n",
    "            .\n",
    "            .\n",
    "            .\n",
    "        .\n",
    "        .\n",
    "        .\n",
    "        },\n",
    "    }\n",
    "    \n",
    "    \"\"\"\n",
    "    pred_pkl_files = pathlib.Path(pred_path)\n",
    "    pred_pkl_files = pred_pkl_files.glob('*.pkl')\n",
    "    start = time.time()\n",
    "    notewise_dict = {}\n",
    "    notewise_dict['note'] = {}\n",
    "    notewise_dict['note_w_off'] = {}\n",
    "    for pkl_file in tqdm.tqdm(pred_pkl_files):\n",
    "        note_events = pickle.load(open(os.path.join(label_path, pkl_file.name), 'rb'))\n",
    "        transcribed_dict = pickle.load(open(pkl_file,'rb'))\n",
    "        unique_plugin_names = sorted(list(set([note_event['plugin_name'] for note_event in note_events])))    \n",
    "\n",
    "        notewise_dict['note'][pkl_file.name] = {}\n",
    "        notewise_dict['note_w_off'][pkl_file.name] = {}        \n",
    "        for plugin_name in unique_plugin_names:\n",
    "            notewise_dict['note'][pkl_file.name][plugin_name] = {}\n",
    "            notewise_dict['note_w_off'][pkl_file.name][plugin_name] = {}\n",
    "            \n",
    "            ref_on_off_pairs = []\n",
    "            ref_pitches = []\n",
    "\n",
    "            for note_event in note_events:\n",
    "                if note_event['plugin_name'] == plugin_name:\n",
    "                    print(f\"{note_event['plugin_name']=}\")\n",
    "                    ref_on_off_pairs.append([note_event['start'], note_event['end']])\n",
    "                    ref_pitches.append(note_event['pitch'])\n",
    "\n",
    "            est_on_off_pairs = []\n",
    "            est_pitches = []\n",
    "\n",
    "            for note_event in transcribed_dict[plugin_name]:\n",
    "                est_on_off_pairs.append([note_event['onset_time'], note_event['offset_time']])\n",
    "                est_pitches.append(note_event['midi_note'])\n",
    "\n",
    "            # from IPython import embed; embed(using=False); os._exit(0)\n",
    "\n",
    "            ref_on_off_pairs = np.array(ref_on_off_pairs)\n",
    "            ref_pitches = np.array(ref_pitches)\n",
    "            est_on_off_pairs = np.array(est_on_off_pairs)\n",
    "            est_pitches = np.array(est_pitches)\n",
    "\n",
    "            if est_on_off_pairs.shape[0]!=0:\n",
    "                (note_precision, note_recall, note_f1, _,) = mir_eval.transcription.precision_recall_f1_overlap(\n",
    "                    ref_intervals=ref_on_off_pairs,\n",
    "                    ref_pitches=ref_pitches,\n",
    "                    est_intervals=est_on_off_pairs,\n",
    "                    est_pitches=est_pitches,\n",
    "                    onset_tolerance=0.05,\n",
    "                    offset_ratio=None,\n",
    "                )\n",
    "                \n",
    "                (note_woffset_precision, note_woffset_recall, note_woffset_f1, _,) = mir_eval.transcription.precision_recall_f1_overlap(\n",
    "                    ref_intervals=ref_on_off_pairs,\n",
    "                    ref_pitches=ref_pitches,\n",
    "                    est_intervals=est_on_off_pairs,\n",
    "                    est_pitches=est_pitches,\n",
    "                    onset_tolerance=0.05,\n",
    "                    offset_ratio=0.2,\n",
    "                )\n",
    "            else:\n",
    "                print(f\"empty pianoroll\")\n",
    "                note_precision = 0\n",
    "                note_recall = 0\n",
    "                note_f1 = 0                                       \n",
    "            notewise_dict['note'][pkl_file.name][plugin_name] = {\n",
    "                                                        'precision': note_precision,\n",
    "                                                        'recall': note_recall,\n",
    "                                                        'f1': note_f1\n",
    "                                                        }\n",
    "            notewise_dict['note_w_off'][pkl_file.name][plugin_name] = {\n",
    "                                                        'precision': note_woffset_precision,\n",
    "                                                        'recall': note_woffset_recall,\n",
    "                                                        'f1': note_woffset_f1\n",
    "                                                        }            \n",
    "    return notewise_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "39848b13",
   "metadata": {},
   "outputs": [],
   "source": [
    "notes_pkls_dir = './instruments_classification_notes_MIDI_class'\n",
    "pred_path = '/opt/tiger/kinwai/jointist/outputs/2021-12-27/21-58-29/MIDI_output/'\n",
    "pred_pkl_files = pathlib.Path(pred_path)\n",
    "pred_pkl_files = pred_pkl_files.glob('*.pkl')\n",
    "label_path = os.path.join(notes_pkls_dir, 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "2daa8727",
   "metadata": {},
   "outputs": [],
   "source": [
    "notewise_dict = {}\n",
    "notewise_dict['note'] = {}\n",
    "notewise_dict['note_w_off'] = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "15446c49",
   "metadata": {},
   "outputs": [],
   "source": [
    "pkl_file = next(iter(pred_pkl_files))\n",
    "notewise_dict['note'][pkl_file.name] = {}\n",
    "notewise_dict['note_w_off'][pkl_file.name] = {}    \n",
    "note_events = pickle.load(open(os.path.join(label_path, pkl_file.name), 'rb'))\n",
    "transcribed_dict = pickle.load(open(pkl_file,'rb'))\n",
    "unique_plugin_names = sorted(list(set([note_event['plugin_name'] for note_event in note_events])))    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "6c3586e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plugin_name = unique_plugin_names[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "496edc24",
   "metadata": {},
   "outputs": [],
   "source": [
    "notewise_dict['note'][pkl_file.name][plugin_name] = {}\n",
    "notewise_dict['note_w_off'][pkl_file.name][plugin_name] = {}\n",
    "plugin_name = 'Bass'\n",
    "ref_on_off_pairs = []\n",
    "ref_pitches = []\n",
    "for note_event in note_events:\n",
    "    if note_event['plugin_name'] == plugin_name:\n",
    "        ref_on_off_pairs.append([note_event['start'], note_event['end']])\n",
    "        ref_pitches.append(note_event['pitch'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "aa5b7588",
   "metadata": {},
   "outputs": [],
   "source": [
    "est_on_off_pairs = []\n",
    "est_pitches = []\n",
    "for note_event in transcribed_dict[plugin_name]:\n",
    "    est_on_off_pairs.append([note_event['onset_time'], note_event['offset_time']])\n",
    "    est_pitches.append(note_event['midi_note'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "9de86869",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['Piano', 'Electric Piano', 'Organ', 'Electric Guitar', 'Bass', 'Strings', 'Voice', 'Synth Pad'])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transcribed_dict.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "897a480c",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dict = torch.load('/opt/tiger/kinwai/jointist/outputs/2021-12-27/22-25-08/output_dict.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "7f0c0c17",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "47d99de7",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(data_module.test_dataloader()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "eb536ce9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5000, 88)"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch['target_dict'][0]['Bass']['frame_roll'][:5000].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "2276b17d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['Electric Guitar', 'Strings', 'Synth Pad', 'Electric Piano', 'Bass', 'Organ', 'Piano', 'Voice'])"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch['target_dict'][0].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "id": "ba540b95",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fc836a63520>"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQ4klEQVR4nO3da6wcZ33H8e+v8Y0EaOxALZOg4pYIFFXChKMQBEJqQi6kFXalKAqtikUjWaqghdKqmPICKvUFqVooVSsql6Q1FQ0JJpEjRAmOG4QqFYMTTK6kdkICMY7NJSEplUwC/77YMRxOznrnXPac4+d8P9JqZ56Z2f0/49mf58zszqSqkCSd+n5psQuQJM0PA12SGmGgS1IjDHRJaoSBLkmNWLGQb7Yqq2sNZyzkW0rSKe9pnvheVb141HwLGuhrOIPX5uKFfEtJOuXdXrse7TOfh1wkqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiN6BXqSP0lyX5J7k9yQZE2SjUn2JTmU5MYkq8ZdrCRpuJGBnuRs4I+Biar6DeA04GrgWuAjVfVy4AngmnEWKkk6ub6HXFYAz0uyAjgdOAJcBOzqpu8Etsx7dZKk3kYGelUdBv4G+BaDIP8hcCfwZFU92832GHD2dMsn2ZZkf5L9z3B8fqqWJD1Hn0Mua4HNwEbgJcAZwOV936CqdlTVRFVNrGT1rAuVJJ1cn0MubwK+WVXfrapngJuB1wNndodgAM4BDo+pRklSD30C/VvAhUlOTxLgYuB+4A7gym6ercDu8ZQoSeqjzzH0fQxOft4F3NMtswN4L/CeJIeAs4DrxlinJGmEFaNngar6APCBKc0PAxfMe0WSpFnxl6KS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1os89RV+R5MCkx1NJ3p1kXZI9SQ52z2sXomBJ0vT63LHowaraVFWbgNcA/wfcAmwH9lbVucDeblyStEhmesjlYuChqnoU2Azs7Np3AlvmsS5J0gzNNNCvBm7ohtdX1ZFu+HFg/bxVJUmasd6BnmQV8Bbg01OnVVUBNWS5bUn2J9n/DMdnXagk6eRmsof+ZuCuqjrajR9NsgGgez423UJVtaOqJqpqYiWr51atJGmomQT6W/n54RaAW4Gt3fBWYPd8FSVJmrlegZ7kDOAS4OZJzR8CLklyEHhTNy5JWiQr+sxUVT8CzprS9n0G33qRJC0B/lJUkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1IglE+i3fefA0PGTTRv2WqPmURuW07/1ib4uhf6erIaTfXanzreU+tSCJRPokqS5yeDKtwvjhVlXr83JrxZw23cOcNlLNj1n+MQ4wGUv2fScaZIW1+TP53Tt001TP7fXrjuramLUfO6hS1IjDHRJasSSO+QiSfpFHnKRpGXGQJekRvS9Y9GZSXYl+UaSB5K8Lsm6JHuSHOye1467WEnScH330D8KfL6qXgm8CngA2A7srapzgb3duCRpkYwM9CS/DLwRuA6gqn5cVU8Cm4Gd3Ww7gS3jKVGS1EefPfSNwHeBf0nytSQf724avb6qjnTzPA6sn27hJNuS7E+y/xmOz0/VkqTn6BPoK4DzgY9V1auBHzHl8EoNvvs47fcfq2pHVU1U1cRKVs+1Xi1zXvNDGq5PoD8GPFZV+7rxXQwC/miSDQDd87HxlChJ6mNkoFfV48C3k7yia7oYuB+4FdjatW0Fdo+lQklSLyt6zvdHwCeTrAIeBt7O4D+Dm5JcAzwKXDWeEqWf8+JO0nC9Ar2qDgDT/ezU3/FL0hKxpH4pOt2NLE60nWyapIU37LM53bTpbnpxsuU1O0sq0CVJs2egS1IjllygT77H4HR3Phl2VxSpRUv5cMSJO4fB8MMuw+4uNt20pdzXU8WSC3RJ0ux4gwtJWuK8wYUkLTMGuiQ1oolA92SKJDUS6JKkRgLdrzBKUiOBLkky0CWpGX0vn3tKmHpy1EMxkpYT99AlqRFN7aG7Ry5pOesV6EkeAZ4GfgI8W1UTSdYBNwIvAx4BrqqqJ8ZTpiRplJkccvnNqto06XoC24G9VXUusLcblyQtkrkcQ98M7OyGdwJb5lzNHPhrUUnLXd9AL+ALSe5Msq1rW19VR7rhx4H10y2YZFuS/Un2P8PxOZYrSRqm70nRN1TV4SS/AuxJ8o3JE6uqkkx7Hd6q2gHsgMHlc+dU7Ul4QlTSctdrD72qDnfPx4BbgAuAo0k2AHTPx8ZVpCRptJGBnuSMJC84MQxcCtwL3Aps7WbbCuweV5GSpNH6HHJZD9yS5MT8/15Vn0/yVeCmJNcAjwJXja9MSdIoIwO9qh4GXjVN+/cB7ycnSUuEP/2XpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDWid6AnOS3J15J8thvfmGRfkkNJbkyyanxlSpJGmcke+ruAByaNXwt8pKpeDjwBXDOfhUmSZqZXoCc5B/gt4OPdeICLgF3dLDuBLWOoT5LUU9899L8D/hz4aTd+FvBkVT3bjT8GnD3dgkm2JdmfZP8zHJ9LrZKkkxgZ6El+GzhWVXfO5g2qakdVTVTVxEpWz+YlJEk9rOgxz+uBtyS5AlgDvBD4KHBmkhXdXvo5wOHxlSlJGmXkHnpVva+qzqmqlwFXA/9ZVb8H3AFc2c22Fdg9tiolSSPN5Xvo7wXek+QQg2Pq181PSZKk2ehzyOVnquqLwBe74YeBC+a/JEnSbPhLUUlqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhrR556ia5J8JcnXk9yX5C+79o1J9iU5lOTGJKvGX64kaZg+e+jHgYuq6lXAJuDyJBcC1wIfqaqXA08A14ytSknSSH3uKVpV9b/d6MruUcBFwK6ufSewZRwFSpL66XUMPclpSQ4Ax4A9wEPAk1X1bDfLY8DZQ5bdlmR/kv3PcHweSpYkTadXoFfVT6pqE3AOg/uIvrLvG1TVjqqaqKqJlayeXZWSpJFm9C2XqnoSuAN4HXBmkhM3mT4HODy/pUmSZqLPt1xenOTMbvh5wCXAAwyC/cputq3A7jHVKEnqYcXoWdgA7ExyGoP/AG6qqs8muR/4VJK/Ar4GXDfGOiVJI4wM9Kq6G3j1NO0PMzieLklaAvylqCQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEX1uQffSJHckuT/JfUne1bWvS7InycHuee34y5UkDdNnD/1Z4E+r6jzgQuAdSc4DtgN7q+pcYG83LklaJCMDvaqOVNVd3fDTDG4QfTawGdjZzbYT2DKmGiVJPfS5SfTPJHkZg/uL7gPWV9WRbtLjwPohy2wDtgGs4fRZFypJOrneJ0WTPB/4DPDuqnpq8rSqKqCmW66qdlTVRFVNrGT1nIqVJA3XK9CTrGQQ5p+sqpu75qNJNnTTNwDHxlOiJKmPPt9yCXAd8EBVfXjSpFuBrd3wVmD3/JcnSeqrzzH01wO/D9yT5EDX9hfAh4CbklwDPApcNZYKJUm9jAz0qvovIEMmXzy/5UiSZstfikpSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGtHnFnTXJzmW5N5JbeuS7ElysHteO94yJUmj9NlD/1fg8ilt24G9VXUusLcblyQtopGBXlVfAn4wpXkzsLMb3glsmd+yJEkz1ecm0dNZX1VHuuHHgfXDZkyyDdgGsIbTZ/l2kqRR5nxStKoKqJNM31FVE1U1sZLVc307SdIQsw30o0k2AHTPx+avJEnSbMw20G8FtnbDW4Hd81OOJGm2+nxt8Qbgv4FXJHksyTXAh4BLkhwE3tSNS5IW0ciTolX11iGTLp7nWiRJc+AvRSWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWrEnAI9yeVJHkxyKMn2+SpKkjRzsw70JKcB/wi8GTgPeGuS8+arMEnSzMxlD/0C4FBVPVxVPwY+BWyen7IkSTM1l0A/G/j2pPHHujZJ0iIYeU/RuUqyDdgGsIbTx/12krRszSXQDwMvnTR+Ttf2C6pqB7ADIMnTt9euB+fwnqe6FwHfW+wiFpnrwHWw3PsPM18Hv9pnprkE+leBc5NsZBDkVwO/O2KZB6tqYg7veUpLsn859x9cB+A6WO79h/Gtg1kHelU9m+SdwG3AacD1VXXfvFUmSZqROR1Dr6rPAZ+bp1okSXOw0L8U3bHA77fULPf+g+sAXAfLvf8wpnWQqhrH60qSFpjXcpGkRhjoktSIBQn01i/ileSRJPckOZBkf9e2LsmeJAe757Vde5L8fbcu7k5y/qTX2drNfzDJ1sXqTx9Jrk9yLMm9k9rmrc9JXtOt00PdslnYHp7ckP5/MMnhbjs4kOSKSdPe1/XlwSSXTWqf9rORZGOSfV37jUlWLVzvRkvy0iR3JLk/yX1J3tW1L6dtYNg6WLztoKrG+mDwlcaHgF8DVgFfB84b9/su5AN4BHjRlLa/BrZ3w9uBa7vhK4D/AAJcCOzr2tcBD3fPa7vhtYvdt5P0+Y3A+cC94+gz8JVu3nTLvnmx+9yj/x8E/myaec/rtvvVwMbu83DayT4bwE3A1d3wPwF/uNh9ntKnDcD53fALgP/p+rmctoFh62DRtoOF2ENfrhfx2gzs7IZ3AlsmtX+iBr4MnJlkA3AZsKeqflBVTwB7gMsXuObequpLwA+mNM9Ln7tpL6yqL9dgS/7EpNdaEob0f5jNwKeq6nhVfRM4xOBzMe1no9sTvQjY1S0/eV0uCVV1pKru6oafBh5gcC2n5bQNDFsHw4x9O1iIQF8OF/Eq4AtJ7szg2jUA66vqSDf8OLC+Gx62PlpYT/PV57O74antp4J3docUrj9xuIGZ9/8s4MmqenZK+5KU5GXAq4F9LNNtYMo6gEXaDjwpOj/eUFXnM7g2/DuSvHHyxG4PY1l9P3Q59hn4GPDrwCbgCPC3i1rNAkjyfOAzwLur6qnJ05bLNjDNOli07WAhAr3XRbxOZVV1uHs+BtzC4E+oo92fjXTPx7rZh62PFtbTfPX5cDc8tX1Jq6qjVfWTqvop8M8MtgOYef+/z+CQxIop7UtKkpUMguyTVXVz17ystoHp1sFibgcLEeg/u4hXd4b2auDWBXjfBZHkjCQvODEMXArcy6CPJ87YbwV2d8O3Am/rzvpfCPyw+xP1NuDSJGu7P9Eu7dpOJfPS527aU0ku7I4jvm3Say1ZJ4Ks8zsMtgMY9P/qJKszuJjduQxO+E372ej2bO8AruyWn7wul4Tu3+U64IGq+vCkSctmGxi2DhZ1O1igs8FXMDgD/BDw/oV4z4V6MDgz/fXucd+J/jE4/rUXOAjcDqzr2sPg1n0PAfcAE5Ne6w8YnCg5BLx9sfs2ot83MPhz8hkGx/aumc8+AxPdB+Eh4B/oftW8VB5D+v9vXf/u7j68GybN//6uLw8y6dsawz4b3Xb1lW69fBpYvdh9ntL/NzA4nHI3cKB7XLHMtoFh62DRtgN/+i9JjfCkqCQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5Jjfh/s+DOCF8tAHoAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(batch['target_dict'][0]['Synth Pad']['frame_roll'][:].T, aspect='auto', origin='lower', interpolation='none')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "3e9973dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fc8369bf220>"
      ]
     },
     "execution_count": 117,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXMAAAD4CAYAAAAeugY9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQJUlEQVR4nO3dbYwdZ3nG8f/V+I0EaOxALTdBjSkRKKpECKs0CITUmISQVtiVoii0Kha1ZKmCFkqrYsoHqNQPpGqhVK2oXJLWrWhIMIkcIQoYNwhVKgYnhLyS2gkJxHJsXhKSUsk4cPfDeUwXd3fP2Zezaz/7/0mrmXlmZud+xnMuz86ccyZVhSTpzPZzS12AJGn+DHNJ6oBhLkkdMMwlqQOGuSR1YMVibmxVVtcazlnMTUrSGe9ZnvpuVb14pmUWNczXcA6/mk2LuUlJOuN9oXY/PmwZL7NIUgcMc0nqgGEuSR0wzCWpA4a5JHXAMJekDhjmktQBw1ySOmCYS1IHDHNJ6oBhLkkdMMwlqQOGuSR1wDCXpA6MFOZJ/jDJA0nuT3JzkjVJNibZn+RQkluSrBp3sZKkqQ0N8yTnA38ATFTVrwBnAdcDNwAfrqqXAU8B28ZZqCRpeqNeZlkBPC/JCuBs4AhwBbC7zd8FbFnw6iRJIxka5lV1GPhL4FsMQvwHwF3A01X1XFvsCeD8qdZPsj3JgSQHTnB8YaqWJP2MUS6zrAU2AxuBXwTOAa4edQNVtbOqJqpqYiWr51yoJGl6o1xmeQPwzar6TlWdAG4DXguc2y67AFwAHB5TjZKkIUYJ828Blyc5O0mATcCDwJ3AtW2ZrcCe8ZQoSRpmlGvm+xnc6LwbuK+tsxN4D/DuJIeA84Abx1inJGkGK4YvAlX1fuD9pzQ/Cly24BVJkmbNT4BKUgcMc0nqgGEuSR0wzCWpA4a5JHXAMJekDhjmktQBw1ySOmCYS1IHDHNJ6oBhLkkdMMwlqQOGuSR1wDCXpA4Y5pLUgVGeAfryJPdM+nkmybuSrEuyN8nBNly7GAVLkv6/UZ409HBVXVJVlwCvBv4HuB3YAeyrqouAfW1akrQEZnuZZRPwSFU9DmwGdrX2XcCWBaxLkjQLsw3z64Gb2/j6qjrSxp8E1i9YVZKkWRk5zJOsAt4MfPLUeVVVQE2z3vYkB5IcOMHxORcqSZrebM7M3wTcXVVH2/TRJBsA2vDYVCtV1c6qmqiqiZWsnl+1kqQpzSbM38L/XWIBuAPY2sa3AnsWqihJ0uyMFOZJzgGuBG6b1PxB4MokB4E3tGlJ0hJYMcpCVfVD4LxT2r7H4N0tkqQl5idAJakDhrkkdcAwl6QOGOaS1AHDXJI6YJhLUgcMc0nqgGEuSR0wzCWpA4a5JHXAMJekDhjmktQBw1ySOmCYS1IHDHNJ6oBhLkkdGPVJQ+cm2Z3kG0keSvKaJOuS7E1ysA3XjrtYSdLURj0z/wjw2ap6BfBK4CFgB7Cvqi4C9rVpSdISGBrmSX4eeD1wI0BV/aiqngY2A7vaYruALeMpUZI0zChn5huB7wD/mORrST7WHvC8vqqOtGWeBNZPtXKS7UkOJDlwguMLU7Uk6WeMEuYrgEuBj1bVq4AfcsollaoqoKZauap2VtVEVU2sZPV865UkTWGUMH8CeKKq9rfp3QzC/WiSDQBteGw8JUqShhka5lX1JPDtJC9vTZuAB4E7gK2tbSuwZywVSpKGWjHicr8PfDzJKuBR4G0M/iO4Nck24HHguvGUKEkaZqQwr6p7gIkpZm1a0GokSXPiJ0AlqQOGuSR1wDCXpA4Y5pLUAcNckjpgmEtSBwxzSeqAYS5JHTDMJakDhrkkdcAwl6QOGOaS1AHDXJI6YJhLUgcMc0nqgGEuSR0Y6eEUSR4DngV+DDxXVRNJ1gG3ABcCjwHXVdVT4ylTkjST2ZyZ/1pVXVJVJ584tAPYV1UXAfvatCRpCcznMstmYFcb3wVsmXc1kqQ5GTXMC/h8kruSbG9t66vqSBt/Elg/1YpJtic5kOTACY7Ps1xJ0lRGumYOvK6qDif5BWBvkm9MnllVlaSmWrGqdgI7AV6YdVMuI0man5HOzKvqcBseA24HLgOOJtkA0IbHxlWkJGlmQ8M8yTlJXnByHLgKuB+4A9jaFtsK7BlXkZKkmY1ymWU9cHuSk8v/a1V9NslXgVuTbAMeB64bX5mSpJkMDfOqehR45RTt3wM2jaMoSdLs+AlQSeqAYS5JHTDMJakDhrkkdcAwl6QOGOaS1AHDXJI6YJhLUgcMc0nqgGEuSR0wzCWpA4a5JHXAMJekDhjmktQBw1ySOmCYS1IHRg7zJGcl+VqST7fpjUn2JzmU5JYkq8ZXpiRpJrM5M38n8NCk6RuAD1fVy4CngG0LWZgkaXQjhXmSC4BfBz7WpgNcAexui+wCtoyhPknSCEY9M/9r4E+An7Tp84Cnq+q5Nv0EcP5UKybZnuRAkgMnOD6fWiVJ0xga5kl+AzhWVXfNZQNVtbOqJqpqYiWr5/IrJElDrBhhmdcCb05yDbAGeCHwEeDcJCva2fkFwOHxlSlJmsnQM/Oqem9VXVBVFwLXA/9eVb8N3Alc2xbbCuwZW5WSpBnN533m7wHeneQQg2voNy5MSZKk2RrlMstPVdUXgS+28UeByxa+JEnSbPkJUEnqgGEuSR0wzCWpA4a5JHXAMJekDhjmktQBw1ySOmCYS1IHDHNJ6oBhLkkdMMwlqQOGuSR1wDCXpA4Y5pLUAcNckjowyjNA1yT5SpKvJ3kgyZ+19o1J9ic5lOSWJKvGX64kaSqjnJkfB66oqlcClwBXJ7kcuAH4cFW9DHgK2Da2KiVJMxrlGaBVVf/dJle2nwKuAHa39l3AlnEUKEkabqRr5knOSnIPcAzYCzwCPF1Vz7VFngDOn2bd7UkOJDlwguMLULIk6VQjhXlV/biqLgEuYPDcz1eMuoGq2llVE1U1sZLVc6tSkjSjWb2bpaqeBu4EXgOcm+TkA6EvAA4vbGmSpFGN8m6WFyc5t40/D7gSeIhBqF/bFtsK7BlTjZKkIVYMX4QNwK4kZzEI/1ur6tNJHgQ+keTPga8BN46xTknSDIaGeVXdC7xqivZHGVw/lyQtMT8BKkkdMMwlqQOGuSR1wDCXpA4Y5pLUAcNckjpgmEtSBwxzSeqAYS5JHTDMJakDhrkkdcAwl6QOGOaS1AHDXJI6YJhLUgcMc0nqwCiPjXtJkjuTPJjkgSTvbO3rkuxNcrAN146/XEnSVEY5M38O+KOquhi4HHh7kouBHcC+qroI2NemJUlLYGiYV9WRqrq7jT/L4GHO5wObgV1tsV3AljHVKEkaYpQHOv9UkgsZPA90P7C+qo60WU8C66dZZzuwHWANZ8+5UEnS9Ea+AZrk+cCngHdV1TOT51VVATXVelW1s6omqmpiJavnVawkaWojhXmSlQyC/ONVdVtrPppkQ5u/ATg2nhIlScOM8m6WADcCD1XVhybNugPY2sa3AnsWvjxJ0ihGuWb+WuB3gPuS3NPa/hT4IHBrkm3A48B1Y6lQkjTU0DCvqv8AMs3sTQtbjiRpLvwEqCR1wDCXpA4Y5pLUAcNckjpgmEtSBwxzSeqAYS5JHTDMJakDhrkkdcAwl6QOGOaS1AHDXJI6YJhLUgcMc0nqgGEuSR0wzCWpA6M8Nu6mJMeS3D+pbV2SvUkOtuHa8ZYpSZrJKGfm/wRcfUrbDmBfVV0E7GvTkqQlMjTMq+pLwPdPad4M7Grju4AtC1uWJGk2Rnmg81TWV9WRNv4ksH66BZNsB7YDrOHsOW5OkjSTed8AraoCaob5O6tqoqomVrJ6vpuTJE1hrmF+NMkGgDY8tnAlSZJma65hfgewtY1vBfYsTDmSpLkY5a2JNwP/Cbw8yRNJtgEfBK5MchB4Q5uWJC2RoTdAq+ot08zatMC1SJLmyE+ASlIHDHNJ6oBhLkkdMMwlqQOGuSR1wDCXpA4Y5pLUAcNckjpgmEtSBwxzSeqAYS5JHTDMJakDhrkkdcAwl6QOGOaS1IF5hXmSq5M8nORQkh0LVZQkaXbmHOZJzgL+DngTcDHwliQXL1RhkqTRzefM/DLgUFU9WlU/Aj4BbF6YsiRJszGfMD8f+Pak6SdamyRpkQ19Buh8JdkObAdYw9nj3pwkLUvzCfPDwEsmTV/Q2n5GVe0EdgIkefYLtfvheWzzTPci4LtLXcQSW+77wP4v7/7D3PbBLw1bYD5h/lXgoiQbGYT49cBvDVnn4aqamMc2z2hJDizn/oP7wP4v7/7D+PbBnMO8qp5L8g7gc8BZwE1V9cCCVSZJGtm8rplX1WeAzyxQLZKkOVrsT4DuXOTtnW6We//BfWD/NZZ9kKoax++VJC0iv5tFkjpgmEtSBxYlzHv+Qq4kjyW5L8k9SQ60tnVJ9iY52IZrW3uS/E3bD/cmuXTS79nalj+YZOtS9WcUSW5KcizJ/ZPaFqzPSV7d9umhtm4Wt4czm6b/H0hyuB0H9yS5ZtK897a+PJzkjZPap3xdJNmYZH9rvyXJqsXr3XBJXpLkziQPJnkgyTtb+3I6BqbbB0t3HFTVWH8YvG3xEeClwCrg68DF497uYv0AjwEvOqXtL4AdbXwHcEMbvwb4NyDA5cD+1r4OeLQN17bxtUvdtxn6/HrgUuD+cfQZ+EpbNm3dNy11n0fo/weAP55i2YvbMb8a2NheC2fN9LoAbgWub+N/D/zeUvf5lD5tAC5t4y8A/qv1czkdA9PtgyU7DhbjzHw5fiHXZmBXG98FbJnU/s818GXg3CQbgDcCe6vq+1X1FLAXuHqRax5ZVX0J+P4pzQvS5zbvhVX15Rocxf886XedFqbp/3Q2A5+oquNV9U3gEIPXxJSvi3YGegWwu60/eV+eFqrqSFXd3cafBR5i8L1My+kYmG4fTGfsx8FihHnvX8hVwOeT3JXB99AArK+qI238SWB9G59uX/Swjxaqz+e38VPbzwTvaJcRbjp5iYHZ9/884Omqeu6U9tNSkguBVwH7WabHwCn7AJboOPAG6Py9rqouZfC97m9P8vrJM9uZxbJ6/+dy7DPwUeCXgUuAI8BfLWk1iyDJ84FPAe+qqmcmz1sux8AU+2DJjoPFCPORvpDrTFVVh9vwGHA7gz+bjrY/FWnDY23x6fZFD/toofp8uI2f2n5aq6qjVfXjqvoJ8A8MjgOYff+/x+AyxIpT2k8rSVYyCLGPV9VtrXlZHQNT7YOlPA4WI8x/+oVc7W7s9cAdi7DdsUtyTpIXnBwHrgLuZ9C/k3fmtwJ72vgdwFvb3f3LgR+0P0s/B1yVZG37s+yq1nYmWZA+t3nPJLm8XTd866Tfddo6GWLNbzI4DmDQ/+uTrM7gS+kuYnBzb8rXRTujvRO4tq0/eV+eFtq/y43AQ1X1oUmzls0xMN0+WNLjYJHu/F7D4G7vI8D7FmObi9SvlzK4+/x14IGTfWNwvWsfcBD4ArCutYfBo/YeAe4DJib9rt9lcFPkEPC2pe7bkH7fzOBPyBMMruVtW8g+AxPtRfAI8Le0TyqfLj/T9P9fWv/ubS/cDZOWf1/ry8NMelfGdK+Ldlx9pe2XTwKrl7rPp/T/dQwuodwL3NN+rllmx8B0+2DJjgM/zi9JHfAGqCR1wDCXpA4Y5pLUAcNckjpgmEtSBwxzSeqAYS5JHfhfbAcXC8lldcIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "idx=7\n",
    "plt.imshow((output_dict['frame_output'][:,idx*88:(idx+1)*88]>0.5).t(), aspect='auto', origin='lower', interpolation='none')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23285102",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "034dc4b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'onset_time': 16.34, 'offset_time': 16.71, 'midi_note': 39, 'velocity': 100},\n",
       " {'onset_time': 16.7, 'offset_time': 17.4, 'midi_note': 37, 'velocity': 100},\n",
       " {'onset_time': 17.05, 'offset_time': 17.28, 'midi_note': 36, 'velocity': 100},\n",
       " {'onset_time': 17.39, 'offset_time': 18.14, 'midi_note': 35, 'velocity': 100},\n",
       " {'onset_time': 18.05, 'offset_time': 18.45, 'midi_note': 46, 'velocity': 100},\n",
       " {'onset_time': 18.51, 'offset_time': 18.73, 'midi_note': 46, 'velocity': 100},\n",
       " {'onset_time': 18.74, 'offset_time': 18.75, 'midi_note': 35, 'velocity': 100},\n",
       " {'onset_time': 18.74, 'offset_time': 19.62, 'midi_note': 37, 'velocity': 100},\n",
       " {'onset_time': 19.78, 'offset_time': 19.98, 'midi_note': 37, 'velocity': 100},\n",
       " {'onset_time': 19.99, 'offset_time': 20.0, 'midi_note': 37, 'velocity': 100}]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transcribed_dict[plugin_name][:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "ef48e0ac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "39"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ref_pitches[10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "eb2c64bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[16.363632, 16.708092135416667],\n",
       " [16.70098986458333, 17.389910135416667],\n",
       " [17.386359, 18.046870187499998],\n",
       " [18.06462586458333, 18.76775067708333],\n",
       " [18.753546135416666, 19.406955052083333]]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ref_on_off_pairs[10:15]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "2eea14c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "empty pianoroll\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'note_woffset_precision' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_2125481/1625350293.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     56\u001b[0m                                                 }\n\u001b[1;32m     57\u001b[0m     notewise_dict['note_w_off'][pkl_file.name][plugin_name] = {\n\u001b[0;32m---> 58\u001b[0;31m                                                 \u001b[0;34m'precision'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnote_woffset_precision\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     59\u001b[0m                                                 \u001b[0;34m'recall'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnote_woffset_recall\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     60\u001b[0m                                                 \u001b[0;34m'f1'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnote_woffset_f1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'note_woffset_precision' is not defined"
     ]
    }
   ],
   "source": [
    "notewise_dict['note'][pkl_file.name] = {}\n",
    "notewise_dict['note_w_off'][pkl_file.name] = {}        \n",
    "for plugin_name in unique_plugin_names:\n",
    "    notewise_dict['note'][pkl_file.name][plugin_name] = {}\n",
    "    notewise_dict['note_w_off'][pkl_file.name][plugin_name] = {}\n",
    "\n",
    "    ref_on_off_pairs = []\n",
    "    ref_pitches = []\n",
    "\n",
    "    for note_event in note_events:\n",
    "        if note_event['plugin_name'] == plugin_name:\n",
    "            ref_on_off_pairs.append([note_event['start'], note_event['end']])\n",
    "            ref_pitches.append(note_event['pitch'])\n",
    "\n",
    "    est_on_off_pairs = []\n",
    "    est_pitches = []\n",
    "\n",
    "    for note_event in transcribed_dict[plugin_name]:\n",
    "        est_on_off_pairs.append([note_event['onset_time'], note_event['offset_time']])\n",
    "        est_pitches.append(note_event['midi_note'])\n",
    "\n",
    "    # from IPython import embed; embed(using=False); os._exit(0)\n",
    "\n",
    "    ref_on_off_pairs = np.array(ref_on_off_pairs)\n",
    "    ref_pitches = np.array(ref_pitches)\n",
    "    est_on_off_pairs = np.array(est_on_off_pairs)\n",
    "    est_pitches = np.array(est_pitches)\n",
    "\n",
    "    if est_on_off_pairs.shape[0]!=0:\n",
    "        (note_precision, note_recall, note_f1, _,) = mir_eval.transcription.precision_recall_f1_overlap(\n",
    "            ref_intervals=ref_on_off_pairs,\n",
    "            ref_pitches=ref_pitches,\n",
    "            est_intervals=est_on_off_pairs,\n",
    "            est_pitches=est_pitches,\n",
    "            onset_tolerance=0.05,\n",
    "            offset_ratio=None,\n",
    "        )\n",
    "\n",
    "        (note_woffset_precision, note_woffset_recall, note_woffset_f1, _,) = mir_eval.transcription.precision_recall_f1_overlap(\n",
    "            ref_intervals=ref_on_off_pairs,\n",
    "            ref_pitches=ref_pitches,\n",
    "            est_intervals=est_on_off_pairs,\n",
    "            est_pitches=est_pitches,\n",
    "            onset_tolerance=0.05,\n",
    "            offset_ratio=0.2,\n",
    "        )\n",
    "    else:\n",
    "        print(f\"empty pianoroll\")\n",
    "        note_precision = 0\n",
    "        note_recall = 0\n",
    "        note_f1 = 0                                       \n",
    "    notewise_dict['note'][pkl_file.name][plugin_name] = {\n",
    "                                                'precision': note_precision,\n",
    "                                                'recall': note_recall,\n",
    "                                                'f1': note_f1\n",
    "                                                }\n",
    "    notewise_dict['note_w_off'][pkl_file.name][plugin_name] = {\n",
    "                                                'precision': note_woffset_precision,\n",
    "                                                'recall': note_woffset_recall,\n",
    "                                                'f1': note_woffset_f1\n",
    "                                                }\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "83eb4c42",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'note': {'Track01909.pkl': {'Bass': {'precision': 0.0,\n",
       "    'recall': 0.0,\n",
       "    'f1': 0.0}}},\n",
       " 'note_w_off': {'Track01909.pkl': {'Bass': {'precision': 0.0,\n",
       "    'recall': 0.0,\n",
       "    'f1': 0.0}}}}"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "notewise_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c3b4524c",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'outputs' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_374245/3649859659.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mframewise_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'framewise'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_dict\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m     \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0mframewise_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'framewise'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'outputs' is not defined"
     ]
    }
   ],
   "source": [
    "pred_path = self.evaluation_output_path\n",
    "label_path = os.path.join(notes_pkls_dir, 'test')\n",
    "notewise_dict = evaluate_F1(pred_path, label_path)\n",
    "pickle.dump(notewise_dict, open(\"notewise_dict.pkl\", 'wb')) # saving the notewise_dict as pickle file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abfdc279",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jointist",
   "language": "python",
   "name": "jointist"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
