{
 "cells": [
  {
   "cell_type": "code",
<<<<<<< HEAD
   "execution_count": null,
=======
   "execution_count": 1,
>>>>>>> 723b27578c473637c60ef5cea879e7247a6041e1
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.io.wavfile import read, write\n",
    "import torchaudio\n",
    "import torch\n",
    "from librosa.util import normalize\n",
    "from librosa.filters import mel as librosa_mel_fn\n",
    "import numpy as np\n",
    "import librosa\n",
    "from IPython.display import Audio\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "import soundfile as sf\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
<<<<<<< HEAD
   "execution_count": null,
=======
   "execution_count": 2,
>>>>>>> 723b27578c473637c60ef5cea879e7247a6041e1
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_WAV_VALUE = 32768.0\n",
    "\n",
    "def load_wav(full_path):\n",
    "    sampling_rate, data = read(full_path)\n",
    "    return data, sampling_rate\n",
    "\n",
    "def dynamic_range_compression(x, C=1, clip_val=1e-5):\n",
    "    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)\n",
    "\n",
    "def dynamic_range_decompression(x, C=1):\n",
    "    return np.exp(x) / C\n",
    "\n",
    "def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):\n",
    "    return torch.log(torch.clamp(x, min=clip_val) * C)\n",
    "\n",
    "def dynamic_range_decompression_torch(x, C=1):\n",
    "    return torch.exp(x) / C\n",
    "\n",
    "def spectral_normalize_torch(magnitudes):\n",
    "    output = dynamic_range_compression_torch(magnitudes)\n",
    "    return output\n",
    "\n",
    "def spectral_de_normalize_torch(magnitudes):\n",
    "    output = dynamic_range_decompression_torch(magnitudes)\n",
    "    return output\n",
    "\n",
    "mel_basis = {}\n",
    "hann_window = {}\n",
    "\n",
    "def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):\n",
    "    if torch.min(y) < -1.:\n",
    "        print('min value is ', torch.min(y))\n",
    "    if torch.max(y) > 1.:\n",
    "        print('max value is ', torch.max(y))\n",
    "\n",
    "    global mel_basis, hann_window\n",
    "    if fmax not in mel_basis:\n",
    "        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)\n",
    "        mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)\n",
    "        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)\n",
    "\n",
    "    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')\n",
    "    y = y.squeeze(1)\n",
    "\n",
    "    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],\n",
    "                      center=center, pad_mode='reflect', normalized=False, onesided=True)\n",
    "\n",
    "    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))\n",
    "\n",
    "    spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)\n",
    "    spec = spectral_normalize_torch(spec)\n",
    "\n",
    "    return spec"
   ]
  },
  {
   "cell_type": "code",
<<<<<<< HEAD
   "execution_count": null,
=======
   "execution_count": 3,
>>>>>>> 723b27578c473637c60ef5cea879e7247a6041e1
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"/home/v-xxxx/AUDIT_v2/medata_infos/fsd_short.json\", \"r\") as f:\n",
    "    fsd_short_infos = json.load(f)\n",
    "with open(\"/home/v-xxxx/AUDIT_v2/medata_infos/fsd50k_short.json\", \"r\") as f:\n",
    "    fsd50k_short_infos = json.load(f)\n",
    "b_wav_infos = fsd_short_infos + fsd50k_short_infos\n",
    "print(len(b_wav_infos))"
   ]
  },
  {
   "cell_type": "code",
<<<<<<< HEAD
   "execution_count": null,
   "metadata": {},
   "outputs": [],
=======
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "45435 449 79147 11754 1232 13732\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 161749/161749 [00:00<00:00, 1116806.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "112044\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
>>>>>>> 723b27578c473637c60ef5cea879e7247a6041e1
   "source": [
    "ac_train_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/ac_train.json\"\n",
    "ac_val_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/ac_val.json\"\n",
    "audioset_sl_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/audioset_sl.json\"\n",
    "bbc_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/bbc.json\"\n",
    "soundbible_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/soundbible.json\"\n",
    "fsd_10s_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/fsd_10s.json\"\n",
    "vgg_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/vggsound.json\"\n",
    "\n",
    "with open(ac_train_json, \"r\") as f:\n",
    "    ac_train_infos = json.load(f)\n",
    "with open(ac_val_json, \"r\") as f:\n",
    "    ac_val_infos = json.load(f)\n",
    "with open(audioset_sl_json, \"r\") as f:\n",
    "    audioset_sl_infos = json.load(f)\n",
    "with open(bbc_json, \"r\") as f:\n",
    "    bbc_infos = json.load(f)\n",
    "with open(soundbible_json, \"r\") as f:\n",
    "    soundbible_infos = json.load(f)\n",
    "with open(fsd_10s_json, \"r\") as f:\n",
    "    fsd_10s_infos = json.load(f)\n",
    "with open(vgg_json, \"r\") as f:\n",
    "    vgg_infos = json.load(f)\n",
    "np.random.shuffle(vgg_infos)\n",
    "vgg_infos = vgg_infos[:10*1000]\n",
    "\n",
    "print(len(ac_train_infos), len(ac_val_infos), len(audioset_sl_infos),\n",
    "      len(bbc_infos), len(soundbible_infos), len(fsd_10s_infos))\n",
    "\n",
    "seven_sets_infos = bbc_infos + soundbible_infos + fsd_10s_infos + audioset_sl_infos + ac_train_infos + ac_val_infos + vgg_infos\n",
    "a_infos = []\n",
    "for info in tqdm(seven_sets_infos[:]):\n",
    "    if len(info[\"caption\"].split(\" \")) <= 10:\n",
    "        a_infos.append(info)\n",
    "print(len(a_infos))"
   ]
  },
  {
   "cell_type": "code",
<<<<<<< HEAD
   "execution_count": null,
=======
   "execution_count": 5,
>>>>>>> 723b27578c473637c60ef5cea879e7247a6041e1
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.shuffle(a_infos)\n",
    "np.random.shuffle(b_wav_infos)"
   ]
  },
  {
   "cell_type": "code",
<<<<<<< HEAD
   "execution_count": null,
=======
   "execution_count": 6,
>>>>>>> 723b27578c473637c60ef5cea879e7247a6041e1
   "metadata": {},
   "outputs": [],
   "source": [
    "json_lists = []\n",
    "remove_json_lists = []\n",
<<<<<<< HEAD
    "replace_in_wav_path = \"/blob/v-xxxx/AUDITDATA/replace_1/in/wav\"\n",
    "replace_in_mel_path = \"/blob/v-xxxx/AUDITDATA/replace_1/in/mel\"\n",
    "replace_out_wav_path = \"/blob/v-xxxx/AUDITDATA/replace_1/out/wav\"\n",
    "replace_out_mel_path = \"/blob/v-xxxx/AUDITDATA/replace_1/out/mel\""
=======
    "replace_in_wav_path = \"/blob/v-xxxx/AUDITDATA/replace_2/in/wav\"\n",
    "replace_in_mel_path = \"/blob/v-xxxx/AUDITDATA/replace_2/in/mel\"\n",
    "replace_out_wav_path = \"/blob/v-xxxx/AUDITDATA/replace_2/out/wav\"\n",
    "replace_out_mel_path = \"/blob/v-xxxx/AUDITDATA/replace_2/out/mel\""
>>>>>>> 723b27578c473637c60ef5cea879e7247a6041e1
   ]
  },
  {
   "cell_type": "code",
<<<<<<< HEAD
   "execution_count": null,
=======
   "execution_count": 7,
>>>>>>> 723b27578c473637c60ef5cea879e7247a6041e1
   "metadata": {},
   "outputs": [],
   "source": [
    "replace_templates = [\"Replace \\{A\\} with \\{B\\}\", \"Replace \\{A\\} to \\{B\\}\",\n",
    "                     \"Swap out \\{A\\} with \\{B\\}\", \"Swap \\{A\\} with \\{B\\}\",\n",
    "                     \"Exchange \\{A\\} with \\{B\\}\", \"Substitute \\{A\\} with \\{B\\}\",\n",
    "                     \"Replace \\{A\\} with a new audio of \\{B\\}\",\n",
    "                     \"Swap \\{A\\} with a new sound of \\{B\\}\"]"
   ]
  },
  {
   "cell_type": "code",
<<<<<<< HEAD
   "execution_count": null,
   "metadata": {},
   "outputs": [],
=======
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/30000 [00:00<?, ?it/s]/opt/conda/envs/control/lib/python3.8/site-packages/torch/functional.py:632: UserWarning: stft will soon require the return_complex parameter be given for real inputs, and will further require that return_complex=True in a future PyTorch release. (Triggered internally at ../aten/src/ATen/native/SpectralOps.cpp:801.)\n",
      "  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]\n",
<<<<<<< HEAD
      "100%|██████████| 30000/30000 [2:34:10<00:00,  3.24it/s]  \n"
=======
      "100%|██████████| 30000/30000 [2:37:09<00:00,  3.18it/s]  \n"
>>>>>>> 09543133234a20bf3c4a930ea66aca22d1e09787
     ]
    }
   ],
>>>>>>> 723b27578c473637c60ef5cea879e7247a6041e1
   "source": [
    "json_lists = []\n",
    "for i in tqdm(range(30000)[:]):\n",
    "\n",
    "    a_id = np.random.randint(0, len(a_infos))\n",
    "    b_id = np.random.randint(0, len(b_wav_infos))\n",
    "    c_id = np.random.randint(0, len(b_wav_infos))\n",
    "    \n",
    "    a_wav_path, a_caption = a_infos[a_id][\"mel\"].replace(\"/mel/\", \"/wav/\").replace(\".npy\", \".wav\"), a_infos[a_id][\"caption\"]\n",
    "    b_wav_path, b_caption = b_wav_infos[b_id][\"wav\"], b_wav_infos[b_id][\"caption\"]\n",
    "    c_wav_path, c_caption = b_wav_infos[c_id][\"wav\"], b_wav_infos[c_id][\"caption\"]\n",
    "    # print(a_caption, b_caption)\n",
    "\n",
    "    template = np.random.choice(replace_templates)\n",
    "    caption = template.replace(\"\\{A\\}\", b_caption.lower().replace(\".\", \"\")).replace(\"\\{B\\}\", c_caption.lower().replace(\".\", \"\"))\n",
    "\n",
    "    a_wav, sr = librosa.load(a_wav_path, sr=16000)\n",
    "    b_wav, sr = librosa.load(b_wav_path, sr=16000)\n",
    "    b_wav = b_wav[: min(len(b_wav), int(16000*3.5))]\n",
    "    c_wav, sr = librosa.load(c_wav_path, sr=16000)\n",
    "    c_wav = c_wav[: min(len(c_wav), int(16000*3.5))]\n",
    "\n",
    "    a_wav = np.pad(a_wav, (0, max(0, 16000*10-len(a_wav))), 'wrap')[:16000*10]\n",
    "    pre_pad_length = np.random.randint(16000*3, int(16000*6.5))\n",
    "    b_wav = np.pad(b_wav, (pre_pad_length, max(0, 16000*10-len(b_wav)-pre_pad_length)), 'constant', constant_values=(0, 0))[:16000*10]\n",
    "    c_wav = np.pad(c_wav, (pre_pad_length, max(0, 16000*10-len(c_wav)-pre_pad_length)), 'constant', constant_values=(0, 0))[:16000*10]\n",
    "\n",
    "    d_wav = a_wav + b_wav\n",
    "    d_wav = np.clip(d_wav, -1, 1)\n",
    "    e_wav = a_wav + c_wav\n",
    "    e_wav = np.clip(e_wav, -1, 1)\n",
    "\n",
    "    x = torch.FloatTensor(d_wav)\n",
    "    x = mel_spectrogram(x.unsqueeze(0), n_fft=1024, num_mels=80, sampling_rate=16000,\n",
    "                        hop_size=256, win_size=1024, fmin=0, fmax=8000)\n",
    "    d_spec = x.cpu().numpy()[0]\n",
    "    d_wav = d_wav * MAX_WAV_VALUE\n",
    "    d_wav = d_wav.astype('int16')\n",
    "    write(os.path.join(replace_in_wav_path, \"{}\".format(str(i))+\".wav\"), 16000, d_wav)\n",
    "    np.save(os.path.join(replace_in_mel_path, \"{}\".format(str(i))+\".npy\"), d_spec)\n",
    "\n",
    "    x = torch.FloatTensor(e_wav)\n",
    "    x = mel_spectrogram(x.unsqueeze(0), n_fft=1024, num_mels=80, sampling_rate=16000,\n",
    "                        hop_size=256, win_size=1024, fmin=0, fmax=8000)\n",
    "    e_spec = x.cpu().numpy()[0]\n",
    "    e_wav = e_wav * MAX_WAV_VALUE\n",
    "    e_wav = e_wav.astype('int16')\n",
    "    write(os.path.join(replace_out_wav_path, \"{}\".format(str(i))+\".wav\"), 16000, e_wav)\n",
    "    np.save(os.path.join(replace_out_mel_path, \"{}\".format(str(i))+\".npy\"), e_spec)\n",
    "\n",
    "    json_lists.append({\"in_mel\":os.path.join(replace_in_mel_path, \"{}\".format(str(i))+\".npy\"),\n",
    "                       \"out_mel\":os.path.join(replace_out_mel_path, \"{}\".format(str(i))+\".npy\"),\n",
    "                       \"caption\":caption})\n",
    " \n",
<<<<<<< HEAD
    "with open(\"/home/v-xxxx/AUDIT_v2/editing_medata_infos/replace_1.json\", \"w\") as f:\n",
=======
    "with open(\"/home/v-xxxx/AUDIT_v2/editing_medata_infos/replace_2.json\", \"w\") as f:\n",
>>>>>>> 723b27578c473637c60ef5cea879e7247a6041e1
    "    json.dump(json_lists, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_mel_path = \"/blob/v-xxxx/AUDITDATA/replace/in/mel/8.npy\"\n",
    "out_mel_path = \"/blob/v-xxxx/AUDITDATA/replace/out/mel/8.npy\"\n",
    "in_mel = np.load(in_mel_path)\n",
    "out_mel = np.load(out_mel_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Audio(in_mel_path.replace(\"/mel/\", \"/wav/\").replace(\".npy\", \".wav\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Audio(out_mel_path.replace(\"/mel/\", \"/wav/\").replace(\".npy\", \".wav\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(in_mel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(out_mel)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "control",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
