{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import diffusers\n",
    "from diffusers import AutoencoderKL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_id = \"/blob/v-xxxx/DiffAudioImg/VGGSound/data/ac_train/mel/-0A1_JR5f34_16000_26000.npy\"\n",
    "test_mel = np.load(test_id)[:,:624]\n",
    "print(test_mel.shape)\n",
    "plt.imshow(test_mel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae = AutoencoderKL.from_pretrained(\"/blob/v-xxxx/AUDITPLUS/VAEGAN/checkpoint-24000\")\n",
    "vae.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_mel_tensor = np.expand_dims(np.expand_dims(test_mel, 0), 0)\n",
    "test_mel_tensor = torch.Tensor(test_mel_tensor)\n",
    "test_mel_tensor.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    posterior = vae.encode(test_mel_tensor).latent_dist\n",
    "    z = posterior.sample()\n",
    "    vae_output = vae.decode(z).sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae_res = vae_output[0][0].numpy()\n",
    "print(vae_res.shape)\n",
    "plt.imshow(vae_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "abs_diff = abs(vae_res - test_mel)\n",
    "print(abs_diff.max(), abs_diff.min(), abs_diff.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Audio\n",
    "from shutil import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_wav_id = test_id.replace(\"/mel/\", \"/wav/\").replace(\".npy\", \".wav\")\n",
    "Audio(test_wav_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "copy(test_id, '/home/v-xxxx/AUDIT_v2/test_vae_mel/'+'test_0.npy')\n",
    "np.save('/home/v-xxxx/AUDIT_v2/test_vae_mel/'+'test_1.npy', vae_res)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "control",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
