{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "33eecf77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Importing the libraries \n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1d2e3d3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import *\n",
    "from torchvision import transforms\n",
    "import torchvision\n",
    "from tqdm import tqdm\n",
    "from torchvision.utils import save_image\n",
    "to_pil_image = transforms.ToPILImage()\n",
    "def image_to_vid(images):\n",
    "    imgs = [np.array(to_pil_image(img)) for img in images]\n",
    "    imageio.mimsave('../outputs/generated_images.gif', imgs)\n",
    "def save_reconstructed_imagesVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"VAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesEVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"EVAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesREAL(images, epoch):\n",
    "    save_image(images.cpu(), f\"REAL{epoch}.jpg\")\n",
    "def save_train_loss_plot(train_loss, valid_loss):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(train_loss, color='orange', label='train loss VAE')\n",
    "    plt.plot(valid_loss, color='red', label='train loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Total loss (Reconstruction+B*I(K))/M')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST.jpg')\n",
    "    plt.show()\n",
    "def save_valid_loss_plot(valid_lossVAE, valid_lossEVAE):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(valid_lossVAE, color='orange', label='validation loss VAE')\n",
    "    plt.plot(valid_lossEVAE, color='red', label='validation loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Reconstruction loss')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST_valid.jpg')\n",
    "    plt.show()\n",
    "\n",
    "from torchvision.utils import make_grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "02d7c481",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Diffusion:\n",
    "    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device=\"cuda\"):\n",
    "        self.noise_steps = noise_steps\n",
    "        self.beta_start = beta_start\n",
    "        self.beta_end = beta_end\n",
    "        self.img_size = img_size\n",
    "        self.device = device\n",
    "\n",
    "        self.beta = self.prepare_noise_schedule().to(device)\n",
    "        self.alpha = 1. - self.beta\n",
    "        self.alpha_hat = torch.cumprod(self.alpha, dim=0)\n",
    "\n",
    "    def prepare_noise_schedule(self):\n",
    "        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)\n",
    "\n",
    "    def noise_images(self, x, t):\n",
    "        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None]\n",
    "        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None]\n",
    "        Ɛ = torch.randn_like(x)\n",
    "\n",
    "        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ,  Ɛ\n",
    "\n",
    "    def sample_timesteps(self, n):\n",
    "        return torch.randint(low=1, high=self.noise_steps, size=(n,))\n",
    "\n",
    "    def sample(self, model, n):\n",
    "        #logging.info(f\"Sampling {n} new images....\")\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            x = torch.randn((n,8)).to(self.device)\n",
    "            \n",
    "            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):\n",
    "                t = (torch.ones(n) * i).long().to(self.device)\n",
    "                predicted_noise = model(x,t)\n",
    "                alpha = self.alpha[t][:, None]\n",
    "                alpha_hat = self.alpha_hat[t][:, None]\n",
    "                beta = self.beta[t][:, None]\n",
    "                if i > 1:\n",
    "                    noise = torch.randn_like(x)\n",
    "                else:\n",
    "                    noise = torch.zeros_like(x)\n",
    "                #x=x+0.01*(noise-predicted_noise)\n",
    "                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise\n",
    "        model.train()\n",
    "        #x = (x.clamp(-1, 1) + 1) / 2\n",
    "#         x = (x * 255).type(torch.uint8)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "107cb0ab-56e9-4737-a6ed-e74dffa6b713",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "65fac90a-fd94-4a88-b774-284de0135cca",
   "metadata": {},
   "outputs": [],
   "source": [
    "ModuleType = Union[str, Callable[..., nn.Module]]\n",
    "\n",
    "class SiLU(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x * torch.sigmoid(x)\n",
    "\n",
    "def timestep_embedding(timesteps, dim, max_period=10000):\n",
    "    \"\"\"\n",
    "    Create sinusoidal timestep embeddings.\n",
    "\n",
    "    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n",
    "                      These may be fractional.\n",
    "    :param dim: the dimension of the output.\n",
    "    :param max_period: controls the minimum frequency of the embeddings.\n",
    "    :return: an [N x dim] Tensor of positional embeddings.\n",
    "    \"\"\"\n",
    "    half = dim // 2\n",
    "    freqs = torch.exp(\n",
    "        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n",
    "    ).to(device=timesteps.device)\n",
    "    args = timesteps[:, None].float() * freqs[None]\n",
    "    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n",
    "    if dim % 2:\n",
    "        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n",
    "    return embedding\n",
    "\n",
    "def _is_glu_activation(activation: ModuleType):\n",
    "    return (\n",
    "        isinstance(activation, str)\n",
    "        and activation.endswith('GLU')\n",
    "        or activation in [ReGLU, GEGLU]\n",
    "    )\n",
    "\n",
    "\n",
    "def _all_or_none(values):\n",
    "    assert all(x is None for x in values) or all(x is not None for x in values)\n",
    "\n",
    "def reglu(x):\n",
    "    \"\"\"The ReGLU activation function from [1].\n",
    "    References:\n",
    "        [1] Noam Shazeer, \"GLU Variants Improve Transformer\", 2020\n",
    "    \"\"\"\n",
    "    assert x.shape[-1] % 2 == 0\n",
    "    a, b = x.chunk(2, dim=-1)\n",
    "    return a * F.relu(b)\n",
    "\n",
    "\n",
    "def geglu(x):\n",
    "    \"\"\"The GEGLU activation function from [1].\n",
    "    References:\n",
    "        [1] Noam Shazeer, \"GLU Variants Improve Transformer\", 2020\n",
    "    \"\"\"\n",
    "    assert x.shape[-1] % 2 == 0\n",
    "    a, b = x.chunk(2, dim=-1)\n",
    "    return a * F.gelu(b)\n",
    "\n",
    "class ReGLU(nn.Module):\n",
    "    \"\"\"The ReGLU activation function from [shazeer2020glu].\n",
    "\n",
    "    Examples:\n",
    "        .. testcode::\n",
    "\n",
    "            module = ReGLU()\n",
    "            x = torch.randn(3, 4)\n",
    "            assert module(x).shape == (3, 2)\n",
    "\n",
    "    References:\n",
    "        * [shazeer2020glu] Noam Shazeer, \"GLU Variants Improve Transformer\", 2020\n",
    "    \"\"\"\n",
    "\n",
    "    def forward(self, x):\n",
    "        return reglu(x)\n",
    "\n",
    "\n",
    "class GEGLU(nn.Module):\n",
    "    \"\"\"The GEGLU activation function from [shazeer2020glu].\n",
    "\n",
    "    Examples:\n",
    "        .. testcode::\n",
    "\n",
    "            module = GEGLU()\n",
    "            x = torch.randn(3, 4)\n",
    "            assert module(x).shape == (3, 2)\n",
    "\n",
    "    References:\n",
    "        * [shazeer2020glu] Noam Shazeer, \"GLU Variants Improve Transformer\", 2020\n",
    "    \"\"\"\n",
    "\n",
    "    def forward(self, x) :\n",
    "        return geglu(x)\n",
    "\n",
    "def _make_nn_module(module_type: ModuleType, *args) -> nn.Module:\n",
    "    return (\n",
    "        (\n",
    "            ReGLU()\n",
    "            if module_type == 'ReGLU'\n",
    "            else GEGLU()\n",
    "            if module_type == 'GEGLU'\n",
    "            else getattr(nn, module_type)(*args)\n",
    "        )\n",
    "        if isinstance(module_type, str)\n",
    "        else module_type(*args)\n",
    "    )\n",
    "\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    \"\"\"The MLP model used in [gorishniy2021revisiting].\n",
    "\n",
    "    The following scheme describes the architecture:\n",
    "\n",
    "    .. code-block:: text\n",
    "\n",
    "          MLP: (in) -> Block -> ... -> Block -> Linear -> (out)\n",
    "        Block: (in) -> Linear -> Activation -> Dropout -> (out)\n",
    "\n",
    "    Examples:\n",
    "        .. testcode::\n",
    "\n",
    "            x = torch.randn(4, 2)\n",
    "            module = MLP.make_baseline(x.shape[1], [3, 5], 0.1, 1)\n",
    "            assert module(x).shape == (len(x), 1)\n",
    "\n",
    "    References:\n",
    "        * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, \"Revisiting Deep Learning Models for Tabular Data\", 2021\n",
    "    \"\"\"\n",
    "\n",
    "    class Block(nn.Module):\n",
    "        \"\"\"The main building block of `MLP`.\"\"\"\n",
    "\n",
    "        def __init__(\n",
    "            self,\n",
    "            *,\n",
    "            d_in: int,\n",
    "            d_out: int,\n",
    "            bias: bool,\n",
    "            activation: ModuleType,\n",
    "            dropout: float,\n",
    "        ) -> None:\n",
    "            super().__init__()\n",
    "            self.linear = nn.Linear(d_in, d_out, bias)\n",
    "            self.activation = _make_nn_module(activation)\n",
    "            self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "        def forward(self, x):\n",
    "            return self.dropout(self.activation(self.linear(x)))\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        *,\n",
    "        d_in: int,\n",
    "        d_layers: List[int],\n",
    "        dropouts: Union[float, List[float]],\n",
    "        activation: Union[str, Callable[[], nn.Module]],\n",
    "        d_out: int,\n",
    "    ) -> None:\n",
    "        \"\"\"\n",
    "        Note:\n",
    "            `make_baseline` is the recommended constructor.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        if isinstance(dropouts, float):\n",
    "            dropouts = [dropouts] * len(d_layers)\n",
    "        assert len(d_layers) == len(dropouts)\n",
    "        assert activation not in ['ReGLU', 'GEGLU']\n",
    "\n",
    "        self.blocks = nn.ModuleList(\n",
    "            [\n",
    "                MLP.Block(\n",
    "                    d_in=d_layers[i - 1] if i else d_in,\n",
    "                    d_out=d,\n",
    "                    bias=True,\n",
    "                    activation=activation,\n",
    "                    dropout=dropout,\n",
    "                )\n",
    "                for i, (d, dropout) in enumerate(zip(d_layers, dropouts))\n",
    "            ]\n",
    "        )\n",
    "        self.head = nn.Linear(d_layers[-1] if d_layers else d_in, d_out)\n",
    "\n",
    "    @classmethod\n",
    "    def make_baseline(\n",
    "        cls: Type['MLP'],\n",
    "        d_in: int,\n",
    "        d_layers: List[int],\n",
    "        dropout: float,\n",
    "        d_out: int,\n",
    "    ) -> 'MLP':\n",
    "        \"\"\"Create a \"baseline\" `MLP`.\n",
    "\n",
    "        This variation of MLP was used in [gorishniy2021revisiting]. Features:\n",
    "\n",
    "        * :code:`Activation` = :code:`ReLU`\n",
    "        * all linear layers except for the first one and the last one are of the same dimension\n",
    "        * the dropout rate is the same for all dropout layers\n",
    "\n",
    "        Args:\n",
    "            d_in: the input size\n",
    "            d_layers: the dimensions of the linear layers. If there are more than two\n",
    "                layers, then all of them except for the first and the last ones must\n",
    "                have the same dimension. Valid examples: :code:`[]`, :code:`[8]`,\n",
    "                :code:`[8, 16]`, :code:`[2, 2, 2, 2]`, :code:`[1, 2, 2, 4]`. Invalid\n",
    "                example: :code:`[1, 2, 3, 4]`.\n",
    "            dropout: the dropout rate for all hidden layers\n",
    "            d_out: the output size\n",
    "        Returns:\n",
    "            MLP\n",
    "\n",
    "        References:\n",
    "            * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, \"Revisiting Deep Learning Models for Tabular Data\", 2021\n",
    "        \"\"\"\n",
    "        assert isinstance(dropout, float)\n",
    "        if len(d_layers) > 2:\n",
    "            assert len(set(d_layers[1:-1])) == 1, (\n",
    "                'if d_layers contains more than two elements, then'\n",
    "                ' all elements except for the first and the last ones must be equal.'\n",
    "            )\n",
    "        return MLP(\n",
    "            d_in=d_in,\n",
    "            d_layers=d_layers,  # type: ignore\n",
    "            dropouts=dropout,\n",
    "            activation='ReLU',\n",
    "            d_out=d_out,\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        for block in self.blocks:\n",
    "            x = block(x)\n",
    "        x = self.head(x)\n",
    "        return x\n",
    "class MLPDiffusion(nn.Module):\n",
    "    def __init__(self, d_in, num_classes, is_y_cond, rtdl_params, dim_t = 128):\n",
    "        super().__init__()\n",
    "        self.dim_t = dim_t\n",
    "        self.num_classes = num_classes\n",
    "        self.is_y_cond = is_y_cond\n",
    "\n",
    "        # d0 = rtdl_params['d_layers'][0]\n",
    "\n",
    "        rtdl_params['d_in'] = dim_t\n",
    "        rtdl_params['d_out'] = d_in\n",
    "\n",
    "        self.mlp = MLP.make_baseline(**rtdl_params)\n",
    "\n",
    "        if self.num_classes > 0 and is_y_cond:\n",
    "            self.label_emb = nn.Embedding(self.num_classes, dim_t)\n",
    "        elif self.num_classes == 0 and is_y_cond:\n",
    "            self.label_emb = nn.Linear(1, dim_t)\n",
    "        \n",
    "        self.proj = nn.Linear(d_in, dim_t)\n",
    "        self.time_embed = nn.Sequential(\n",
    "            nn.Linear(dim_t, dim_t),\n",
    "            nn.SiLU(),\n",
    "            nn.Linear(dim_t, dim_t)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x, timesteps, y=None):\n",
    "        emb = self.time_embed(timestep_embedding(timesteps, self.dim_t))\n",
    "        if self.is_y_cond and y is not None:\n",
    "            if self.num_classes > 0:\n",
    "                y = y.squeeze()\n",
    "            else:\n",
    "                y = y.resize(y.size(0), 1).float()\n",
    "            emb += F.silu(self.label_emb(y))\n",
    "        x = self.proj(x) + emb\n",
    "        return self.mlp(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9a061ccf-4c54-4ba5-abdf-fe352eaa8c9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResNet(nn.Module):\n",
    "    \"\"\"The ResNet model used in [gorishniy2021revisiting].\n",
    "    The following scheme describes the architecture:\n",
    "    .. code-block:: text\n",
    "        ResNet: (in) -> Linear -> Block -> ... -> Block -> Head -> (out)\n",
    "                 |-> Norm -> Linear -> Activation -> Dropout -> Linear -> Dropout ->|\n",
    "                 |                                                                  |\n",
    "         Block: (in) ------------------------------------------------------------> Add -> (out)\n",
    "          Head: (in) -> Norm -> Activation -> Linear -> (out)\n",
    "    Examples:\n",
    "        .. testcode::\n",
    "            x = torch.randn(4, 2)\n",
    "            module = ResNet.make_baseline(\n",
    "                d_in=x.shape[1],\n",
    "                n_blocks=2,\n",
    "                d_main=3,\n",
    "                d_hidden=4,\n",
    "                dropout_first=0.25,\n",
    "                dropout_second=0.0,\n",
    "                d_out=1\n",
    "            )\n",
    "            assert module(x).shape == (len(x), 1)\n",
    "    References:\n",
    "        * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, \"Revisiting Deep Learning Models for Tabular Data\", 2021\n",
    "    \"\"\"\n",
    "\n",
    "    class Block(nn.Module):\n",
    "        \"\"\"The main building block of `ResNet`.\"\"\"\n",
    "\n",
    "        def __init__(\n",
    "            self,\n",
    "            *,\n",
    "            d_main: int,\n",
    "            d_hidden: int,\n",
    "            bias_first: bool,\n",
    "            bias_second: bool,\n",
    "            dropout_first: float,\n",
    "            dropout_second: float,\n",
    "            normalization: ModuleType,\n",
    "            activation: ModuleType,\n",
    "            skip_connection: bool,\n",
    "        ) -> None:\n",
    "            super().__init__()\n",
    "            self.normalization = _make_nn_module(normalization, d_main)\n",
    "            self.linear_first = nn.Linear(d_main, d_hidden, bias_first)\n",
    "            self.activation = _make_nn_module(activation)\n",
    "            self.dropout_first = nn.Dropout(dropout_first)\n",
    "            self.linear_second = nn.Linear(d_hidden, d_main, bias_second)\n",
    "            self.dropout_second = nn.Dropout(dropout_second)\n",
    "            self.skip_connection = skip_connection\n",
    "\n",
    "        def forward(self, x):\n",
    "            x_input = x\n",
    "            x = self.normalization(x)\n",
    "            x = self.linear_first(x)\n",
    "            x = self.activation(x)\n",
    "            x = self.dropout_first(x)\n",
    "            x = self.linear_second(x)\n",
    "            x = self.dropout_second(x)\n",
    "            if self.skip_connection:\n",
    "                x = x_input + x\n",
    "            return x\n",
    "\n",
    "    class Head(nn.Module):\n",
    "        \"\"\"The final module of `ResNet`.\"\"\"\n",
    "\n",
    "        def __init__(\n",
    "            self,\n",
    "            *,\n",
    "            d_in: int,\n",
    "            d_out: int,\n",
    "            bias: bool,\n",
    "            normalization: ModuleType,\n",
    "            activation: ModuleType,\n",
    "        ) -> None:\n",
    "            super().__init__()\n",
    "            self.normalization = _make_nn_module(normalization, d_in)\n",
    "            self.activation = _make_nn_module(activation)\n",
    "            self.linear = nn.Linear(d_in, d_out, bias)\n",
    "\n",
    "        def forward(self, x):\n",
    "            if self.normalization is not None:\n",
    "                x = self.normalization(x)\n",
    "            x = self.activation(x)\n",
    "            x = self.linear(x)\n",
    "            return x\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        *,\n",
    "        d_in: int,\n",
    "        n_blocks: int,\n",
    "        d_main: int,\n",
    "        d_hidden: int,\n",
    "        dropout_first: float,\n",
    "        dropout_second: float,\n",
    "        normalization: ModuleType,\n",
    "        activation: ModuleType,\n",
    "        d_out: int,\n",
    "    ) -> None:\n",
    "        \"\"\"\n",
    "        Note:\n",
    "            `make_baseline` is the recommended constructor.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "\n",
    "        self.first_layer = nn.Linear(d_in, d_main)\n",
    "        if d_main is None:\n",
    "            d_main = d_in\n",
    "        self.blocks = nn.Sequential(\n",
    "            *[\n",
    "                ResNet.Block(\n",
    "                    d_main=d_main,\n",
    "                    d_hidden=d_hidden,\n",
    "                    bias_first=True,\n",
    "                    bias_second=True,\n",
    "                    dropout_first=dropout_first,\n",
    "                    dropout_second=dropout_second,\n",
    "                    normalization=normalization,\n",
    "                    activation=activation,\n",
    "                    skip_connection=True,\n",
    "                )\n",
    "                for _ in range(n_blocks)\n",
    "            ]\n",
    "        )\n",
    "        self.head = ResNet.Head(\n",
    "            d_in=d_main,\n",
    "            d_out=d_out,\n",
    "            bias=True,\n",
    "            normalization=normalization,\n",
    "            activation=activation,\n",
    "        )\n",
    "\n",
    "    @classmethod\n",
    "    def make_baseline(\n",
    "        cls: Type['ResNet'],\n",
    "        *,\n",
    "        d_in: int,\n",
    "        n_blocks: int,\n",
    "        d_main: int,\n",
    "        d_hidden: int,\n",
    "        dropout_first: float,\n",
    "        dropout_second: float,\n",
    "        d_out: int,\n",
    "    ) -> 'ResNet':\n",
    "        \"\"\"Create a \"baseline\" `ResNet`.\n",
    "        This variation of ResNet was used in [gorishniy2021revisiting]. Features:\n",
    "        * :code:`Activation` = :code:`ReLU`\n",
    "        * :code:`Norm` = :code:`BatchNorm1d`\n",
    "        Args:\n",
    "            d_in: the input size\n",
    "            n_blocks: the number of Blocks\n",
    "            d_main: the input size (or, equivalently, the output size) of each Block\n",
    "            d_hidden: the output size of the first linear layer in each Block\n",
    "            dropout_first: the dropout rate of the first dropout layer in each Block.\n",
    "            dropout_second: the dropout rate of the second dropout layer in each Block.\n",
    "        References:\n",
    "            * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, \"Revisiting Deep Learning Models for Tabular Data\", 2021\n",
    "        \"\"\"\n",
    "        return cls(\n",
    "            d_in=d_in,\n",
    "            n_blocks=n_blocks,\n",
    "            d_main=d_main,\n",
    "            d_hidden=d_hidden,\n",
    "            dropout_first=dropout_first,\n",
    "            dropout_second=dropout_second,\n",
    "            normalization='BatchNorm1d',\n",
    "            activation='ReLU',\n",
    "            d_out=d_out,\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.first_layer(x)\n",
    "        x = self.blocks(x)\n",
    "        x = self.head(x)\n",
    "        return x\n",
    "class ResNetDiffusion(nn.Module):\n",
    "    def __init__(self, d_in, num_classes, rtdl_params, dim_t = 256):\n",
    "        super().__init__()\n",
    "        self.dim_t = dim_t\n",
    "        self.num_classes = num_classes\n",
    "\n",
    "        rtdl_params['d_in'] = d_in\n",
    "        rtdl_params['d_out'] = d_in\n",
    "       # rtdl_params['emb_d'] = dim_t\n",
    "        self.resnet = ResNet.make_baseline(**rtdl_params)\n",
    "\n",
    "        if self.num_classes > 0:\n",
    "            self.label_emb = nn.Embedding(self.num_classes, dim_t)\n",
    "        \n",
    "        self.time_embed = nn.Sequential(\n",
    "            nn.Linear(dim_t, dim_t),\n",
    "            nn.SiLU(),\n",
    "            nn.Linear(dim_t, dim_t)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x, timesteps, y=None):\n",
    "        emb = self.time_embed(timestep_embedding(timesteps, self.dim_t))\n",
    "        if y is not None and self.num_classes > 0:\n",
    "            emb += self.label_emb(y.squeeze())\n",
    "        x=x+emb\n",
    "        return self.resnet(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e816c1d",
   "metadata": {},
   "source": [
    "##EVAE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6f379ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EVAE(nn.Module):\n",
    "    def __init__(self,image_channel=1,kernel_size=4,latent_dim=64,init_channel=32,m=100,B=1):\n",
    "        super(EVAE, self).__init__()\n",
    "\n",
    "        self.bm=np.power(m,-2/9)#nn.Parameter(torch.ones(1),requires_grad=True)#np.power(m,-2/9)\n",
    "        self.latent_dim=latent_dim\n",
    "        self.B=B\n",
    "        # encoder\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channel,out_channels=init_channel,kernel_size=kernel_size,stride=2,padding=1),            \n",
    "            nn.BatchNorm2d(init_channel),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel*2,out_channels=init_channel*4,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.Conv2d(init_channel*4,out_channels=init_channel*8,kernel_size=kernel_size,stride=2,padding=0)   ,            \n",
    "            nn.BatchNorm2d(init_channel*8),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "            #nn.AdaptiveAvgPool2d((100,init_channel*8)),\n",
    "            nn.Flatten( start_dim=1)\n",
    "            \n",
    "            #nn.Linear(init_channel*8,init_channel*8) #hidden\n",
    "            \n",
    "        )\n",
    "        \n",
    "        \n",
    "\n",
    "        ## fully connected laeyer for learning representation\n",
    "        self.fully_connected_layer_a=nn.Linear(init_channel*8,latent_dim)\n",
    "        self.fully_connected_layer_b=nn.Linear(init_channel*8,latent_dim)\n",
    "        #self.fully_connected_layer_beta=nn.Linear(init_channel*8,latent_dim)\n",
    "        # self.prior=nn.Sequential(\n",
    "        # nn.Linear(latent_dim,latent_dim),\n",
    "        # #nn.BatchNorm2d(latent_dim,affine=False),\n",
    "        # nn.ReLU(inplace=True),\n",
    "        # nn.Linear(latent_dim,latent_dim),\n",
    "        # #nn.BatchNorm2d(latent_dim,affine=False),\n",
    "        # #nn.Tanh()\n",
    "        # )\n",
    "\n",
    "        #self.fully_connected_layer=nn.Linear(latent_dim,init_channel*8)      \n",
    "        ## decoder\n",
    "        \n",
    "                # encoder\n",
    "        self.decoder = nn.Sequential(\n",
    "            \n",
    "\n",
    "            \n",
    "            nn.ConvTranspose2d(latent_dim,out_channels=init_channel*4,kernel_size=kernel_size,stride=1,padding=0),            \n",
    "            nn.BatchNorm2d(init_channel*4,affine=False),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*4,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2,affine=False),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*2,out_channels=init_channel*1,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*1,affine=False) ,\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*1,out_channels=image_channel,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        \n",
    "        \n",
    " \n",
    "\n",
    "    def reparametrize(self,mu,log_r):\n",
    "        \n",
    "        coef= self.bm #b(m)\n",
    "        n=log_r.size(dim=0)\n",
    "        d=log_r.size(dim=1)\n",
    "        B=self.B\n",
    "        #beta=torch.exp(0.5*log_beta)\n",
    "        uniform_k=torch.rand(n*d,3).to(device)\n",
    "        U_k=2*uniform_k-1\n",
    "        \n",
    "        uniform_p=torch.rand(n,d).to(device)\n",
    "        U_p=2*uniform_p-1  #uniform prior\n",
    "        #z=self.prior(U_p)\n",
    "        median=torch.median(U_k,1)[0].reshape(n,d) #sample from E kernel\n",
    "        sample=coef*(torch.exp(0.5*log_r)*median+mu)+U_p*torch.exp(0.5*log_r)\n",
    "        return sample#,U_p/2\n",
    "    def forward(self,x):\n",
    "        #encoding\n",
    "\n",
    "        hidden =self.encoder(x)\n",
    "        \n",
    "        #hidden=torch.flatten(x, start_dim=1)\n",
    "     \n",
    "        mean=self.fully_connected_layer_a(hidden)\n",
    "        log_r=self.fully_connected_layer_b(hidden)\n",
    "        #log_beta=self.fully_connected_layer_beta(hidden)\n",
    "        #rand_ind=torch.randint(hidden.size(0), (hidden.size(0),))\n",
    "        #mean=mean[rand_ind]\n",
    "        zs=self.reparametrize(mean,log_r) #uniform distribution\n",
    "\n",
    "      \n",
    "        z=zs.view(-1,self.latent_dim,1,1)\n",
    "        #decoding\n",
    "        reconstruction=self.decoder(z)\n",
    "\n",
    "        return reconstruction,mean,log_r,zs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4c602b50-8be1-4940-9aa7-836a3ab633ab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([86, 82, 49, 81, 52, 93, 32,  3, 82, 45, 73, 46, 91,  0, 26, 85, 60, 21,\n",
       "        25, 16,  7, 68, 78, 44, 42, 74, 33, 48, 76, 89, 79, 39,  1, 54, 53, 27,\n",
       "         8, 12, 27, 61,  6, 76, 95, 36, 77, 23, 14, 54, 79, 25,  0, 72,  7, 31,\n",
       "        11, 33, 84, 58, 74, 80, 63, 20,  4, 78, 41, 77, 91, 68, 50,  9,  9, 38,\n",
       "        53, 64, 94, 90,  3, 84, 50, 54, 32, 41, 98,  7, 53,  6, 73, 78, 40, 28,\n",
       "        65, 85,  9, 41, 14, 11, 27, 49, 47, 92])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.randint(100, (100,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "79c70590",
   "metadata": {},
   "outputs": [],
   "source": [
    "def final_lossEVAE(bce_loss,log_r):\n",
    "    M=log_r.size(dim=1)\n",
    "    N=log_r.size(dim=0)\n",
    "   # m=torch.exp(0.5*log_beta)\n",
    "    return bce_loss#+2*3/5* torch.exp(torch.logsumexp(torch.logsumexp(-0.5*log_r , dim = 1), dim = 0))#/N\n",
    "\n",
    "def model_trainEVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.train()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    for i, data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        reconstruction,mean,log_r,z=model(data)\n",
    "\n",
    "        bce_loss= criterion(reconstruction,data)\n",
    "        \n",
    "        loss=final_lossEVAE(bce_loss,log_r)\n",
    "        #print(loss)\n",
    "        loss.backward()\n",
    "        running_loss+=loss.item()\n",
    "        optimizer.step()\n",
    "    train_loss=running_loss/counter\n",
    "    #print(log_var)\n",
    "    return train_loss,z\n",
    "\n",
    "def model_validateEVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.eval()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    with torch.no_grad():\n",
    "        for i,data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "            counter+=1\n",
    "            data=data[0]\n",
    "            data=data.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            reconstruction,*_=model(data)\n",
    "            bce_loss= criterion(reconstruction,data)\n",
    "            loss=bce_loss#final_loss(bce_loss,mu,log_var)\n",
    "            running_loss+=loss.item()\n",
    "            if i==int(len(dataset)/dataloader.batch_size)-1:\n",
    "                recon_images=reconstruction\n",
    "        valid_loss=running_loss/counter\n",
    "        return valid_loss,recon_images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a445ce36",
   "metadata": {},
   "source": [
    "## dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0f8a070e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset MNIST\n",
       "    Number of datapoints: 60000\n",
       "    Root location: ./\n",
       "    Split: Train"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torchvision.datasets.MNIST('./',download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a8de03d4-0658-4a30-8158-f51211cc002a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2048"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "32*64"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "788b980d",
   "metadata": {},
   "source": [
    "# training EVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "b72a8856",
   "metadata": {},
   "outputs": [],
   "source": [
    "device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "latent_dim=8\n",
    "#VAEmodel=VAE(latent_dim=latent_dim).to(device)\n",
    "EVAEmodel=EVAE(latent_dim=latent_dim,B=10,m=100).to(device)\n",
    "lr=0.0003\n",
    "epochs=50\n",
    "batch_size=100\n",
    "\n",
    "transform=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()])\n",
    "\n",
    "#training set transforms.Resize((32,32)),\n",
    "train_set=torchvision.datasets.MNIST(root='./',train=True,download=False,transform=transform)\n",
    "train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True)\n",
    "\n",
    "#test set\n",
    "test_set=torchvision.datasets.MNIST(root='./',train=False,download=False,transform=transform)\n",
    "test_loader=torch.utils.data.DataLoader(test_set,batch_size=batch_size,shuffle=True)\n",
    "#optimizerVAE=optim.Adam(VAEmodel.parameters(),lr=lr)\n",
    "optimizerEVAE=optim.Adam(EVAEmodel.parameters(),lr=lr)\n",
    "criterion=nn.BCELoss(reduction=\"sum\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "576ebcb9-e05d-4e83-bb4f-d4bb879ae972",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### grid_imagesVAE=[]\n",
    "train_lossVAE=[]\n",
    "valid_lossVAE=[]\n",
    "\n",
    "grid_imagesEVAE=[]\n",
    "train_lossEVAE=[]\n",
    "valid_lossEVAE=[]\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print(f\"Epoch{epoch+1} of {epochs}\")\n",
    "\n",
    "    \n",
    "    train_epoch_lossEVAE,z=model_trainEVAE(EVAEmodel,train_loader,train_set,device,optimizerEVAE,criterion)\n",
    "    valid_epoch_lossEVAE,recon_imagesEVAE=model_validateEVAE(EVAEmodel,test_loader,test_set,device,optimizerEVAE,criterion)\n",
    "    train_lossEVAE.append(train_epoch_lossEVAE)\n",
    "    valid_lossEVAE.append(valid_epoch_lossEVAE)\n",
    "\n",
    "    save_reconstructed_imagesEVAE(recon_imagesEVAE,epoch+1)\n",
    "\n",
    "    print(f\"train loss:{train_epoch_lossEVAE:.4f}\")\n",
    "    print(f\"valid loss:{valid_epoch_lossEVAE:.4f}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67d3ce25-de81-497a-b526-39781ef2d856",
   "metadata": {},
   "outputs": [],
   "source": [
    "z[2]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "11c0a6d9-4116-4c06-8877-e1762b9cd3eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "\n",
    "lr=0.0002\n",
    "epochs=50\n",
    "batch_size=32\n",
    "#UNet().to(device)\n",
    "# unet = MLPDiffusion(d_in=128,num_classes=0,is_y_cond=False,\n",
    "#                     rtdl_params={'d_in':128,'d_layers':[256,,1024,1024,256],'dropout':0.0,'d_out': 128}).to(device)\n",
    "unet =ResNetDiffusion(d_in=latent_dim,num_classes=0,dim_t=latent_dim,\n",
    "                    rtdl_params={'d_in':latent_dim,'n_blocks':4,'d_main':256,'d_hidden':256,'dropout_first':0.0,'dropout_second':0.0,'d_out': latent_dim}).to(device)\n",
    "\n",
    "diffusion = Diffusion(img_size=16, device=device)\n",
    "\n",
    "\n",
    "optimizer30=optim.AdamW(unet.parameters(),lr=lr)\n",
    "\n",
    "criterion=torch.nn.MSELoss(reduction=\"sum\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45464f07-2ace-43d6-a66b-a3439178b92d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "grid_images=[]\n",
    "train_loss=[]\n",
    "valid_loss=[]\n",
    "mus=[]\n",
    "varss=[]\n",
    "\n",
    "EVAEmodel.eval()\n",
    "for epoch in range(50):\n",
    "    print(f\"Epoch{epoch+1} of {epochs}\")\n",
    "    \n",
    "    unet.train()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    #initial\n",
    "\n",
    "    for i, data in tqdm(enumerate(train_loader),total=int(len(train_set)/train_loader.batch_size)):\n",
    "\n",
    "            \n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        images=data.to(device)\n",
    "        optimizer30.zero_grad()\n",
    "        \n",
    "        *_,latent= EVAEmodel(images)\n",
    "        #middle=self.middle(enc)#DKGMmodel_generate.latent(images)\n",
    "        t = diffusion.sample_timesteps(latent.shape[0]).to(device)\n",
    "        x_t, noise = diffusion.noise_images(latent.view(-1,latent_dim), t)\n",
    "        predicted_noise = unet(x_t, t)\n",
    "        loss = criterion(predicted_noise, noise)\n",
    "    \n",
    "\n",
    "            \n",
    "        loss.backward()\n",
    "        optimizer30.step()\n",
    "        running_loss+=loss.item()\n",
    "\n",
    "\n",
    "    train_loss=running_loss/counter\n",
    "    print(train_loss)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cea7c1e3-a80e-4a43-bfbd-85dc6d927795",
   "metadata": {},
   "source": [
    "### generat sample images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "1bf9769f-04f3-4644-9e40-2265a638b7bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "999it [00:01, 722.22it/s]00:00<?, ?it/s]\n",
      "  0%|          | 7/1875 [00:01<06:19,  4.92it/s]\n"
     ]
    }
   ],
   "source": [
    "#DKGMmodel_generate.train()\n",
    "EVAEmodel.eval()\n",
    "unet.eval()\n",
    "diffusion = Diffusion(img_size=16, device=device)\n",
    "#DKGMmodel_debiasing.eval()\n",
    "#diffusion = Diffusion(img_size=16, device=device)\n",
    "index_batch=torch.randperm(train_loader.batch_size)[0]\n",
    "index_sample=torch.randperm(train_loader.batch_size)[:64]\n",
    "\n",
    "train_set=torchvision.datasets.MNIST(root='./',train=True,download=False,transform=transform)\n",
    "train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=False)\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i,data in tqdm(enumerate(train_loader),total=int(len(train_set)/train_loader.batch_size)):\n",
    "        \n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        #optimizerVAE.zero_grad()\n",
    "        #optimizerEVAE.zero_grad()\n",
    "        if i==7:\n",
    "            #reconstruction,mean,log_r,z=EVAEmodel(data)\n",
    "\n",
    "     \n",
    "           \n",
    "            noise= torch.randn(100, 100, device=device)#torch.median(U_k,1)[0].reshape(point.size(0),3,32,32)\n",
    "\n",
    "            state1= diffusion.sample(unet,100)\n",
    "\n",
    "      \n",
    "            z=state1.view(-1,16,1,1)\n",
    "\n",
    "            reconstruction=EVAEmodel.decoder(z)#K_z.view(-1,4,16,16))\n",
    "\n",
    "            save_reconstructed_imagesVAE(data,0)\n",
    "            save_reconstructed_imagesEVAE(reconstruction,0)\n",
    "\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "839ce8fb-ceb3-4ffb-88d4-b191ae0a3551",
   "metadata": {},
   "source": [
    "### FID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a346abb-c6e5-4e16-aaf7-16848bdee53a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#VAEmodel.eval()\n",
    "EVAEmodel.eval()\n",
    "unet.eval()\n",
    "#unet.train()\n",
    "#DKGMmodel_debiasing.eval()\n",
    "#diffusion = Diffusion(img_size=16, device=device)\n",
    "running_loss=0.0\n",
    "counter=0\n",
    "#test set\n",
    "tota_sharpVAE=0.0\n",
    "tota_sharpEVAE=0.0\n",
    "\n",
    "\n",
    "diffusion = Diffusion(img_size=16, device=device)\n",
    "from torcheval.metrics import FrechetInceptionDistance\n",
    "\n",
    "fidEVAE = FrechetInceptionDistance(device=device)            \n",
    " \n",
    "with torch.no_grad():\n",
    "    for i,data in tqdm(enumerate(train_loader),total=int(len(train_set)/train_loader.batch_size)):\n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        #generation\n",
    "        state1=diffusion.sample(unet,data.size(0))\n",
    "\n",
    "        state2=EVAEmodel.decoder(state1.view(-1,8,1,1))    \n",
    "\n",
    "\n",
    "        fidEVAE.update(data.repeat(1, 3, 1,1), is_real=True)\n",
    "        fidEVAE.update(state2.repeat(1, 3, 1,1), is_real=False)\n",
    "\n",
    "lossEVAE=fidEVAE.compute()\n",
    "\n",
    "print(f\"FIDEVAE: {float(lossEVAE)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d848ec-b89e-4217-b316-be70176fafbe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
