{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.io.wavfile import read, write\n",
    "import torchaudio\n",
    "import torch\n",
    "from librosa.util import normalize\n",
    "from librosa.filters import mel as librosa_mel_fn\n",
    "import numpy as np\n",
    "import librosa\n",
    "import librosa.display\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "import soundfile as sf\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_WAV_VALUE = 32768.0\n",
    "\n",
    "def load_wav(full_path):\n",
    "    sampling_rate, data = read(full_path)\n",
    "    return data, sampling_rate\n",
    "\n",
    "def dynamic_range_compression(x, C=1, clip_val=1e-5):\n",
    "    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)\n",
    "\n",
    "def dynamic_range_decompression(x, C=1):\n",
    "    return np.exp(x) / C\n",
    "\n",
    "def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):\n",
    "    return torch.log(torch.clamp(x, min=clip_val) * C)\n",
    "\n",
    "def dynamic_range_decompression_torch(x, C=1):\n",
    "    return torch.exp(x) / C\n",
    "\n",
    "def spectral_normalize_torch(magnitudes):\n",
    "    output = dynamic_range_compression_torch(magnitudes)\n",
    "    return output\n",
    "\n",
    "def spectral_de_normalize_torch(magnitudes):\n",
    "    output = dynamic_range_decompression_torch(magnitudes)\n",
    "    return output\n",
    "\n",
    "mel_basis = {}\n",
    "hann_window = {}\n",
    "\n",
    "def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):\n",
    "    if torch.min(y) < -1.:\n",
    "        print('min value is ', torch.min(y))\n",
    "    if torch.max(y) > 1.:\n",
    "        print('max value is ', torch.max(y))\n",
    "\n",
    "    global mel_basis, hann_window\n",
    "    if fmax not in mel_basis:\n",
    "        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)\n",
    "        mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)\n",
    "        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)\n",
    "\n",
    "    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')\n",
    "    y = y.squeeze(1)\n",
    "\n",
    "    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],\n",
    "                      center=center, pad_mode='reflect', normalized=False, onesided=True)\n",
    "\n",
    "    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))\n",
    "\n",
    "    spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)\n",
    "    spec = spectral_normalize_torch(spec)\n",
    "\n",
    "    return spec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "wav_path = \"/blob/v-xxxx/DiffAudioImg/VGGSound/data/ac_train/wav\"\n",
    "sr_wav_path = \"/blob/v-xxxx/DiffAudioImg/VGGSound/data/ac_train/sr/wav\"\n",
    "sr_mel_path = \"/blob/v-xxxx/DiffAudioImg/VGGSound/data/ac_train/sr/mel\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "45435"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wav_files = os.listdir(wav_path)\n",
    "len(wav_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sr_wav_set = set(os.listdir(sr_wav_path))\n",
    "len(sr_wav_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/45435 [00:00<?, ?it/s]/opt/conda/envs/control/lib/python3.8/site-packages/torch/functional.py:632: UserWarning: stft will soon require the return_complex parameter be given for real inputs, and will further require that return_complex=True in a future PyTorch release. (Triggered internally at ../aten/src/ATen/native/SpectralOps.cpp:801.)\n",
      "  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]\n",
      "100%|██████████| 45435/45435 [2:20:44<00:00,  5.38it/s]  \n"
     ]
    }
   ],
   "source": [
    "for file_name in tqdm(wav_files[:]):\n",
    "    if file_name in sr_wav_set:\n",
    "        continue\n",
    "    try:\n",
    "        wav, sr = librosa.load(os.path.join(wav_path, file_name), sr=8000)\n",
    "        wav = np.clip(wav, -1, 1)\n",
    "        wav = wav * MAX_WAV_VALUE\n",
    "        wav = wav.astype('int16')\n",
    "        write(os.path.join(sr_wav_path, file_name), 8000, wav)\n",
    "\n",
    "        wav, sr = librosa.load(os.path.join(sr_wav_path, file_name), sr=16000)\n",
    "        wav = np.clip(wav, -1, 1)\n",
    "        x = torch.FloatTensor(wav)\n",
    "        # print(len(x))\n",
    "        x = mel_spectrogram(x.unsqueeze(0), n_fft=1024, num_mels=80, sampling_rate=16000,\n",
    "                        hop_size=256, win_size=1024, fmin=0, fmax=8000)\n",
    "        # print(x.shape)\n",
    "        spec = x.cpu().numpy()[0]\n",
    "        if spec.shape[1] < 624:\n",
    "            spec = np.pad(spec, ((0, 0), (0, 624 - spec.shape[1])), 'wrap')\n",
    "        # print(spec.shape)\n",
    "        wav = wav * MAX_WAV_VALUE\n",
    "        wav = wav.astype('int16')\n",
    "        write(os.path.join(sr_wav_path, file_name), 16000, wav)\n",
    "        np.save(os.path.join(sr_mel_path, file_name.replace(\".wav\", \".npy\")), spec) \n",
    "\n",
    "    except:\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "control",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
