{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/asteroid-team/asteroid/blob/master/notebooks/03_PITLossWrapper.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Permutation invariant training\n",
    "Permutation invariant training (PIT) was succesfully introduced to train DNN-based speaker-independent speech separation systems [1, 2, 3, 4]. Since then, it has been applied to environmental source separation [5] and classification [6], and end-to-end diarization [7, 8].\n",
    "There has also been recent work to extend or improve on PIT-based training [9, 10, 11].\n",
    "\n",
    "Asteroid provides `PITLossWrapper`, a flexible class which enables seamless transformation of simple loss functions into permutation invariant losses, for any loss function and any number of sources!\n",
    "\n",
    "It supports three types of loss functions :\n",
    "- 1) The loss function compute the average loss for a given permutation (over all source-estimates). `PITLossWrapper` loops over the permutations and returns the minimum loss, the one to be backproped. (`pit_from = perm_avg`)\n",
    "- 2) Second way, compute pair wise losses and take the mean over each permutation. \n",
    "  - a) The pairwise losses can be computed using one function, which returns a pairwise matrix. In this case `PITLossWrapper` finds the best permutation and returns the minimum loss. (`pit_from = pw_mtx`)\n",
    "  - b) The provided function computes the loss for one given target-estimate pair (a point in the pairwise matrix). `PITLossWrapper` computes the pairwise loss matrix by calling this function on each pair. It then finds the best permutation and returns the minimum loss as done in 2b. (`pit_from = pw_pt`)\n",
    " \n",
    "In addition, we provide common loss functions in these three forms.  \n",
    "Let's try to understand these three ways of computing PIT losses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First install asteroid and depencies\n",
    "!pip install git+https://github.com/asteroid-team/asteroid.git@master --quiet\n"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### After installing requirements, you need to Restart Runtime (Ctrl + M).\n",
    "\n",
    "Else it will fail to import asteroid"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from itertools import permutations\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from asteroid.losses import PITLossWrapper\n",
    "from asteroid.losses import pairwise_mse, singlesrc_mse, multisrc_mse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To be able to visualize some results, we will take a batch size of 1.\n",
    "batch_size, n_sources, feat_dim = 1, 4, 50\n",
    "# First, take random sources\n",
    "sources = torch.randn(batch_size, n_sources, feat_dim)\n",
    "# Generate estimates : Randomly permute the sources and add some noise.\n",
    "random_permutation = torch.randperm(n_sources)\n",
    "estimate_sources = sources[:, random_permutation] + torch.randn(batch_size, n_sources, feat_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. The Naive Way.\n",
    "The naive way consists in looping over all permutations on the source axis to find the best one.\n",
    "It corresponds to the mode `pit_from='perm_avg'`, meaning `permutation average` because the loss function computes the \n",
    " average loss for a set of sources and their estimates (a given permutation)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best permutation : (2, 0, 3, 1). 0riginal permutation : tensor([2, 0, 3, 1])\n"
     ]
    }
   ],
   "source": [
    "# The naive way. Find the best loss by looping over the permutations.\n",
    "perms = list(permutations(range(n_sources)))\n",
    "all_losses = torch.stack([multisrc_mse(estimate_sources, sources[:, p]) for p in perms])\n",
    "\n",
    "best_loss_idx = torch.argmin(all_losses)\n",
    "# We will backprop all_losses[best_loss_idx]\n",
    "print(\"Best permutation : {}. 0riginal permutation : {}\".format(perms[best_loss_idx], random_permutation))\n",
    "\n",
    "# This is equivalent to :\n",
    "loss_func = PITLossWrapper(multisrc_mse, pit_from='perm_avg')\n",
    "best_loss = loss_func(estimate_sources, sources)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. More efficient way\n",
    "The first thing to notice is that the loss on one permutation is the sum of between-source losses. We can speed up the naive approach by computing individual pairwise losses and averaging them to compute the loss for each permutation. We can compute the pairwise losses in two ways : \n",
    "- The given function compute the loss function for a given pair (a point in the pairwise matrix) and we can loop over the pairs. (`pit_from='pw_pt'` for pairwise point.)\n",
    "- The given function computes the pairwise matrix directly. (`pit_from='pw_mtx'` for pairwise matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQcAAAD8CAYAAAB6iWHJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAANH0lEQVR4nO3da6wc9XnH8e+vxlxSKDdTYRkXEoEQKW0gUJcIqUJcJBOlOGqJCi8SQCBXUWhI1agJvVA1b0ryIqkIUaoIUCHKBQQpdSOqlBaiBLVQDDKUSyEGVcWEFmLABHGL0dMXO9CTw//Y4J2d3ePz/Ugrz+78vc+zAv28Z2bOPKkqJGm+X5h2A5Jmk+EgqclwkNRkOEhqMhwkNRkOkprGCockByW5NcmPuj8PXGDd60k2dY8N49SUNIyMc51Dki8Az1bV5Uk+CxxYVZ9prHuxqvYdo09JAxs3HB4BTqmqp5KsBL5fVUc31hkO0iIzbjg8X1UHdNsBnnvj+bx124FNwHbg8qq6eYH3Ww+sB1i2zx4n/NLhb3mrRW//PV6edgsTs/WBPafdwkSsPPalabcwMY8+8OpPquqQ1r6dhkOSfwYObez6U+DauWGQ5LmqestxhySrqurJJO8BbgNOq6rHdlT3oGMOqdOv+Z0d9rYY/faK+6bdwsRcd/TqabcwEZc+dv+0W5iY097z6D1VdWJr3x47+8tVdfpC+5L8b5KVc36seHqB93iy+/PxJN8Hjgd2GA6SpmvcU5kbgPO67fOAv5+/IMmBSfbqtlcAJwMPjVlX0oSNGw6XA2ck+RFwevecJCcmuapbcwywMcl9wO2MjjkYDtKM2+mPFTtSVVuB0xqvbwQu6rb/Ffi1cepIGp5XSEpqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ19RIOSdYmeSTJ5m7y1fz9eyW5vtt/V5Ij+qgraXLGDocky4CvAGcC7wXOTfLeecsuZDTw5kjgS8Dnx60rabL6+OawBthcVY9X1WvAt4F189asA67ttm8ETusmZEmaUX2EwyrgiTnPt3SvNddU1XZgG3BwD7UlTchMHZBMsj7JxiQbX33ulWm3Iy1pfYTDk8DcIYmHda811yTZA9gf2Dr/jarqa1V1YlWduNeBe/fQmqRd1Uc43A0cleTdSfYEzmE0Jm+uuWPzzgZuq3HGe0uauLEmXsHoGEKSi4HvAcuAa6rqwSSfAzZW1QbgauDrSTYDzzIKEEkzbOxwAKiqW4Bb5r122ZztV4CP9FFL0jBm6oCkpNlhOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1DTUr8/wkzyTZ1D0u6qOupMkZ+wazc2ZlnsFo2tXdSTZU1UPzll5fVRePW0/SMPq4+/SbszIBkrwxK3N+OLwjrz9abFv7sx7amy1/dsWHp93CxPzGHf817RYm4oJ/2p2/6P7xgnuGmpUJ8LtJ7k9yY5LVjf0/Nw7vtXIcnjRNQx2Q/AfgiKr6deBW/n/i9s+ZOw5vzzgOT5qmQWZlVtXWqnq1e3oVcEIPdSVN0CCzMpOsnPP0LODhHupKmqChZmV+MslZwHZGszLPH7eupMkaalbmpcClfdSSNAyvkJTUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhq6msc3jVJnk7ywAL7k+SKblze/Une30ddSZPT1zeHvwXW7mD/mcBR3WM98NWe6kqakF7Coap+wOiu0gtZB1xXI3cCB8y7Xb2kGTPUMYe3NTLPcXjS7JipA5KOw5Nmx1DhsNOReZJmy1DhsAH4WHfW4iRgW1U9NVBtSbugl4lXSb4FnAKsSLIF+AtgOUBV/Q2jaVgfBDYDLwEX9FFX0uT0NQ7v3J3sL+ATfdSSNIyZOiApaXYYDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpqGGod3SpJtSTZ1j8v6qCtpcnq5hySjcXhXAtftYM0Pq+pDPdWTNGFDjcOTtMj09c3h7fhAkvuAHwOfrqoH5y9Isp7RoF32+uX9eOWbBw3Y3kAeWzbtDvQOHfPXu++/e/+9g31DHZC8Fzi8qt4HfBm4ubVo7ji85Qe8a6DWJLUMEg5V9UJVvdht3wIsT7JiiNqSds0g4ZDk0CTpttd0dbcOUVvSrhlqHN7ZwMeTbAdeBs7ppmBJmlFDjcO7ktGpTkmLhFdISmoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDWNHQ5JVie5PclDSR5MckljTZJckWRzkvuTvH/cupImq497SG4H/qiq7k2yH3BPklur6qE5a84Ejuoevwl8tftT0owa+5tDVT1VVfd22z8FHgZWzVu2DriuRu4EDkiyctzakian12MOSY4AjgfumrdrFfDEnOdbeGuAkGR9ko1JNv7s+Zf6bE3SO9RbOCTZF7gJ+FRVvbAr7+E4PGl29BIOSZYzCoZvVNV3GkueBFbPeX5Y95qkGdXH2YoAVwMPV9UXF1i2AfhYd9biJGBbVT01bm1Jk9PH2YqTgY8C/5FkU/fanwC/Am+Ow7sF+CCwGXgJuKCHupImaOxwqKo7gOxkTQGfGLeWpOF4haSkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FS01Dj8E5Jsi3Jpu5x2bh1JU3WUOPwAH5YVR/qoZ6kAQw1Dk/SItPHN4c37WAcHsAHktwH/Bj4dFU92Pj764H1APse+i5+9YD/6bO9mbDPn+++Y/5uuO9fpt3CRBz/hXOm3cLknLXwrqHG4d0LHF5V7wO+DNzceo+54/D2OXDvvlqTtAsGGYdXVS9U1Yvd9i3A8iQr+qgtaTIGGYeX5NBuHUnWdHW3jltb0uQMNQ7vbODjSbYDLwPndFOwJM2oocbhXQlcOW4tScPxCklJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkpj5uMLt3kn9Pcl83Du8vG2v2SnJ9ks1J7urmW0iaYX18c3gVOLWbSXEcsDbJSfPWXAg8V1VHAl8CPt9DXUkT1Mc4vHpjJgWwvHvMv7P0OuDabvtG4LQ3blUvaTb1NdRmWXdb+qeBW6tq/ji8VcATAFW1HdgGHNxHbUmT0Us4VNXrVXUccBiwJsmxu/I+SdYn2Zhk48vPvdJHa5J2Ua9nK6rqeeB2YO28XU8CqwGS7AHsT2PilbMypdnRx9mKQ5Ic0G3vA5wB/Oe8ZRuA87rts4HbnHglzbY+xuGtBK5NsoxR2NxQVd9N8jlgY1VtYDRL8+tJNgPPArvxTHNp99DHOLz7geMbr182Z/sV4CPj1pI0HK+QlNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1DTUr8/wkzyTZ1D0uGreupMnq4+7Tb8zKfDHJcuCOJP9YVXfOW3d9VV3cQz1JA+jj7tMF7GxWpqRFJn3MlulmVtwDHAl8pao+M2//+cBfAc8AjwJ/WFVPNN5nPbC+e3o08MjYzb19K4CfDFhvKH6uxWfIz3Z4VR3S2tFLOLz5ZqPJV38H/EFVPTDn9YOBF6vq1SS/D/xeVZ3aW+EeJNlYVSdOu4+++bkWn1n5bIPMyqyqrVX1avf0KuCEPutK6t8gszKTrJzz9Czg4XHrSpqsoWZlfjLJWcB2RrMyz++hbt++Nu0GJsTPtfjMxGfr9ZiDpN2HV0hKajIcJDUt+XBIsjbJI0k2J/nstPvpS5Jrkjyd5IGdr148kqxOcnuSh7rL9S+Zdk99eDu/hjB4T0v5mEN3EPVRRmdYtgB3A+dW1UNTbawHSX6L0ZWr11XVsdPupy/dma+VVXVvkv0YXXz34cX+3yxJgF+c+2sIwCWNX0MYzFL/5rAG2FxVj1fVa8C3gXVT7qkXVfUDRmeGditV9VRV3dtt/5TRafFV0+1qfDUyU7+GsNTDYRUw9zLuLewG/6MtFUmOAI4H7ppuJ/1IsizJJuBp4NaqmurnWurhoEUqyb7ATcCnquqFaffTh6p6vaqOAw4D1iSZ6o+DSz0cngRWz3l+WPeaZlj3M/lNwDeq6jvT7qdvC/0awtCWejjcDRyV5N1J9gTOATZMuSftQHfg7mrg4ar64rT76cvb+TWEoS3pcKiq7cDFwPcYHdi6oaoenG5X/UjyLeDfgKOTbEly4bR76snJwEeBU+fcWeyD026qByuB25Pcz+gfrVur6rvTbGhJn8qUtLAl/c1B0sIMB0lNhoOkJsNBUpPhIKnJcJDUZDhIavo/liMm2T5kwzgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Without source axis, let's compute the loss for each source-estimate pair.\n",
    "def mse(est_target, target):\n",
    "    \"\"\" Batch MSE between a source and its estimate\"\"\"\n",
    "    loss = (target - est_target)**2\n",
    "    return loss.mean(-1)\n",
    "\n",
    "# Compute pairwise losses\n",
    "pairwise_losses = torch.zeros(batch_size, n_sources, n_sources)\n",
    "for i in range(n_sources):\n",
    "    for j in range(n_sources):\n",
    "        pairwise_losses[:, i, j] = mse(estimate_sources[:, i], sources[:, j])\n",
    "# Plot the pairwise losses\n",
    "ax = plt.imshow(pairwise_losses[0].data.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQcAAAD8CAYAAAB6iWHJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAANH0lEQVR4nO3da6wc9XnH8e+vxlxSKDdTYRkXEoEQKW0gUJcIqUJcJBOlOGqJCi8SQCBXUWhI1agJvVA1b0ryIqkIUaoIUCHKBQQpdSOqlBaiBLVQDDKUSyEGVcWEFmLABHGL0dMXO9CTw//Y4J2d3ePz/Ugrz+78vc+zAv28Z2bOPKkqJGm+X5h2A5Jmk+EgqclwkNRkOEhqMhwkNRkOkprGCockByW5NcmPuj8PXGDd60k2dY8N49SUNIyMc51Dki8Az1bV5Uk+CxxYVZ9prHuxqvYdo09JAxs3HB4BTqmqp5KsBL5fVUc31hkO0iIzbjg8X1UHdNsBnnvj+bx124FNwHbg8qq6eYH3Ww+sB1i2zx4n/NLhb3mrRW//PV6edgsTs/WBPafdwkSsPPalabcwMY8+8OpPquqQ1r6dhkOSfwYObez6U+DauWGQ5LmqestxhySrqurJJO8BbgNOq6rHdlT3oGMOqdOv+Z0d9rYY/faK+6bdwsRcd/TqabcwEZc+dv+0W5iY097z6D1VdWJr3x47+8tVdfpC+5L8b5KVc36seHqB93iy+/PxJN8Hjgd2GA6SpmvcU5kbgPO67fOAv5+/IMmBSfbqtlcAJwMPjVlX0oSNGw6XA2ck+RFwevecJCcmuapbcwywMcl9wO2MjjkYDtKM2+mPFTtSVVuB0xqvbwQu6rb/Ffi1cepIGp5XSEpqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ19RIOSdYmeSTJ5m7y1fz9eyW5vtt/V5Ij+qgraXLGDocky4CvAGcC7wXOTfLeecsuZDTw5kjgS8Dnx60rabL6+OawBthcVY9X1WvAt4F189asA67ttm8ETusmZEmaUX2EwyrgiTnPt3SvNddU1XZgG3BwD7UlTchMHZBMsj7JxiQbX33ulWm3Iy1pfYTDk8DcIYmHda811yTZA9gf2Dr/jarqa1V1YlWduNeBe/fQmqRd1Uc43A0cleTdSfYEzmE0Jm+uuWPzzgZuq3HGe0uauLEmXsHoGEKSi4HvAcuAa6rqwSSfAzZW1QbgauDrSTYDzzIKEEkzbOxwAKiqW4Bb5r122ZztV4CP9FFL0jBm6oCkpNlhOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1DTUr8/wkzyTZ1D0u6qOupMkZ+wazc2ZlnsFo2tXdSTZU1UPzll5fVRePW0/SMPq4+/SbszIBkrwxK3N+OLwjrz9abFv7sx7amy1/dsWHp93CxPzGHf817RYm4oJ/2p2/6P7xgnuGmpUJ8LtJ7k9yY5LVjf0/Nw7vtXIcnjRNQx2Q/AfgiKr6deBW/n/i9s+ZOw5vzzgOT5qmQWZlVtXWqnq1e3oVcEIPdSVN0CCzMpOsnPP0LODhHupKmqChZmV+MslZwHZGszLPH7eupMkaalbmpcClfdSSNAyvkJTUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhq6msc3jVJnk7ywAL7k+SKblze/Une30ddSZPT1zeHvwXW7mD/mcBR3WM98NWe6kqakF7Coap+wOiu0gtZB1xXI3cCB8y7Xb2kGTPUMYe3NTLPcXjS7JipA5KOw5Nmx1DhsNOReZJmy1DhsAH4WHfW4iRgW1U9NVBtSbugl4lXSb4FnAKsSLIF+AtgOUBV/Q2jaVgfBDYDLwEX9FFX0uT0NQ7v3J3sL+ATfdSSNIyZOiApaXYYDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpqGGod3SpJtSTZ1j8v6qCtpcnq5hySjcXhXAtftYM0Pq+pDPdWTNGFDjcOTtMj09c3h7fhAkvuAHwOfrqoH5y9Isp7RoF32+uX9eOWbBw3Y3kAeWzbtDvQOHfPXu++/e/+9g31DHZC8Fzi8qt4HfBm4ubVo7ji85Qe8a6DWJLUMEg5V9UJVvdht3wIsT7JiiNqSds0g4ZDk0CTpttd0dbcOUVvSrhlqHN7ZwMeTbAdeBs7ppmBJmlFDjcO7ktGpTkmLhFdISmoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDWNHQ5JVie5PclDSR5MckljTZJckWRzkvuTvH/cupImq497SG4H/qiq7k2yH3BPklur6qE5a84Ejuoevwl8tftT0owa+5tDVT1VVfd22z8FHgZWzVu2DriuRu4EDkiyctzakian12MOSY4AjgfumrdrFfDEnOdbeGuAkGR9ko1JNv7s+Zf6bE3SO9RbOCTZF7gJ+FRVvbAr7+E4PGl29BIOSZYzCoZvVNV3GkueBFbPeX5Y95qkGdXH2YoAVwMPV9UXF1i2AfhYd9biJGBbVT01bm1Jk9PH2YqTgY8C/5FkU/fanwC/Am+Ow7sF+CCwGXgJuKCHupImaOxwqKo7gOxkTQGfGLeWpOF4haSkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FS01Dj8E5Jsi3Jpu5x2bh1JU3WUOPwAH5YVR/qoZ6kAQw1Dk/SItPHN4c37WAcHsAHktwH/Bj4dFU92Pj764H1APse+i5+9YD/6bO9mbDPn+++Y/5uuO9fpt3CRBz/hXOm3cLknLXwrqHG4d0LHF5V7wO+DNzceo+54/D2OXDvvlqTtAsGGYdXVS9U1Yvd9i3A8iQr+qgtaTIGGYeX5NBuHUnWdHW3jltb0uQMNQ7vbODjSbYDLwPndFOwJM2oocbhXQlcOW4tScPxCklJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkpj5uMLt3kn9Pcl83Du8vG2v2SnJ9ks1J7urmW0iaYX18c3gVOLWbSXEcsDbJSfPWXAg8V1VHAl8CPt9DXUkT1Mc4vHpjJgWwvHvMv7P0OuDabvtG4LQ3blUvaTb1NdRmWXdb+qeBW6tq/ji8VcATAFW1HdgGHNxHbUmT0Us4VNXrVXUccBiwJsmxu/I+SdYn2Zhk48vPvdJHa5J2Ua9nK6rqeeB2YO28XU8CqwGS7AHsT2PilbMypdnRx9mKQ5Ic0G3vA5wB/Oe8ZRuA87rts4HbnHglzbY+xuGtBK5NsoxR2NxQVd9N8jlgY1VtYDRL8+tJNgPPArvxTHNp99DHOLz7geMbr182Z/sV4CPj1pI0HK+QlNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1DTUr8/wkzyTZ1D0uGreupMnq4+7Tb8zKfDHJcuCOJP9YVXfOW3d9VV3cQz1JA+jj7tMF7GxWpqRFJn3MlulmVtwDHAl8pao+M2//+cBfAc8AjwJ/WFVPNN5nPbC+e3o08MjYzb19K4CfDFhvKH6uxWfIz3Z4VR3S2tFLOLz5ZqPJV38H/EFVPTDn9YOBF6vq1SS/D/xeVZ3aW+EeJNlYVSdOu4+++bkWn1n5bIPMyqyqrVX1avf0KuCEPutK6t8gszKTrJzz9Czg4XHrSpqsoWZlfjLJWcB2RrMyz++hbt++Nu0GJsTPtfjMxGfr9ZiDpN2HV0hKajIcJDUt+XBIsjbJI0k2J/nstPvpS5Jrkjyd5IGdr148kqxOcnuSh7rL9S+Zdk99eDu/hjB4T0v5mEN3EPVRRmdYtgB3A+dW1UNTbawHSX6L0ZWr11XVsdPupy/dma+VVXVvkv0YXXz34cX+3yxJgF+c+2sIwCWNX0MYzFL/5rAG2FxVj1fVa8C3gXVT7qkXVfUDRmeGditV9VRV3dtt/5TRafFV0+1qfDUyU7+GsNTDYRUw9zLuLewG/6MtFUmOAI4H7ppuJ/1IsizJJuBp4NaqmurnWurhoEUqyb7ATcCnquqFaffTh6p6vaqOAw4D1iSZ6o+DSz0cngRWz3l+WPeaZlj3M/lNwDeq6jvT7qdvC/0awtCWejjcDRyV5N1J9gTOATZMuSftQHfg7mrg4ar64rT76cvb+TWEoS3pcKiq7cDFwPcYHdi6oaoenG5X/UjyLeDfgKOTbEly4bR76snJwEeBU+fcWeyD026qByuB25Pcz+gfrVur6rvTbGhJn8qUtLAl/c1B0sIMB0lNhoOkJsNBUpPhIKnJcJDUZDhIavo/liMm2T5kwzgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# We can also compute the pairwise losses directly using dimension broadcasting\n",
    "def pairwise_mse(est_targets, targets):\n",
    "    \"\"\" Batch pairwise MSE. \"\"\"\n",
    "    targets = targets.unsqueeze(1)\n",
    "    est_targets = est_targets.unsqueeze(2)\n",
    "    pw_loss = (targets - est_targets)**2\n",
    "    mean_over = list(range(3, pw_loss.ndim))\n",
    "    return pw_loss.mean(dim=mean_over)\n",
    "# Compute pairwise losses using broadcasting (+ unit test equality)\n",
    "direct_pairwise_losses = pairwise_mse(estimate_sources, sources)\n",
    "torch.testing.assert_allclose(pairwise_losses, direct_pairwise_losses)\n",
    "# Plot the pairwise losses\n",
    "ax = plt.imshow(direct_pairwise_losses[0].data.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have the loss values for each source-estimate pair, we can compute the average over this matrix for each permutation. \n",
    "Below are plotted the one-hot permutation matrices, which will be individually multiplied with the `pairwise_losses`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABGUAAADHCAYAAACwce6oAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAATxklEQVR4nO3de7QuZV0H8O8PjiZaqRhejlmUka4yZdlNrAgtPbW6YK4oyliRmWV5SU3JWx2NzO6meUtDRE2LytQkjwcDJC2j7KSmQBYoehQRDEsFtPP0x8yWt5d99tn7nIFnOPvzWetdL+/MvM8885th2O+XeWaqtRYAAAAAblqH9O4AAAAAwGYklAEAAADoQCgDAAAA0IFQBgAAAKADoQwAAABAB0IZAAAAgA6EMgDMWlU9uKr+pqqurKprquriqvrNqrp9p/4cWVWtqh4xYXvbq+qrp2jv5qaqbjdu/30PoI3tVfXAVaafXlWXHlAHZ6qqHlJVT9jgd7ZXVbux+gQAbJxQBoDZqqqnJtmR5Jokj0iyLcmLk5yc5IKqulu/3k3myCS/mmRThjJJbpdh+/c7lBm/f4NQJsmvJfmhA2h3zh6SZEOhTJKXJTnmRugLALCftvTuAACspqoekOTUJM9trT1+YdZ5VfW6JP+c5IwkD+jRv4NZVX1Ra+3a3v04UK21/+jdhzlY2Z+ttQ8n+XDv/gAA13OlDABz9eQkVyV5yvKM1tolSZ6T5Liq+taV6eOwolOr6rFVdUlV/XdVnVdVX7/cRlU9tKr+oao+U1X/VVVnVtVXbKB/h1bVs6rqo+P331hVX760jluM/bm0qq4b30+tqluM849Lcs64+M6x/22cvqqqOreq/q6qjq+q91bVtVV1YVX9yCrL3qeq3lBVn6yqz1bV26vqO5aWOb2qPlxVx1TVO6rqs0l+a5x3aVW9qqpOqqqLxjbOr6qjquo2VfWScVjZ5VX1u1W1ZaHdk8dtOXJpfV8YQjPOu2Sc9dKF7T95nP/gqjprrPFnxu19YlUdutDeynCcpy18f/vCtl26tP67VNUZVfWJsXbvrqqfWFpmpe/3q6pXV9Wnqmp3VT2vqm61t32z2KdxPz+xqj449v1NVXXH8fVnVXV1VV1WVacsffeIsa4Xj9+7rKr+pKruurjPkvxkkrsubPOl47zjxs8PraqXVtUVSS5frv1CW1uq6pSqel8NwwOvqKo3V9U9l/r04qr6yMLx9sildu5cVa8Y63TtuM/+uqruuK96AcBm5koZAGZn/HH/nUle31q7Zi+LvSHJb2YYtvLOhek/keSiJI9Lcsskv53k9VV1z9ba58f2fy7Ji5K8PMmzknxJku0ZrsK5d2vtv9fRzackeUeShye5Y5LfTfKqJMctLPOKJD+S5NlJ/i7J/ZM8LcNQpR9P8q4kv5DkBUkem+SC8Xvv28e6vybJ88Y+fzzJo5K8tqquaK2dM27jfZOcn+RfkvxMks8k+bkkZ1fV/Vtr/7zQ3m2TvDbJ7yR5apLPLsw7Nsndk5ySoZ7PTfIXSf4zyQeSnDgu8/Qk/5Hkhfvo+6KPJnlokr9M8hsZ9mnGdpKhTm9N8vwMQ9i+adzmI5L88rjMMUn+PsnpSV4yTlv1apCquk2S85LcftzOyzIcL6+sqlu31v5o6SuvTPKasY/HjOv+ZIbhUvtyUpL3Jvn5JHfKULczMhxrf5Pkj5KckOQ5VfWe1tpZ4/cOH7f1KUmuSLI1yROTvH08hq/JMCzriCTfnOQHx+8tX9n0/HE9JyVZK0h6bYahUM9Ncva47LFJ7pLkwqr60gzH7mHj9l+SYRjhi2q4Auf5YzuvTPKVSZ6Uoa53SvJdSW6971IBwCbWWvPy8vLy8prVK8MPupbkN9ZY5lbjMi9cmNaS/HuSWyxM++Fx+v3Hz1+c5Ookpy2191VJrkvyi/vo25Fje+cuTf+lcfrW8fO9xs/bl5Z7+jj93uPn48bP373O2pw7Ln+/hWmHJrkwyfkL096a5P1Jbrm03PuT/NXCtNPH9o5fZV2XZrha6bYL0x47Lv+ypWXfleSchc8nj8sdubTc9uHPjxvU8xH72O7K8D+TnpYhGDlkab+fusp3Tk9y6cLnR4/LHre03NkZwq1Dl/r+zKXl/jrJxevYRy3JxUm2LEz7vXH60xembRnX+/I12jo0yd3G7/7Q0rZ9eJXlV46n160yb7n2DxyXfewa639GhpDoqKXpL03yiZVtTPI/a7Xj5eXl5eXltfrL8CUADjY7W2ufW/j8nvF9ZWjSMUm+NMmrx6EbW8Yrcy7LEGwcmyRVdcji/Kpa/m/mWUufl9dz7Pj+qqXlVj5/54a26v+7rLX2DysfWmv/m+TMJN8y9vuwsf0zk+xZ2MbKEEAcu9Te5zIEDqv5+9ba1QufLxzfdywtd2GG8GAy41Cjl1TVBzMEZp/LcJ+h22W4Ommjjk3ykdbauUvTX5XhypOvW5r+pqXP78n1+3dfdrbxyqzRDeo2zv9AlupWVY+qqn+tqv9J8vkkHxpn3WOd606S161jmQdnCGVeusYy35PhSrRLlv592ZHkDrm+ZhckeVJVPa6qvqGqagN9BYBNSygDwBxdmeH/zh+5xjIr8y5bmn7V0ueVYR0rQzhWfsyfneFH/uLrGzL80EyS05bmnbbB9Rw+vn90abmPLc3fH5fvZdotM4QLh2e4wuIZueE2PjrJ7ZdCpivGYGc1n1z6fN0a0/d5v5X1Gvv3hiTfnyGIeWCG4Tq/Pi6yP+s6PDfcH8ne98lq+/iL1rmu/apbVT0mwxCwszMMm/qWJPcbZ29km1fbzmV3SHJVa+2zayxzxwxh1vJxdOZCG0nyoxn215OTvDvJR6rqV1YJMwGABe4pA8DstNY+X1XnJXlQVd2qrX5fmZV7afztBpu/cnw/Ocm/rTJ/5X4y25P84cL0T2xwPSs/6O+c6++RsvJ5cf7+uNNepl2X4T4khyXZk+FeNWes1kBrbc/ixwPoy96s7LNbLk2/w/KCe3H3DPeQOam19oWrjarqBw6gT1dl9atNptgnUzkxyVtba09cmVBVX7Uf7axnn34iyeFVddgawcyVGYZYPW4v8y9KktbaxzPcH+kXquoeGW5E/MwMx+OLNtJxANhMhDIAzNXvJNmZ4Sa5T1icMf5IPSXJ21pr71zlu2t5R4bg5Wtaa6/Y20KttUsz3FNlf71tfD8x11/dkSQPG9/PHd9XrrA5bANt362q7rcyhGl8GtEJSf5xDFs+XVXnJ7lPknctBTA3lQ+O7/fKcH+VlRs4P3hpub1t/8oNYr8wFK2Gp1Y9LDd03SrfX815SU6oqm9rrb19YfqPZwge9nWD5ZvCrZN8amnaT62y3LXZ2DGzmrdkuGHyIzLcGHg1b07ymCQfGoOXfWqtXZTkqeMNte91gH0EgIOaUAaAWWqtnV1Vv5rkmTU8OvmMDEM/7pvhh+TVGZ4ss9F2P1VVT0rygqo6IsMTaq5OctcM92E5t7X2JxP0/71V9Zok28cw4h0Z7mfzjCSvaa2t3IPm4gz3DXl4VV2V4cf2RW3tJ0BdnuRPx/pckeHpS187vq94QoZgaEdV/XGG4SxflqF+h7bWfjk3rgsyXCH02+MQlmszPIloefjP5Rmuxjixqt6d5NMZnvDz/gzBzq9X1f9mCGcev5d1vS/J91XVmzMcI7tba7tXWe70DFd8/GVVPS3DU5oeluRBSX52jSFcN6U3Jzmlqp6a5B8zDNv64VWWe1+Gq1weleSfklyzcEytS2vtnKr6iyS/V1V3y3DV2S0yDFd603jvnd/PMDTp/Kr6/QxXxtwmyT2TfEdr7fiqum2G4VavznDvnM8lOT7DU67espE+AcBmY5wvALPVWntWku/N8CPw5Rl+4P18hoDmm1prH1rj62u1+5IMw5/ukeFRvmdlGK60JcmuA+749U7O8Njuh4/r+Onx808u9OXKDPd5uU+GKzkuSPKN+2j3AxmuXvilDI+TPirJj7Xxcdhju+/KcA+WKzM8PvstSf4gw31z3rbc4NTGm9gen+GeP6dnGEq1c/znxeX2ZLhS4/YZfthfkOQHWmvXZXhU88cy7O8XjP1+ziqre3SGMOeN4/cfuZc+fTpD8PaWsZ3XZ6j7Se2Gj8Pu5VkZHu39+Aw36713hkdQL3tZhsdZPztDePPG/VzfiRmO/YdkuCfMaUm+PuM9acabPN8/w/F7SoYb/J6WYd+uHG/XZHj61s8k+fOx38ckeVhr7fX72S8A2BSqtRtjGDkAcGOoqnMzPIb423v3BQCAA+NKGQAAAIAOhDIAAAAAHRi+BAAAANCBK2UAAAAAOhDKAAAAAHQglAEAAADoQCgDAAAA0IFQBgAAAKADoQwAAABAB1vWmrnnY0dN9rzsbVuPnqqpWdq558yasj21X78pa6/u6zdl3R90yAmT1X3H7l2TtDPX/TfXuk9lqv2XTLsPpz7Hq/36zfWYd65ZP8f7+jnXbMxca+9cs35zrftUNsPxnqj9RqxWe1fKAAAAAHQglAEAAADoQCgDAAAA0IFQBgAAAKADoQwAAABAB0IZAAAAgA6EMgAAAAAdCGUAAAAAOhDKAAAAAHQglAEAAADoQCgDAAAA0IFQBgAAAKADoQwAAABAB0IZAAAAgA6EMgAAAAAdCGUAAAAAOhDKAAAAAHSwZa2Z27YePdmKduzeNUk7U/Zpzua4nVPtw2Se25fMs1+boe6bYRvnaI513yz7T+1v/qaq1xyPhTnz9+TN3xx/X8yZc00fzjU3fzenc40rZQAAAAA6EMoAAAAAdCCUAQAAAOhAKAMAAADQgVAGAAAAoAOhDAAAAEAHQhkAAACADoQyAAAAAB0IZQAAAAA6EMoAAAAAdCCUAQAAAOhAKAMAAADQgVAGAAAAoAOhDAAAAEAHQhkAAACADoQyAAAAAB0IZQAAAAA62HJTrWjb1qMnaWfH7l2TtJNM16e5m6pmm6VeczNl3af892dKczy2NsO5Zo7H1lxrNbWDfTvneq6Z47/XB/uxkMyz7pvFwV77Kfu0c89kTc3SHP+bP2dzPN43Q92Tef5NeWOfa1wpAwAAANCBUAYAAACgA6EMAAAAQAdCGQAAAIAOhDIAAAAAHQhlAAAAADoQygAAAAB0IJQBAAAA6EAoAwAAANCBUAYAAACgA6EMAAAAQAdCGQAAAIAOhDIAAAAAHQhlAAAAADoQygAAAAB0IJQBAAAA6EAoAwAAANCBUAYAAACggy29O7BR27YePVlbO3bvmqytOZuyZlPZDLWfahun3H9TtrVzz2RNzZJzzcY4z/TjXLN+czxOpzTXY17d+5njf8sO9uMhmfaYmGO95nqOV/f1m/rv+IO99jc2V8oAAAAAdCCUAQAAAOhAKAMAAADQgVAGAAAAoAOhDAAAAEAHQhkAAACADoQyAAAAAB0IZQAAAAA6EMoAAAAAdCCUAQAAAOhAKAMAAADQgVAGAAAAoAOhDAAAAEAHQhkAAACADoQyAAAAAB0IZQAAAAA6EMoAAAAAdCCUAQAAAOigWmu9+wAAAACw6bhSBgAAAKADoQwAAABAB0IZAAAAgA6EMgAAAAAdCGUAAAAAOhDKAAAAAHQglAEAAADoQCgDAAAA0IFQBgAAAKADoQwAAABAB0IZAAAAgA6EMgAAAAAdCGUAAAAAOhDKAAAAAHSwZa2Zez52VJtqRdu2Hj1VU7O0c8+ZNWV7ar9+U9Ze3ddvyro/6JATJqv7jt27JmlnrvtvrnWfo6mOhSQ55M7/Puk5Xu3Xb8raO9esn3PN+s31eE/UfiOca/pwrlm/uR7vidpvxGq1d6UMAAAAQAdCGQAAAIAOhDIAAAAAHQhlAAAAADoQygAAAAB0IJQBAAAA6EAoAwAAANCBUAYAAACgA6EMAAAAQAdCGQAAAIAOhDIAAAAAHQhlAAAAADoQygAAAAB0IJQBAAAA6EAoAwAAANCBUAYAAACgA6EMAAAAQAdb1pq5bevRk61ox+5dk7QzZZ/m7GDfzqmOh6mpex9T9utg34dTOtjrPmWfdu6ZrKkkar8RU9d+KnOs+1zP8VOa49+Tm+F4n9Jcz39zrb1zTR/ONTd/N6dzjStlAAAAADoQygAAAAB0IJQBAAAA6EAoAwAAANCBUAYAAACgA6EMAAAAQAdCGQAAAIAOhDIAAAAAHQhlAAAAADoQygAAAAB0IJQBAAAA6EAoAwAAANCBUAYAAACgA6EMAAAAQAdCGQAAAIAOhDIAAAAAHQhlAAAAADrYclOtaNvWo2+qVa3bjt27enfhJjHVdk65D6dsa+eeyZqapSmP07nWfY7nhynN9VwzZd3neJ6Zszlu51zPNVM62Ldxruf4g73uczbH2tuHfczxWEjm+3f8VNs417rP2Wb8m9KVMgAAAAAdCGUAAAAAOhDKAAAAAHQglAEAAADoQCgDAAAA0IFQBgAAAKADoQwAAABAB0IZAAAAgA6EMgAAAAAdCGUAAAAAOhDKAAAAAHQglAEAAADoQCgDAAAA0IFQBgAAAKADoQwAAABAB0IZAAAAgA6EMgAAAAAdCGUAAAAAOtjSuwMbtWP3rsna2rb16Mna2rlnsqYmN9V2zrX2czVVvTZDreZorsf7XM81zjP9ONes3xy3cTMc83Ps12aoezJt35xr1m+Ox5e6b4y6b8wca39z4koZAAAAgA6EMgAAAAAdCGUAAAAAOhDKAAAAAHQglAEAAADoQCgDAAAA0IFQBgAAAKADoQwAAABAB0IZAAAAgA6EMgAAAAAdCGUAAAAAOhDKAAAAAHQglAEAAADoQCgDAAAA0IFQBgAAAKADoQwAAABAB0IZAAAAgA6EMgAAAAAdVGutdx8AAAAANh1XygAAAAB0IJQBAAAA6EAoAwAAANCBUAYAAACgA6EMAAAAQAdCGQAAAIAO/g9VT5v4zR2zpAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 1440x216 with 24 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABGUAAADHCAYAAACwce6oAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAfJ0lEQVR4nO3defgkVXkv8O8rI0ZcEFcy0QQMKC4RTNQAcSEmBq9bNF7URKNoooYY9arJVYlrNKi5KmpUomKURGIE18QNURnUAGpMwBWVCCKOIJsIKgJO3T9ONdP0/NahZqqd+Xye5/fMdHX1W6dOn6ruevucU9V1XQAAAADYuq4zdgEAAAAAtkeSMgAAAAAjkJQBAAAAGIGkDAAAAMAIJGUAAAAARiApAwAAADACSRlgFFV1cFV1U3+XVtVpVfUXVbVmM+J1VfWiLVDUSfy3V9VZWyr+Ats7q6revrW2Nw+q6qFV9cxVvuZFVdVtqTJtCbNlrqqb9Mt+fYF111XVus3czjWOiaHrauoY3m2Z9dZV1WeG2u68q6oD+nr53YHi7dO/dzcdIt7Pm6rard//227m67fI8TXv+uPzCat8zVb9nAOgkZQBxnZQkv2SPDzJ55L8fZIXbEac/ZIcOWC5Zr0kycO2YHyShyZZVVIm7T3fbwuUZWu6SZIXJtnkojHJn/d/Q9gW6mp7tE9a+9gukzJJdkvb/81KymTrHV/z5uAkq0rKxOccwChW/Ws0wMBO7brujP7/H6uqPZI8PatMzHRdd8py61TV9bqu++lmlDFd1/3P5ryOLWPyXnZdd06Sc8Yuz5bSdd1XB4y1TdcVW8+1OZfOkyGPr59nU+dTn3MAI9BTBpg3n09y46q6ZZJU1aOq6pNVdX5VXVZV/11Vj5t90WJDNarqzlV1XFVdluSYqnpWVf24qnacWvc9s8MNquqJVXVVVd24f3yNbt1VtaaqXlJV/1NVl1fVBVX1maq650y5ntQPy5qs89bNHYZQVfeoqo/39fCjqvpEVd1jZp27V9XxVXVhVf2kqr5VVW+cen7XqjqqqtZX1U+r6ntV9cFJfS+x7a6qXtrX37f7OvxQVd2y/zumqi6pqu9U1bNnXnuLqnpTVX2jf913qupfquqXptZ5e5LHJfml2jik7az+uclwkD+oqrdU1flJzuufmx0K9Cf9ug+dWrZDVZ3Yv1c3XmIfJ8Nx9u/359KqOq+qnts/f/++/f2oqj5fVb8x8/oFh5zNts2Z53ZLcmb/8C1T+35w//w1hldM1cXD+zZ5cVX9sKqOrqqbLbZvC9VVv2xNVT23qk7v28P6qnpVVf3CzHq37d/vH1c7Fl+b5HpLbW+B7f9+VX25387pVfWIqece3u/X3gu8bl1VLZl0rarr9u3zrKq6ov/3pVV13al1duu38eSq+pu+7f+gqv69qm69QMxre+zuVFWv7197QVW9o6puMrONG/frTI7Hr1fVM6qq+ucPTvK2fvVvTrWP3Zaoi7P6bT2xqs7oy/9fVfXbC6x7n2rnkUv7dn1cVd15Zp111c5tD+7b/0/T9y6pa3deWHA4XU2da6vqgCQn9E8dP7X/B/TPL/n5sNrjq192+6p6X982flJVp1TV/Rcqe1Xt2e/vZf3+v6CqlvxuPdUO/6yqXlZV5/b1/46q2qmq9ujfh8v69+9xM6/fo6r+uarOrI3n+COqapfp9yzJfZL81tQ+r+ufm5zn7l1Vx1bVD5J8drbup2LdoKpeXu38+dO+vO+pqltNrbN7tXPQ+f06p1bVw2bi3K6v1+/3bfLsfvt+IAa2e06EwLzZPcnPklzWP75tkncneXmSDUnuneTIqrp+13X/sIJ4H0jy1iSv6F//wyTXT7Jvkk/1Fz8HJPlJkvsm+Xj/uvsm+ULXdT9cJO6zkzwjyV8nOTXJjZPcLVNDDKrq5UmeleR1Sf4qyS8leWmSO1fV/l3X/WwF5Z/EukuSE5N8Na1bepfkOUlOrKp9u647rapumOS4tGFgBye5NK3r//5Tof45ya/05flOklsl+Z0kO62gGH+c5MtpF2S3SvKaJP+U5EZJPpLkzWnD0V5eVV/quu7D/etumuTyJM9Ncn6StX29/EdV7dV13eVp3eZvkeTuSR7Sv272l/i/77fzx0l+IQvouu6tVXVgWhv5fNd1303y/L4O7rnE+zntqH6/JvtzWLWL6Qck+du0tvl3Sd5fVb/add0VK4i5mO8l+YMk703ysiT/1i9f7hfr16S11T9MsmeSw9LqdZML72W8I8mD046Pk5LcIe292C1tSGGqJTCPTztunpLk+0me3Jd7pfZIOw5e1L/+kCT/WlXnd113Qtpxur6Pe/VwkqraK+3i8vHLxD8qySPS6uEzae/3X6edP/5oZt3n9vv6hCS3TPKqtHo4YGq7Qxy7r03ywX77t09rMz9LSz6mv3j/UNqwmhck+VKSByZ5ddqxcGj//EuTPC+tLU56On1vmW0fkOQ3+jr4adr56iNVtXfXdV/vt//AtHr/UJLH9K97dpJPV9Vduq77zlS82/V18ZIk30py0dRzm3teWIn/Smtzb0jytLSkfdLOg8nynw+rOr6qam1a+7k0yV8kuaTf/oeq6kFd131k5iXvS0uaHZ52HL047bz6tizvuUnWpbWHO6a1jw1J7prkLUlemXacvK2q/rPruq/0r1vbb+P/JLm4r4NDk3w4G4cn/nlam94h7ZhK2mfftKOTvDPJ/84i1wNTx/7eaXV8SpKdkxyYZJck51XVbdKSOt9P+0w8P8kjk7ynqh7add2kzj/Ul/eQJBekHVMPiB+IAZKu6/z58+dvq/9lY2Lh9mlfCHdJ+/L4syTvX+Q11+nXfUuS02ae65K8aOrxi/plT18gxkVJXtg/3ifti/Crk5w8td73krx86vHbk5w19fiDSd67xP7t1u/LC2aW/1ZfrocuUz9nJXn71ON3J/lBkptMLbtxvy/v7R/frY99lyXiXpbkaZvxfnVJvpFkzdSyV/fLnze1bE3al/O3LRFrhyS36V/7sJk6PmeB9Q/o133fAs+9qH2UXWPZTZJ8O8kn0y7or0ry3FW0yRcssD9XJtl9avlD+nXvs9h7tlzbnGkrXZI/XeC165KsW6AuPjqz3qP75b+ziu3eq1/nsYvE2qd//MT+8b4zx9FX+uW7LVOv6xZ4/Q5JTk/y6ZnyXZLkBjNt7OIk118i/p1n97Vf/rxMHQ9T9bxuZr2/7JevnVrv2hy7k/foqJnlr09LTlb/+EH9egfPrHdkWiLl5jPtco8VHqtnJbkiyW2mlt0o7Vzxz1PLzkjyiZnX3jjtgvk1M+/fhkl7WKBtb9Z5YbY9Ti1/e655rp3U5+8us98Lfj5kdcfXK9POF3tMLdshydeT/Nds2ZM8fibel5J8bJlyTsrzyZnl7+2XP2Zq2S59eV64RLw1Se7Zv/auM/v2mQXWn7Snw1dQ90/o133IEtt/a1oi5mYzy49PG56cJDdfLo4/f/78bc9/stPA2E5Pu+C9KMkb0369u3pywr57+Dur6rv9elcm+dO0ZM5KvG/6Qdd1G9J6nNy3X3TfJF9McmySu1XVjarqjkl2zcZu8wv5fJIHVNXfVtU9a2o4VO9+aRcJR1cbIrKm76b92bRfYe+9wvJP3DvJB7uu+8HUvvww7Zff+/SLvpmWuHlTVT2m/wVzoXL/VVU9vap+re8ptFLHd1131dTj0/t/j5sq01VpF3vX2HZVHVJtKMhlaRcZZ/dPrfR9TGbey8X0dfRHaXV2XJJPpfUEWamrfw2f2p9vdF135tQ6k31fqI63hmNmHh+bduG8mol875928f7umTb6sf75SRvdL8l3uql5m/rjaLYMS5l9/c/6Mt9jarjHm9N6bP1hklQbQvW4JP/Udd1Plog9Kec7ZpZPHt9nZvlsT40v9f/+cv/vio7dasPi1kz9zX6n+tAC27leWm+SSbk3JPmXBcq9Y67dpMyndFM9Xbquu7Qvz3592fdM8qsL7OOPk5ycTc9PZ3Vdd+oi29rs88K1NcDnw6x7p9XdZJ6zSVt9Z5J9atPhj7Pv8ZezsR0tZ7bXzUL1dnFaMuvqequqHavq0GpDAH+Sts+f7p8e+nz6e0nO7Tb2dlnI/dOOqUtm2tJxSfbu6+zCtB5WL682rG7PVZQTYJsnKQOM7WFpQ1b2SvuF/LFd112UJP1wnEnX6eek/bJ/9yT/mJXPZ7FQN/8TkuxbVddPG+5xQlqy4vJ+G7+d9kV3qdv4HpZ2R4+HpH0hvrCq3lZVN++fn8zRckY2XixM/m6UZMn5PxZw00X25dy0X1PTdd0lfdnXpyW4zq42h8fDp9Z/ZFoi5/+mJaO+WyuYB6F38czjK5ZYfvXwoqp6al+ej6cNJbhH2vCxZJFhSItYbsjGtFPSft2+XpLX9UmElVpofxbb99WUf0jnTT/o2hCqi9OGBKzULdMu/n+Ua7bP7/fPT9roL85ub6EyrKa8U8t2TBuqk67r1qcNp/mz/vmD0tr9m5aJPRkyONs+zp15fuKimceTYXKT93Klx+4nZp6bnZx8ue3cNMlF3abD3xYr92osVt+T9jHZx7dm0318UDY9Py117G3WeeHaGujzYdZS59lKf66dstB7vNJ93Nx6e1laT513pA13u0c2DiUc+nx6syTfXWadWyZ5bDZtR/9vEqPrui4t2fmfaeX/Rj8XziGrKC/ANsucMsDYvjz9q+SM/dLmP7lX13VXJ0hWOTFgt8CyE9IuBu/d/72567qrqurTaT1ndk/yua7rfrRo0K67Mq33xSuqate0C5lXp/3S/8i0XwaT9kvj7JfsTD2/Uhel9d6Ztet0/P7X7If3dXS3tHkLjunnkvhy13XfT5sj4SlVdfu0nggvTut+fsQqy7RSj0obJvGsyYKq2n0z4iz0Xi7mhWlzrXwxyeFVdUKftNqSLk9rV1erZSbfvRZuNf2g76m1S5a/gJp2YTYmIheyvv/3e0nutFwZlrHQurdKu+A8f2rZG5N8otokyk9OG9603B1yJhfGu+aac4XsOvP8Sq302H1yWpJmYv0C6y7loiQ3raodZxIzm1vuaYvV96R9TPbhudk4j9a02UTRao69lbo8aW13Zv9XeswM8fkwa6nzbJeF28PW9qi03mMvnSzoE1SrtZL39IK04YFLuTDth4nFeiOuT5Ku676V5LF978y90+bseWNVndVtOlcPwHZFTxlgnk0mn71ysqDaHSZ+/1rG/XLaheBfJblB2nCmpM1B8jtpwx2WGrp0DV3Xndt13ZFpFzeTL7DHpw1N+OWu6/5zgb8zFw24sBPThktdfRHY///BaXMHzJbpqn64yPPTzvV3WGCdr3ddd2jahcZyX7yvjZ0y9R72Fpq49adpk8leK1V1r7QJTv86rX5uki2XcJr27Wxajw9cwesmPShWs++PmHl8UNr7fPIqYnw07Zf1nRdpo5Mkw8lJblNVk95Nk0lqZ8uwlNnX79CX+XPTvZi6rvtk2jCOV6fN4bKSybw/1f/7qJnlj+7/XbeKciYrPHb742eh+lqpE9Pes4MWKPcV2fhebk772Hd6+GJ/rnjgVMyvp809c6dF9vGLq9uVzfLt/t+rj5l+Qu39Z9ZbbP9X+vmwmvo7Ma3udpuKuUNaov2/u5VNFL6lbc3z6ceS7FpVD15inY8muUuSryzSlq4xYXvXnJrkmf2iLfnZA/BzQU8ZYJ6dlHbHiDdU1QvTEijPS/v1bufNDdp13eT2oAcl+fzUF+0TsrHL9SeXilFVH0hyWtrdQS5Ou2PG/dMPtei67n+q6hVJXt/3SDkx7Zfh26R14z6ya3edWamXpPXG+UQft0u7U8pOSf6mL9ODkjwpyfvTbgN7g7Q7llya5OSq2jktcXR0Ns7l8/tpPSw+li3no0meXVWHpt0Z6r5pd/yY9dW0ngOHpHVzv7zrui8tsN6i+ouyo9OGlryyf6+flNZb6Liu6466NjuyjH9N8o9VdXjaRNB7p02quZzz0n5tflRVfTFtONGZXdct1ZvqTlX1tn6bt0u7K9S6rus+sdLCdl23rqremTanzKvT3psNaRORPiDJs7uu+0banY2ek+S9/Xv4/bQhRoveXnyRfXxXfxyfn3YHltv1/846Iu3ORRckec8K9uPL/X68qO8lcVJaL4rnJ3nnatvQFjh2F/ORtCGS/1BVt0ibOPkBaXOivKzrugv69SY9hZ5SVUelHbdfXGDY07Tzknys2q3YJ3dfukHaeWRyDnxKkg/0vayOSavvW6UlRc7uuu7VA+zjUj6SNrHzW/p2cb20YZWXzaz3jbR5qJ5QVRf1+/P1rPzzYTXH1+Fpx+zxfcwfpt3J6HZZWYJ1a/hoksdV1ZfShtj9QTZNZCWt3fx5VT0yrQfZpV1/561VeEfaRN/vrKqXpc2rdKO0uy+9puu609OG7X0u7W6Gr09L9u2Slmy5bdd1T6h298DXJnlXX+Yd0ur5qizzWQuwPZCUAeZW13XnV9XD0m5Z++60btCvTRv3/8JrGf6EtKTM9BfC/05LsOyU5XscfKp//VP69c9Ou6Xp306V/9Cq+lq/zlPSEinfSUsYfHM1he267otVdUAf/6i0+Q1OSbv7z2n9at9Mu7X389PmAbk0ba6c+3Vdd05VXS8tifTEtG7/G9Iubh7ddd0HVlOeVfqbtN4qz0jrmXFi2pf6b82sd2TaXDOHZeMdlHZb5bbenPbr8OP6eQzSdd2xVfXWtIvs/1hiuNy1dVTahfufpB96kzZn0pLb67puQ1X9adp+fzzts/nxaXdCWczT0+YzelfaBc6/pyXgVusxSZ6aNrn25PbJZ6VN0nleX74rqup+aXcPemPaRe2/pE1yupKeLEmrg79L28c9+2384SLJjWPTjvO3z/7KvoSD09rTE9IuzNenDad48Qpffw1DHrtLbGNDtdtSH5aWNLlZWr08M+220pP1TuuTK09KO3avkzbE8qwlwp+Y1kPosCS3TrtA/199km0S98NVde+09/3ItOPm3LTzyrsG2MUldV33gz6RfHhaUuictHPF72bq9uRd111YVX+RVkcnprX33+6Tist+Pqzm+Oq6bn1V3TOt7RyRlig6NckDu6776JD7fy08Ne38P/ms+XDa5Nifm1nvFWkT/x6Z5IZpdXfAajbUdd2VVfV7afX5pP7fC5P8R/rhdV3XnV1Vd0ub5+awtDmiLkzrkTpJgp+b9hn5zLT2eHnaxNcP6rruC6spE8C2aHJbRgBgzvWJuRPSEm0LzQXyc6+qnpjW4+x2WzCBts2qqrPSboX8mLHLAgAsT08ZAGB01W5F/6tpvVveLyEDAGwPJGUAgHnwxrS5MU5KuzMLAMA2z/AlAAAAgBG4JTYAAADACCRlAAAAAEYgKQMAAAAwAkkZAAAAgBFIygAAAACMQFIGAAAAYARrlnpyw7l7Dna/7APX7jNUqLl0/IZja8h46n7lhqx79b5yQ9b7/a5z0GD1ftz6UweJM6/v37zW+1Auftx+g8Xa5aiTB4s19Dl+Huv+uut+cbBYVx7wvcFizWubd65ZuXls75c8et/BYu189CmDxdoezjXbQ90716zcvNb7UC476DcHi3XDYz87WKzt4Vzz81T3esoAAAAAjEBSBgAAAGAEkjIAAAAAI5CUAQAAABiBpAwAAADACCRlAAAAAEYgKQMAAAAwAkkZAAAAgBFIygAAAACMQFIGAAAAYASSMgAAAAAjkJQBAAAAGIGkDAAAAMAIJGUAAAAARiApAwAAADACSRkAAACAEUjKAAAAAIxgzVJPHrh2n8E29JCvXjhInH+7480GiTPvhqz7oRy3/tTBYs3j/iXzWa7tod4vfdS+g8U6cO1gobZ55x+y32CxbnHEyYPE2eWoYeLMu/1Pu2KwWCftveMgca484HuDxNleDHU+fdoZpw8SJ0let8deg8WaV4d884xB4hyx5yBh2Aw7H33KYLGGur6YZ8414xjqc/qkvT87SBxW74bHDlf3e37+eoPFWoieMgAAAAAjkJQBAAAAGIGkDAAAAMAIJGUAAAAARiApAwAAADACSRkAAACAEUjKAAAAAIxAUgYAAABgBJIyAAAAACOQlAEAAAAYgaQMAAAAwAgkZQAAAABGICkDAAAAMAJJGQAAAIARSMoAAAAAjEBSBgAAAGAEkjIAAAAAI1iztTb0b3e82SBxjlt/6iBxkuTAtfsMFmue7X/aFYPE2V7qa94MWe/r/3L/wWIN6Ub/esrYRdjE0844fbBYr9tjr8FiDekWR5w8WKyh6mte62poJ+2949hF2KKG/Kwe0iO+du5gsY65w66DxNke2rzvbuPZ1ut+qOuLJHnqhsFCzaUhzzXnH7LfYLHm1Tx+Tp/7jPn8Hj+0O3xhmBTF137jqkHiJMk37/7TwWJlgXONnjIAAAAAI5CUAQAAABiBpAwAAADACCRlAAAAAEYgKQMAAAAwAkkZAAAAgBFIygAAAACMQFIGAAAAYASSMgAAAAAjkJQBAAAAGIGkDAAAAMAIJGUAAAAARiApAwAAADACSRkAAACAEUjKAAAAAIxAUgYAAABgBJIyAAAAACOQlAEAAAAYwZqxC7BaB67dZ7BYL/7WFwaLNc9O2nvHsYuwie2h7i948n6DxLn5m04eJE6SrH3lSYPFyt89Y7hYc+h1e+w1WKx7ffHywWLNqyHrayjHrT917CJsFY//+rcHifO22//KIHGSYT+rj98wWKgcc4ddhws2h+b1XDNke5hHbz37M2MXYVFD1v33n7L/IHFu+YYBv4vMqcPO/NxgsQ7d/R6DxRrKLY4Y7rtp3jBcqCE/9+fxvLXr4QMeO68a9nv8tl73W5qeMgAAAAAjkJQBAAAAGIGkDAAAAMAIJGUAAAAARiApAwAAADACSRkAAACAEUjKAAAAAIxAUgYAAABgBJIyAAAAACOQlAEAAAAYgaQMAAAAwAgkZQAAAABGICkDAAAAMAJJGQAAAIARSMoAAAAAjEBSBgAAAGAEkjIAAAAAI5CUAQAAABhBdV03dhkAAAAAtjt6ygAAAACMQFIGAAAAYASSMgAAAAAjkJQBAAAAGIGkDAAAAMAIJGUAAAAARiApAwAAADACSRkAAACAEUjKAAAAAIxAUgYAAABgBJIyAAAAACOQlAEAAAAYgaQMAAAAwAgkZQAAAABGsGapJ59z2sO7oTb0hbtu2/mf4zccW0PGU/crN2Tdbzh3z8Hq/cC1+wwVai4NWe/3u85Bg9X7s874yiBxXrXHnQaJM7R5rfd59IozPztYrF//lbMHPcdv63V/3PpTB4t1nV2/OZdt3rlm5bb19u5cM57t4Vwz1D7O6/dS55qVm9f2nqj71Vio7rftq3UAAACAOSUpAwAAADACSRkAAACAEUjKAAAAAIxAUgYAAABgBJIyAAAAACOQlAEAAAAYgaQMAAAAwAgkZQAAAABGICkDAAAAMAJJGQAAAIARSMoAAAAAjEBSBgAAAGAEkjIAAAAAI5CUAQAAABiBpAwAAADACCRlAAAAAEawZqknv3DX4XI2l3x4j0Hi7PyAMwaJM++GrPt5tP59dxy7CAs6cO0+Yxdhi9r3tCvHLsKCfvyw3xws1quGOdVsF4as953e99nBYg3l2bsPt3/HbxgsVJLkskfsO1isGx5zymCxhjLkuXTouh/Kq/a409hF2MS8frYOaR6/T87zuWYePfr0cwaLtT2ca+bxu+l11/3i2EXY4q5/4q0GifOT+5w3SJxk+2jvQ/p5Otds21f+AAAAAHNKUgYAAABgBJIyAAAAACOQlAEAAAAYgaQMAAAAwAgkZQAAAABGICkDAAAAMAJJGQAAAIARSMoAAAAAjEBSBgAAAGAEkjIAAAAAI5CUAQAAABiBpAwAAADACCRlAAAAAEYgKQMAAAAwAkkZAAAAgBFIygAAAACMYM3W2tDODzhja21qxS4+eL+xi7BVnPv+OwwSZ9eHfm2QOEmy9mFfHSxWNgwXah4dt/7UwWIduHafwWINWe87ve+zwwWbQ/N6rhmy3ufxPDPPbnjMKWMXYROPPv2cwWIdvdetB4s1pFufcsPBYp2z72WDxRrKvH62Dlnv2Xf+vk/Os/t+6UeDxfrkr91gkDjzen7Y1j3tjNMHi/W6PQYLNbff439yn/MGibM9fLYO7aqP//IgcY7ea5AwW4WeMgAAAAAjkJQBAAAAGIGkDAAAAMAIJGUAAAAARiApAwAAADACSRkAAACAEUjKAAAAAIxAUgYAAABgBJIyAAAAACOQlAEAAAAYgaQMAAAAwAgkZQAAAABGICkDAAAAMAJJGQAAAIARSMoAAAAAjEBSBgAAAGAEkjIAAAAAI5CUAQAAABjBmrELsFrHrT91sFgHrh0sVPKPA8Ya2K4P/dogcYat+30GizWvLnjyfoPEGbSdsmLONavjPDOeSx697yBxjt7rlEHizLNz9r1s7CJsYnto8+p9PJ/8tRsMFmuoc83OR2/755rDzvzcYLEO3f0eg8R53R57DRJnnj3jjGG+iyTJ4XvcYZA4R+9160HizLth636wUD839JQBAAAAGIGkDAAAAMAIJGUAAAAARiApAwAAADACSRkAAACAEUjKAAAAAIxAUgYAAABgBJIyAAAAACOQlAEAAAAYgaQMAAAAwAgkZQAAAABGICkDAAAAMAJJGQAAAIARSMoAAAAAjEBSBgAAAGAEkjIAAAAAI5CUAQAAABiBpAwAAADACKrrurHLAAAAALDd0VMGAAAAYASSMgAAAAAjkJQBAAAAGIGkDAAAAMAIJGUAAAAARiApAwAAADCC/w8oMvuDOLdSBwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 1440x216 with 24 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Let's plot all the permutation matrices\n",
    "eye = torch.eye(n_sources)\n",
    "perms_one_hot = torch.stack([eye[:, perm] for perm in perms], dim=0)\n",
    "fig, axs = plt.subplots(2, len(perms)//2, figsize=(20, 3))\n",
    "fig.suptitle('One-hot permutation matrices', fontsize=16)\n",
    "for i in range(len(perms)):\n",
    "    col, line = divmod(i, 2)\n",
    "    axs[line, col].imshow((perms_one_hot[i]).data.numpy())\n",
    "    axs[line, col].set_axis_off()\n",
    "\n",
    "fig, axs = plt.subplots(2, len(perms)//2, figsize=(20, 3))\n",
    "fig.suptitle('Pairwise loss matrix multiplied by one-hot permutation matrices', fontsize=16)\n",
    "for i in range(len(perms)):\n",
    "    col, line = divmod(i, 2)\n",
    "    axs[line, col].imshow((perms_one_hot[i] * pairwise_losses[0]).data.numpy())\n",
    "    axs[line, col].set_axis_off()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The mean of each matrix above is a potential loss, the minimum of which will be backproped."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Timing the three approaches, for MSE and SI-SDR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "For MSE\n",
      "1. Naive approach \n",
      "880 µs ± 2.76 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
      "2. More efficient approaches (b)\n",
      "734 µs ± 75.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
      "2. More efficient approaches (b)\n",
      "137 µs ± 834 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
     ]
    }
   ],
   "source": [
    "print(\"For MSE\")\n",
    "print(\"1. Naive approach \")\n",
    "loss_func = PITLossWrapper(multisrc_mse, pit_from='perm_avg')\n",
    "%timeit best_loss = loss_func(sources, estimate_sources)\n",
    "print(\"2. More efficient approaches (b)\")\n",
    "loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt')\n",
    "%timeit best_loss = loss_func(sources, estimate_sources)\n",
    "print(\"2. More efficient approaches (b)\")\n",
    "loss_func = PITLossWrapper(pairwise_mse, pit_from='pw_mtx')\n",
    "%timeit best_loss = loss_func(sources, estimate_sources)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "For SI-SDR\n",
      "1. Naive approach \n",
      "3.41 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
      "2. More efficient approaches (b)\n",
      "2.22 ms ± 74.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
      "2. More efficient approaches (b)\n",
      "257 µs ± 7.32 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
     ]
    }
   ],
   "source": [
    "from asteroid.losses import pairwise_neg_sisdr, singlesrc_neg_sisdr, multisrc_neg_sisdr\n",
    "print(\"For SI-SDR\")\n",
    "print(\"1. Naive approach \")\n",
    "loss_func = PITLossWrapper(multisrc_neg_sisdr, pit_from='perm_avg')\n",
    "%timeit best_loss = loss_func(sources, estimate_sources)\n",
    "print(\"2. More efficient approaches (b)\")\n",
    "loss_func = PITLossWrapper(singlesrc_neg_sisdr, pit_from='pw_pt')\n",
    "%timeit best_loss = loss_func(sources, estimate_sources)\n",
    "print(\"2. More efficient approaches (b)\")\n",
    "loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')\n",
    "%timeit best_loss = loss_func(sources, estimate_sources)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Citations:\n",
    "Classic PIT-based speech separation :\n",
    "- [1] Yu, Dong et al. “Permutation Invariant Training of Deep Models for Speaker-Independent Multi-Talker Speech Separation.” 2017 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP).  \n",
    "- [2] Morten Kolbæk et al. \"Multi-talker Speech Separation with Utterance-level Permutation Invariant Training of Deep Recurrent Neural Networks.\" 2017  \n",
    "- [3] Luo, Yi, and Nima Mesgarani. “Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation.” IEEE/ACM Transactions on Audio, Speech, and Language Processing 27.8 (2019)  \n",
    "- [4] Takahashi, Naoya et al. “Recursive Speech Separation for Unknown Number of Speakers.” Interspeech 2019.  \n",
    "\n",
    "PIT-based environmental sound separation : \n",
    "- [5] Kavalerov, Ilya et al. “Universal Sound Separation.” 2019 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA)   \n",
    "- [6] Tzinis, Efthymios et al. \"Improving Universal Sound Separation Using Sound Classification\", 2019. \n",
    "\n",
    "PIT-based end-to-end diariazation\n",
    "- [7] Fujita, Yusuke et al. \"End-to-End Neural Speaker Diarization with Permutation-Free Objectives\", Interspeech 2019\n",
    "- [8] Fujita, Yusuke et al. \"End-to-End Neural Speaker Diarization with Self-attention\", arXiv 2019\n",
    "\n",
    "Papers on PIT alternatives :\n",
    "- [9] C. Fan et al. \"Utterance-level Permutation Invariant Training with Discriminative Learning for Single Channel Speech Separation,\" 2018 11th International Symposium on Chinese Spoken Language Processing (ISCSLP).  \n",
    "- [10] Yang, Gene-Ping et al. \"Interrupted and cascaded permutation invariant training for speech separation\" 2019.  \n",
    "- [11] Yousefi, Midia et al. \"Probabilistic Permutation Invariant Training for Speech Separation.\" Interspeech 2019.   \n",
    "- [12] Tachibana, H. \"Towards Listening to 10 People Simultaneously: An Efficient Permutation Invariant Training of Audio Source Separation Using Sinkhorn's Algorithm.\" arXiv 2020.   "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
