{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9b923db5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "01b924ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = torch.load('outputs/2022-01-10/14-49-01/batch.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bdc5efe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = torch.load('outputs/2022-01-10/14-49-01/outputs.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "bc287d5d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.0112, -0.0131, -0.0125,  ...,  0.0531,  0.0468,  0.0414]],\n",
       "\n",
       "        [[-0.0159, -0.0182, -0.0171,  ...,  0.0740,  0.0637,  0.0568]],\n",
       "\n",
       "        [[-0.0159, -0.0182, -0.0171,  ...,  0.0740,  0.0637,  0.0568]],\n",
       "\n",
       "        ...,\n",
       "\n",
       "        [[ 0.0166,  0.0270,  0.0211,  ..., -0.0097, -0.0200, -0.0292]],\n",
       "\n",
       "        [[ 0.0327,  0.0404,  0.0386,  ...,  0.0054, -0.0060, -0.0182]],\n",
       "\n",
       "        [[ 0.0315,  0.0420,  0.0373,  ..., -0.0037, -0.0144, -0.0262]]],\n",
       "       device='cuda:1', requires_grad=True)"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs['waveform']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "8af4a432",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch['source_masks'][0:4] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "45b574af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.],\n",
       "        [1.]])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch['source_masks']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "60db6513",
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_B = int(batch['source_masks'].sum().item())\n",
    "audio_len = batch['sources'].shape[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "5ed97d71",
   "metadata": {},
   "outputs": [],
   "source": [
    "masked_sources = torch.masked_select(batch['sources'], batch['source_masks'].bool()).view(valid_B, audio_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "78307aa2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([12, 160000])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_sources.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd1d005d",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jointist",
   "language": "python",
   "name": "jointist"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
