{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.io.wavfile import read, write\n",
    "import torchaudio\n",
    "import torch\n",
    "from librosa.util import normalize\n",
    "from librosa.filters import mel as librosa_mel_fn\n",
    "import numpy as np\n",
    "import librosa\n",
    "import librosa.display\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "import soundfile as sf\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_WAV_VALUE = 32768.0\n",
    "\n",
    "def load_wav(full_path):\n",
    "    sampling_rate, data = read(full_path)\n",
    "    return data, sampling_rate\n",
    "\n",
    "def dynamic_range_compression(x, C=1, clip_val=1e-5):\n",
    "    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)\n",
    "\n",
    "def dynamic_range_decompression(x, C=1):\n",
    "    return np.exp(x) / C\n",
    "\n",
    "def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):\n",
    "    return torch.log(torch.clamp(x, min=clip_val) * C)\n",
    "\n",
    "def dynamic_range_decompression_torch(x, C=1):\n",
    "    return torch.exp(x) / C\n",
    "\n",
    "def spectral_normalize_torch(magnitudes):\n",
    "    output = dynamic_range_compression_torch(magnitudes)\n",
    "    return output\n",
    "\n",
    "def spectral_de_normalize_torch(magnitudes):\n",
    "    output = dynamic_range_decompression_torch(magnitudes)\n",
    "    return output\n",
    "\n",
    "mel_basis = {}\n",
    "hann_window = {}\n",
    "\n",
    "def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):\n",
    "    if torch.min(y) < -1.:\n",
    "        print('min value is ', torch.min(y))\n",
    "    if torch.max(y) > 1.:\n",
    "        print('max value is ', torch.max(y))\n",
    "\n",
    "    global mel_basis, hann_window\n",
    "    if fmax not in mel_basis:\n",
    "        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)\n",
    "        mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)\n",
    "        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)\n",
    "\n",
    "    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')\n",
    "    y = y.squeeze(1)\n",
    "\n",
    "    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],\n",
    "                      center=center, pad_mode='reflect', normalized=False, onesided=True)\n",
    "\n",
    "    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))\n",
    "\n",
    "    spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)\n",
    "    spec = spectral_normalize_torch(spec)\n",
    "\n",
    "    return spec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "wav_path = \"/blob/v-xxxx/DiffAudioImg/VGGSound/data/vggsound/wav\"\n",
    "inpaint_wav_path = \"/blob/v-xxxx/DiffAudioImg/VGGSound/data/vggsound/inpaint/wav\"\n",
    "inpaint_mel_path = \"/blob/v-xxxx/DiffAudioImg/VGGSound/data/vggsound/inpaint/mel\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "177056"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wav_lists = os.listdir(wav_path)\n",
    "len(wav_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "177048"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inpaint_wav_list = os.listdir(inpaint_wav_path)\n",
    "inpaint_wav_set = set(inpaint_wav_list)\n",
    "len(inpaint_wav_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "177044\n"
     ]
    }
   ],
   "source": [
    "with open(\"/home/v-xxxx/AUDIT_v2/editing_medata_infos/inpaint_vgg.json\", \"r\") as f:\n",
    "    vgg_inpaint_infos = json.load(f)\n",
    "print(len(vgg_inpaint_infos))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 176044/176044 [45:33<00:00, 64.41it/s]  \n"
     ]
    }
   ],
   "source": [
    "for info in tqdm(vgg_inpaint_infos[1000:]):\n",
    "\n",
    "    mel = np.load(info[\"in_mel\"])\n",
    "    if mel.shape[1] < 624:\n",
    "        # print(mel.shape)\n",
    "        mel = np.pad(mel, ((0, 0), (0, 624 - mel.shape[1])), 'wrap')\n",
    "        np.save(info[\"in_mel\"], mel)\n",
    "\n",
    "    # if file_name in inpaint_wav_set:\n",
    "    #     continue\n",
    "    # wav, sr = librosa.load(os.path.join(wav_path, file_name), sr=16000)\n",
    "    # wav_len = len(wav)\n",
    "    # start = np.random.randint(int(wav_len * 0.25), int(wav_len * 0.4))\n",
    "    # end = np.random.randint(int(wav_len * 0.6), int(wav_len * 0.75))\n",
    "    # wav[start : end] = 0\n",
    "    # x = torch.FloatTensor(wav)\n",
    "    # # print(len(x))\n",
    "    # x = mel_spectrogram(x.unsqueeze(0), n_fft=1024, num_mels=80, sampling_rate=16000,\n",
    "    #                 hop_size=256, win_size=1024, fmin=0, fmax=8000)\n",
    "    # # print(x.shape)\n",
    "    # spec = x.cpu().numpy()[0]\n",
    "    # # print(spec.shape)\n",
    "    # wav = wav * MAX_WAV_VALUE\n",
    "    # wav = wav.astype('int16')\n",
    "    # write(os.path.join(inpaint_wav_path, file_name), 16000, wav)\n",
    "    # np.save(os.path.join(inpaint_mel_path, file_name.replace(\".wav\", \".npy\")), spec)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "control",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
