{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from create_data_loader import normalize_by_norm, LoadDataset\n",
    "import matplotlib.pyplot as plt\n",
    "from AutoEncoder_models import CnnAEDeepDoubleDescent\n",
    "from torchvision import transforms, datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Create dataset\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "scenario = 'sample_noise'\n",
    "source = 'mnist'\n",
    "check_snr_db = -4\n",
    "add_noise = True"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def mnist_data(scenario, add_noise=True, snr_db=check_snr_db, source='mnist'):\n",
    "    max_value_for_psnr = []\n",
    "    snr = 10 ** (snr_db / 20)\n",
    "    if source == 'mnist':\n",
    "        data = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor()).data / 255.0\n",
    "        data = data[:100]\n",
    "    else:\n",
    "        data = LoadDataset(root=\"./data/mnistm\", n_samples=100).data\n",
    "\n",
    "    if scenario != 'domain_shift':\n",
    "        data = normalize_by_norm(data)\n",
    "        if add_noise:\n",
    "            noise = torch.randn(size=(len(data), data.shape[1], data.shape[2]))\n",
    "            noise = normalize_by_norm(noise) / snr\n",
    "            noisy_data = data + noise\n",
    "        else:\n",
    "            noisy_data = data\n",
    "\n",
    "        for s in noisy_data:\n",
    "                max_value_for_psnr.append(torch.max(s).item())\n",
    "    if source == 'mnist':\n",
    "        data = data[:, None, :, :]\n",
    "        data = data.repeat(1, 3, 1, 1)\n",
    "        noisy_data = noisy_data[:, None, :, :]\n",
    "        noisy_data = noisy_data.repeat(1, 3, 1, 1)\n",
    "    return data, noisy_data, np.mean(max_value_for_psnr)\n",
    "\n",
    "\n",
    "def calc_psnr(img_1, img_2, max_value_psnr):\n",
    "    mse = ((img_1 - img_2)**2).mean().item()\n",
    "    if mse == 0:\n",
    "        psnr = np.inf\n",
    "    else:\n",
    "        psnr = 10 * np.log10(max_value_psnr ** 2 / mse)\n",
    "    return psnr"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "data, noisy_data, max_value_psnr = mnist_data(scenario=scenario, add_noise=add_noise, snr_db=check_snr_db, source=source)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Load model\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "latent_dim = 500\n",
    "channel = 60\n",
    "noise_per = 0.5\n",
    "snr_db = -15\n",
    "model = CnnAEDeepDoubleDescent(latent_dim=latent_dim, channels=channel)\n",
    "model.load_state_dict(torch.load(f\"mnist_sample_noise_models/model_noise_per_{noise_per}_snr_{snr_db}_latent_dim_{latent_dim}_channel_{channel}.pt\", map_location=torch.device('cpu')))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Plot the actual image\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "sample_idx = 0\n",
    "sample = noisy_data[sample_idx]\n",
    "plt.imshow(sample[0], cmap='gray')\n",
    "plt.grid(False)\n",
    "plt.axis('off')\n",
    "if not add_noise:\n",
    "    plt.savefig(f'images/mnist_image_results/image_{sample_idx}_snr_inf.jpg', dpi=300, bbox_inches='tight')\n",
    "elif add_noise:\n",
    "    plt.savefig(f'images/mnist_image_results/image_{sample_idx}_snr_{check_snr_db}.jpg', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Plot the reconstructed image\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    sample = sample.reshape(1, 3, 28, 28)\n",
    "    model.eval()\n",
    "    recon_sample = model(sample)[0]\n",
    "\n",
    "plt.imshow(recon_sample[0][0], cmap='gray')\n",
    "plt.grid(False)\n",
    "plt.axis('off')\n",
    "if not add_noise :\n",
    "    plt.savefig(f'images/mnist_image_results/reconstructed_image_{sample_idx}_snr_inf_latent_{latent_dim}_channel_{channel}.jpg', dpi=300, bbox_inches='tight')\n",
    "elif add_noise:\n",
    "    plt.savefig(f'images/mnist_image_results/reconstructed_image_{sample_idx}_snr_{check_snr_db}_latent_{latent_dim}_channel_{channel}.jpg', dpi=300, bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "#### PSNR value\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "calc_psnr(sample[0], recon_sample[0], max_value_psnr)"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
