{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.io.wavfile import read, write\n",
    "import torchaudio\n",
    "import torch\n",
    "from librosa.util import normalize\n",
    "from librosa.filters import mel as librosa_mel_fn\n",
    "import numpy as np\n",
    "import librosa\n",
    "from IPython.display import Audio\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "import soundfile as sf\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_WAV_VALUE = 32768.0\n",
    "\n",
    "def load_wav(full_path):\n",
    "    sampling_rate, data = read(full_path)\n",
    "    return data, sampling_rate\n",
    "\n",
    "def dynamic_range_compression(x, C=1, clip_val=1e-5):\n",
    "    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)\n",
    "\n",
    "def dynamic_range_decompression(x, C=1):\n",
    "    return np.exp(x) / C\n",
    "\n",
    "def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):\n",
    "    return torch.log(torch.clamp(x, min=clip_val) * C)\n",
    "\n",
    "def dynamic_range_decompression_torch(x, C=1):\n",
    "    return torch.exp(x) / C\n",
    "\n",
    "def spectral_normalize_torch(magnitudes):\n",
    "    output = dynamic_range_compression_torch(magnitudes)\n",
    "    return output\n",
    "\n",
    "def spectral_de_normalize_torch(magnitudes):\n",
    "    output = dynamic_range_decompression_torch(magnitudes)\n",
    "    return output\n",
    "\n",
    "mel_basis = {}\n",
    "hann_window = {}\n",
    "\n",
    "def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):\n",
    "    if torch.min(y) < -1.:\n",
    "        print('min value is ', torch.min(y))\n",
    "    if torch.max(y) > 1.:\n",
    "        print('max value is ', torch.max(y))\n",
    "\n",
    "    global mel_basis, hann_window\n",
    "    if fmax not in mel_basis:\n",
    "        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)\n",
    "        mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)\n",
    "        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)\n",
    "\n",
    "    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')\n",
    "    y = y.squeeze(1)\n",
    "\n",
    "    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],\n",
    "                      center=center, pad_mode='reflect', normalized=False, onesided=True)\n",
    "\n",
    "    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))\n",
    "\n",
    "    spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)\n",
    "    spec = spectral_normalize_torch(spec)\n",
    "\n",
    "    return spec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "45435 449 79147 11754 1232 13732\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 177029/177029 [00:00<00:00, 952852.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "99743\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "ac_train_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/ac_train.json\"\n",
    "ac_val_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/ac_val.json\"\n",
    "audioset_sl_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/audioset_sl.json\"\n",
    "bbc_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/bbc.json\"\n",
    "soundbible_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/soundbible.json\"\n",
    "fsd_10s_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/fsd_10s.json\"\n",
    "fsd50k_10s_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/fsd50k_10s.json\"\n",
    "vgg_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/vggsound.json\"\n",
    "\n",
    "with open(ac_train_json, \"r\") as f:\n",
    "    ac_train_infos = json.load(f)\n",
    "with open(ac_val_json, \"r\") as f:\n",
    "    ac_val_infos = json.load(f)\n",
    "with open(audioset_sl_json, \"r\") as f:\n",
    "    audioset_sl_infos = json.load(f)\n",
    "with open(bbc_json, \"r\") as f:\n",
    "    bbc_infos = json.load(f)\n",
    "with open(soundbible_json, \"r\") as f:\n",
    "    soundbible_infos = json.load(f)\n",
    "with open(fsd_10s_json, \"r\") as f:\n",
    "    fsd_10s_infos = json.load(f)\n",
    "with open(fsd50k_10s_json, \"r\") as f:\n",
    "    fsd50k_10s_infos = json.load(f)\n",
    "with open(vgg_json, \"r\") as f:\n",
    "    vgg_infos = json.load(f)\n",
    "np.random.shuffle(vgg_infos)\n",
    "vgg_infos = vgg_infos[:20*1000]\n",
    "\n",
    "print(len(ac_train_infos), len(ac_val_infos), len(audioset_sl_infos),\n",
    "      len(bbc_infos), len(soundbible_infos), len(fsd_10s_infos))\n",
    "\n",
    "eight_sets_infos = bbc_infos + soundbible_infos + fsd_10s_infos + audioset_sl_infos + ac_train_infos + ac_val_infos + vgg_infos + fsd50k_10s_infos\n",
    "a_infos = []\n",
    "for info in tqdm(eight_sets_infos[:]):\n",
    "    if len(info[\"caption\"].split(\" \")) <= 8:\n",
    "        a_infos.append(info)\n",
    "print(len(a_infos))\n",
    "np.random.shuffle(a_infos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "add_templates = [\"Add \\{A\\} in the background\", \"Mix \\{A\\} in the background\",\n",
    "                 \"Add: \\{A\\} in the background\", \"Fuse: \\{A\\} in the background\",\n",
    "                 \"Combine: \\{A\\} in the background\", \"Add \\{A\\} in the audio\",\n",
    "                 \"Mix \\{A\\} with the audio track\", \"Mix: \\{A\\} with the audio track\"]\n",
    "drop_templates = [\"Drop: \\{A\\}\", \"Drop \\{A\\}\", \"Delete: \\{A\\}\", \"Delete \\{A\\}\",\n",
    "                  \"Omit \\{A\\} from the audio\", \"Erase \\{A\\} from the track\", \"Erase: \\{A\\}\",\n",
    "                  \"Omit: \\{A\\}\", \"Remove: \\{A\\}\", \"Trim: \\{A\\}\", \"Remove \\{A\\}\", \"Drop: \\{A\\}\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "add_wav_path = \"/blob/v-xxxx/AUDITDATA/add_0/wav\"\n",
    "add_mel_path = \"/blob/v-xxxx/AUDITDATA/add_0/mel\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "json_lists = []\n",
    "drop_json_lists = []\n",
    "for i in range(50000)[:]:\n",
    "    a_id, b_id = 0, 0\n",
    "    while (a_id == b_id or a_infos[a_id][\"caption\"] == a_infos[b_id][\"caption\"]):\n",
    "        a_id = np.random.randint(0, len(a_infos))\n",
    "        b_id = np.random.randint(0, len(a_infos))\n",
    "\n",
    "    a_wav_path, a_caption = a_infos[a_id][\"mel\"].replace(\"/mel/\", \"/wav/\").replace(\".npy\", \".wav\"), a_infos[a_id][\"caption\"]\n",
    "    b_wav_path, b_caption = a_infos[b_id][\"mel\"].replace(\"/mel/\", \"/wav/\").replace(\".npy\", \".wav\"), a_infos[b_id][\"caption\"]\n",
    "\n",
    "    template = np.random.choice(add_templates)\n",
    "    caption = template.replace(\"\\{A\\}\", b_caption.lower().replace(\".\", \"\"))\n",
    "    drop_caption = np.random.choice(drop_templates).replace(\"\\{A\\}\", b_caption.lower().replace(\".\", \"\"))\n",
    "\n",
    "    a_mel, b_mel = np.load(a_infos[a_id][\"mel\"]), np.load(a_infos[b_id][\"mel\"])\n",
    "    a_wav, sr = librosa.load(a_wav_path, sr=16000)\n",
    "    b_wav, sr = librosa.load(b_wav_path, sr=16000)\n",
    "\n",
    "    a_wav = np.pad(a_wav, (0, max(0, 16000*10-len(a_wav))), 'wrap')[:16000*10]\n",
    "    b_wav = np.pad(b_wav, (0, max(0, 16000*10-len(b_wav))), 'constant', constant_values=(0,0))[:16000*10]\n",
    "    v1 = []\n",
    "    v2 = []\n",
    "    for j in range(0, len(a_wav), 1024):\n",
    "        v1.append(np.mean(abs(a_wav[j: j+1024])))\n",
    "    for j in range(0, len(b_wav), 1024):\n",
    "        v2.append(np.mean(abs(b_wav[j: j+1024])))\n",
    "    v1, v2 = max(v1), max(v2)\n",
    "    # print(v1, v2)\n",
    "    b_wav = b_wav * v1 / v2\n",
    "    c_wav = a_wav * 0.88 + b_wav * 0.43\n",
    "    \n",
    "    c_wav = np.clip(c_wav, -1, 1)\n",
    "    x = torch.FloatTensor(c_wav)\n",
    "    x = mel_spectrogram(x.unsqueeze(0), n_fft=1024, num_mels=80, sampling_rate=16000,\n",
    "                        hop_size=256, win_size=1024, fmin=0, fmax=8000)\n",
    "    spec = x.cpu().numpy()[0]\n",
    "    # print(spec.shape)\n",
    "\n",
    "    c_wav = c_wav * MAX_WAV_VALUE\n",
    "    c_wav = c_wav.astype('int16')\n",
    "    write(os.path.join(add_wav_path, \"add_{}\".format(str(i))+\".wav\"), 16000, c_wav)\n",
    "    np.save(os.path.join(add_mel_path, \"add_{}\".format(str(i))+\".npy\"), spec)\n",
    "    json_lists.append({\"in_mel\": a_infos[a_id][\"mel\"],\n",
    "                       \"out_mel\": os.path.join(add_mel_path, \"add_{}\".format(str(i))+\".npy\"),\n",
    "                       \"caption\": caption})\n",
    "    drop_json_lists.append({\"in_mel\": os.path.join(add_mel_path, \"add_{}\".format(str(i))+\".npy\"),\n",
    "                            \"out_mel\": a_infos[a_id][\"mel\"],\n",
    "                       \"caption\": drop_caption})\n",
    "    \n",
    "with open(\"/home/v-xxxx/AUDIT_v2/editing_medata_infos/add_0.json\", \"w\") as f:\n",
    "    json.dump(json_lists, f)\n",
    "with open(\"/home/v-xxxx/AUDIT_v2/editing_medata_infos/drop_0.json\", \"w\") as f:\n",
    "    json.dump(drop_json_lists, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "# in_mel_path = \"/blob/v-xxxx/DiffAudioImg/VGGSound/data/ac_train/mel/7TTYk6pgMNw_30000_40000.npy\"\n",
    "# out_mel_path = \"/blob/v-xxxx/AUDITDATA/add_0/mel/add_9.npy\"\n",
    "# in_mel = np.load(in_mel_path)\n",
    "# out_mel = np.load(out_mel_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "control",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
