{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import AutoencoderKL\n",
    "import torch\n",
    "import os\n",
    "import numpy as np\n",
    "import librosa\n",
    "from tqdm import tqdm\n",
    "import soundfile as sf\n",
    "import torchaudio\n",
    "import json\n",
    "from IPython.display import Audio\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae = AutoencoderKL.from_pretrained(\"/blob/v-xxxx/AudioEditingModel/VAE_GAN/checkpoint-40000/vae\")\n",
    "vae.requires_grad_(False)\n",
    "torch_device = torch.device(\"cuda:1\")\n",
    "vae.to(torch_device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'mel': '/blob/v-xxxx/WavCaps/SoundBible/mel/1010.npy',\n",
       "  'caption': 'A spotted owl is making a sound.'},\n",
       " {'mel': '/blob/v-xxxx/WavCaps/SoundBible/mel/1005.npy',\n",
       "  'caption': 'A skillet hits someone in the head, great for cartoon or movie scene.'}]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_json = \"/home/v-xxxx/AUDIT_v2/medata_infos/soundbible.json\"\n",
    "with open(train_json, \"r\") as f:\n",
    "    train_list = json.load(f)\n",
    "train_list[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1232/1232 [01:54<00:00, 10.79it/s]\n"
     ]
    }
   ],
   "source": [
    "vae_lists = []\n",
    "\n",
    "for info in tqdm(train_list[:]):\n",
    "    mel_path = info[\"mel\"]\n",
    "    wav_path = mel_path.replace(\"/mel/\", \"/wav/\").replace(\".npy\", \".wav\")\n",
    "    wav, sr = librosa.load(wav_path, sr=16000)\n",
    "    if len(wav) < 159500:\n",
    "        continue\n",
    "    mel = np.load(mel_path)\n",
    "\n",
    "    test_mel_tensor = np.expand_dims(np.expand_dims(mel[:624], 0), 0)\n",
    "    test_mel_tensor = torch.Tensor(test_mel_tensor).to(torch_device)\n",
    "    with torch.no_grad():\n",
    "        posterior = vae.encode(test_mel_tensor).latent_dist\n",
    "        z = posterior.sample()\n",
    "        vae_output = vae.decode(z).sample\n",
    "    vae_res = vae_output[0][0].cpu().numpy()\n",
    "\n",
    "    np.save(mel_path.replace(\"/mel/\", \"/vae_mel/\"), vae_res)\n",
    "    vae_lists.append(info)\n",
    "\n",
    "vae_json = train_json.replace(\"/medata_infos/\",\"/hifigan_ft_infos/\")\n",
    "with open(vae_json, \"w\") as f:\n",
    "    json.dump(vae_lists, f)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "control",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
