{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3bac134c-b48c-40c7-b09f-91a14fa50970",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eea05e25-78c1-4a7e-86cb-4f4646807877",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.backends.cudnn as cudnn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import numpy as np\n",
    "\n",
    "import os\n",
    "import argparse\n",
    "from tqdm.auto import tqdm\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "028130ee-f0d6-4814-84c6-0d4fc321feb5",
   "metadata": {},
   "source": [
    "## Try the vertical noises first"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "50c71ee3-1a79-43dc-b60a-47cc07e80f8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "best_acc = 0  # best test accuracy\n",
    "start_epoch = 0  # start from epoch 0 or last checkpoint epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ad81372b-7a5a-4d66-82c8-4bc524838375",
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_train_val(train, val_split):\n",
    "    train_len = int(len(train) * (1.0-val_split))\n",
    "    train, val = torch.utils.data.random_split(\n",
    "        train,\n",
    "        (train_len, len(train) - train_len),\n",
    "        generator=torch.Generator().manual_seed(42),\n",
    "    )\n",
    "    return train, val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 204,
   "id": "ae05f286-f2e3-479d-b021-09395b8e2e0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "res1 = 2048\n",
    "res2 = 128\n",
    "res = 256\n",
    "class AddWaveTransform:\n",
    "    def __call__(self, image):\n",
    "        magnitude = 3\n",
    "        frequency = 2.99433\n",
    "        waves0 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res2)).unsqueeze(0)\n",
    "        waves1 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res2)).unsqueeze(0)\n",
    "        waves2 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res2)).unsqueeze(0)\n",
    "        image[0, :, :] += waves0\n",
    "        image[1, :, :] += waves1\n",
    "        image[2, :, :] += waves2\n",
    "        return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "298c0932-8aa4-4869-8724-42e56894e765",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the path to store the CelebA dataset\n",
    "data_dir = './data/CelebA'\n",
    "\n",
    "# Define transformations for the dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.CenterCrop(178),  # Center crop to 178x178\n",
    "    transforms.Resize((res1,res2)),      # Resize to 128x128\n",
    "    transforms.ToTensor(),       # Convert to tensor\n",
    "    AddWaveTransform(),\n",
    "    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize\n",
    "    transforms.Lambda(lambda x: x.view(3, res1*res2)),\n",
    "])\n",
    "\n",
    "# Load the CelebA dataset\n",
    "dataset = datasets.CelebA(root=data_dir, split='train', transform=transform, download=False)\n",
    "\n",
    "# Create DataLoader for the dataset\n",
    "dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)\n",
    "\n",
    "# Check the dataset\n",
    "print(f'Number of samples: {len(dataset)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 206,
   "id": "528069c4-9b68-42b9-b5f1-11e20e87dac2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to denormalize and plot images\n",
    "from PIL import Image\n",
    "def imshow(img):\n",
    "    img = img / 2 + 0.5  # Unnormalize\n",
    "    img = img.reshape(img.shape[:-1] + (res1, res2))\n",
    "    img = torch.clamp(img, 0, 1)\n",
    "    npimg = np.transpose(img.detach().numpy(), (1,2,0))\n",
    "    print(npimg.shape)\n",
    "    rescaled_image = Image.fromarray((npimg*225).astype(np.uint8)).resize((256, 256))\n",
    "    rescaled_image.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8904a465-d127-46c0-93c5-4c6cc0869891",
   "metadata": {},
   "outputs": [],
   "source": [
    "image, label = dataset[2]\n",
    "print(image.shape)\n",
    "imshow(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "id": "b77b9bb9-9205-4e8c-8eaf-c96361968e46",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "from einops import rearrange, repeat\n",
    "\n",
    "from src.models.nn import DropoutNd\n",
    "\n",
    "class S4DKernel_simple(nn.Module):\n",
    "    \"\"\"Generate convolution kernel from diagonal SSM parameters.\"\"\"\n",
    "\n",
    "    def __init__(self, d_model, N=1, lr=0.0001):\n",
    "        super().__init__()\n",
    "        H = d_model\n",
    "        log_dt = torch.rand(H) * (\n",
    "            math.log(1e-3) - math.log(1e-3)\n",
    "        ) + math.log(1e-3)\n",
    "\n",
    "        C = torch.randn(H, N, dtype=torch.cfloat)\n",
    "        self.C = nn.Parameter(torch.view_as_real(C))\n",
    "        self.register(\"log_dt\", log_dt, 0)\n",
    "\n",
    "        log_A_real = torch.log(0.5 * torch.ones(H, N))\n",
    "        A_imag = math.pi * repeat(torch.arange(N) - N // 2, 'n -> h n', h=H) * 10\n",
    "        self.register(\"log_A_real\", log_A_real, lr)\n",
    "        self.register(\"A_imag\", A_imag, lr)\n",
    "\n",
    "    def forward(self, L):\n",
    "        \"\"\"\n",
    "        returns: (..., c, L) where c is number of channels (default 1)\n",
    "        \"\"\"\n",
    "\n",
    "        # Materialize parameters\n",
    "        dt = torch.exp(self.log_dt) # (H)\n",
    "        C = torch.view_as_complex(self.C) # (H N)\n",
    "        A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N)\n",
    "\n",
    "        # Vandermonde multiplication\n",
    "        dtA = A * dt.unsqueeze(-1)  # (H N)\n",
    "        K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L)\n",
    "        C = C * (torch.exp(dtA)-1.) / A\n",
    "        K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real\n",
    "\n",
    "        return K\n",
    "\n",
    "    def register(self, name, tensor, lr=None):\n",
    "        \"\"\"Register a tensor with a configurable learning rate and 0 weight decay\"\"\"\n",
    "\n",
    "        if lr == 0.0:\n",
    "            self.register_buffer(name, tensor)\n",
    "        else:\n",
    "            self.register_parameter(name, nn.Parameter(tensor))\n",
    "\n",
    "            optim = {\"weight_decay\": 0.0}\n",
    "            if lr is not None: optim[\"lr\"] = lr\n",
    "            setattr(getattr(self, name), \"_optim\", optim)\n",
    "\n",
    "L_image = res1*res2\n",
    "class S4D_simple(nn.Module):\n",
    "    def __init__(self, d_state = 512, L = L_image, d_output = 3, dropout=0.0, transposed=True, **kernel_args):\n",
    "        super().__init__()\n",
    "\n",
    "        self.n = d_state\n",
    "        self.d_output = d_output\n",
    "        self.d_model = 4\n",
    "        self.transposed = transposed\n",
    "        self.D = nn.Parameter(torch.randn(1))\n",
    "        self.encoder = nn.Linear(3, self.d_model)\n",
    "        self.decoder = nn.Linear(self.d_model, d_output)\n",
    "\n",
    "        # SSM Kernel\n",
    "        self.kernel = S4DKernel_simple(self.d_model, N=self.n, **kernel_args)\n",
    "\n",
    "        # Pointwise\n",
    "        self.activation = nn.GELU()\n",
    "        # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11\n",
    "        dropout_fn = DropoutNd\n",
    "        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()\n",
    "\n",
    "        # position-wise output transform to mix features\n",
    "        self.output_linear = nn.Sequential(\n",
    "            # nn.Conv1d(self.d_model, 2*self.d_model, kernel_size=1),\n",
    "            nn.GELU(),\n",
    "        )\n",
    "\n",
    "    def forward(self, u, **kwargs): # absorbs return_output and transformer src mask\n",
    "        \"\"\" Input and output shape (B, H, L) \"\"\"\n",
    "        if not self.transposed: u = u.transpose(-1, -2)\n",
    "        L = u.size(-1)\n",
    "        u = u.transpose(-1,-2) # (B L 3)\n",
    "        u = self.encoder(u) # (B L H)\n",
    "        u = u.transpose(-1,-2) # (B H L)\n",
    "\n",
    "        # Compute SSM Kernel\n",
    "        k = self.kernel(L=L) # (H L)\n",
    "\n",
    "        # Convolution\n",
    "        k = nn.functional.pad(k,(0,L),'constant',0)\n",
    "        u = nn.functional.pad(u,(0,L),'constant',0)\n",
    "        k_f = torch.fft.fft(k) # (H L)\n",
    "        u_f = torch.fft.fft(u) # (B H L)\n",
    "        \n",
    "        y = torch.fft.ifft(u_f*k_f).real # (B H L)\n",
    "        y = y[...,0:L]\n",
    "        u = u[...,0:L]\n",
    "        \n",
    "        y = y.transpose(-1,-2) # (B L H)\n",
    "        y = self.decoder(y) # (B L d_output)\n",
    "        y = y.transpose(-1,-2) # (B d_output L)\n",
    "\n",
    "        return y, None # Return a dummy state to satisfy this repo's interface, but this can be modified"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 187,
   "id": "d469774c-a020-47f4-9ee6-4afb1ae45cb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = S4D_simple()\n",
    "model = model.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "id": "64516798-7cfc-4d04-85b5-136c5d8efbfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load('./checkpoint/ckpt.pth')\n",
    "model.load_state_dict(checkpoint['model'])\n",
    "model = model.to('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47bb6f37-c324-47e6-ae30-d9eabe875ea6",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def plot_pics():\n",
    "    pbar = tqdm(enumerate(dataloader))\n",
    "    for batch_idx, (inputs, targets) in pbar:\n",
    "        outputs, _ = model(inputs)\n",
    "        for i in range(32):\n",
    "            print('Blurred Images')\n",
    "            imshow(inputs[i,:,:])\n",
    "            print('Output of the SSM')\n",
    "            imshow(outputs[i,:,:])\n",
    "        break\n",
    "\n",
    "plot_pics()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8b4ca309-dd4a-4766-91ad-67a04a194f7e",
   "metadata": {},
   "source": [
    "## Try the horizontal noises"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "id": "a68c5bad-38ef-42a7-8176-91746fb8e5d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "res1 = 2048\n",
    "res2 = 128\n",
    "res = 256\n",
    "class AddWaveTransform:\n",
    "    def __call__(self, image):\n",
    "        magnitude = 5\n",
    "        frequency = 2.99433\n",
    "        waves0 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res1)).unsqueeze(1)\n",
    "        waves1 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res1)).unsqueeze(1)\n",
    "        waves2 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res1)).unsqueeze(1)\n",
    "        image[0, :, :] += waves0\n",
    "        image[1, :, :] += waves1\n",
    "        image[2, :, :] += waves2\n",
    "        return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc0a14f5-a4ef-4d1c-92b0-e9ab583d5d0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the path to store the CelebA dataset\n",
    "data_dir = './data/CelebA'\n",
    "\n",
    "# Define transformations for the dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.CenterCrop(178),  # Center crop to 178x178\n",
    "    transforms.Resize((res1,res2)),      # Resize to 128x128\n",
    "    transforms.ToTensor(),       # Convert to tensor\n",
    "    AddWaveTransform(),\n",
    "    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize\n",
    "    transforms.Lambda(lambda x: x.view(3, res1*res2)),\n",
    "])\n",
    "\n",
    "# Load the CelebA dataset\n",
    "dataset = datasets.CelebA(root=data_dir, split='train', transform=transform, download=False)\n",
    "\n",
    "# Create DataLoader for the dataset\n",
    "dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)\n",
    "\n",
    "# Check the dataset\n",
    "print(f'Number of samples: {len(dataset)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "328218d1-4da4-449c-9772-1cef9f4061a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def plot_pics():\n",
    "    pbar = tqdm(enumerate(dataloader))\n",
    "    for batch_idx, (inputs, targets) in pbar:\n",
    "        outputs, _ = model(inputs)\n",
    "        for i in range(32):\n",
    "            print('Blurred Images')\n",
    "            imshow(inputs[i,:,:])\n",
    "            print('Output of the SSM')\n",
    "            imshow(outputs[i,:,:])\n",
    "        break\n",
    "\n",
    "plot_pics()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "1db2a130",
   "metadata": {},
   "source": [
    "# For some rigirous numbers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1471e25-f8cf-4c2d-a272-b8ff47aa63a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "\n",
    "class AddWaveTransform:\n",
    "    def __call__(self, image):\n",
    "        waves0 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(2))).unsqueeze(0)\n",
    "        waves1 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(2))).unsqueeze(0)\n",
    "        waves2 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(2))).unsqueeze(0)\n",
    "        image[0, :, :] += waves0\n",
    "        image[1, :, :] += waves1\n",
    "        image[2, :, :] += waves2\n",
    "        return image\n",
    "\n",
    "class AddWaveTransform_Horizontal:\n",
    "    def __call__(self, image):\n",
    "        waves0 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(1))).unsqueeze(1)\n",
    "        waves1 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(1))).unsqueeze(1)\n",
    "        waves2 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(1))).unsqueeze(1)\n",
    "        image[0, :, :] += waves0\n",
    "        image[1, :, :] += waves1\n",
    "        image[2, :, :] += waves2\n",
    "        return image\n",
    "\n",
    "# Define the path to store the CelebA dataset\n",
    "data_dir = './data/CelebA'\n",
    "\n",
    "res1 = 2048\n",
    "res2 = 128\n",
    "\n",
    "# Define transformations for the dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.CenterCrop(178),  # Center crop to 178x178\n",
    "    transforms.Resize((res1,res2)),      # Resize to 128x128\n",
    "    transforms.ToTensor(),       # Convert to tensor\n",
    "    #AddWaveTransform(),\n",
    "    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize\n",
    "    transforms.Lambda(lambda x: x.view(3, res1*res2)),\n",
    "])\n",
    "\n",
    "# Define the two transforms\n",
    "transform1 = transforms.Compose([\n",
    "    transforms.CenterCrop(178),  # Center crop to 178x178\n",
    "    transforms.Resize((res1,res2)),      # Resize to 128x128\n",
    "    transforms.ToTensor(),       # Convert to tensor\n",
    "    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize\n",
    "    transforms.Lambda(lambda x: x.view(3, res1*res2)),\n",
    "])\n",
    "\n",
    "transform2 = transforms.Compose([\n",
    "    transforms.CenterCrop(178),  # Center crop to 178x178\n",
    "    transforms.Resize((res1,res2)),      # Resize to 128x128\n",
    "    transforms.ToTensor(),       # Convert to tensor\n",
    "    AddWaveTransform(),\n",
    "    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize\n",
    "    transforms.Lambda(lambda x: x.view(3, res1*res2)),\n",
    "])\n",
    "\n",
    "transform3 = transforms.Compose([\n",
    "    transforms.CenterCrop(178),  # Center crop to 178x178\n",
    "    transforms.Resize((res1,res2)),      # Resize to 128x128\n",
    "    transforms.ToTensor(),       # Convert to tensor\n",
    "    AddWaveTransform_Horizontal(),\n",
    "    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize\n",
    "    transforms.Lambda(lambda x: x.view(3, res1*res2)),\n",
    "])\n",
    "\n",
    "class DualTransformDataset(Dataset):\n",
    "    def __init__(self, dataset, transform1=None, transform2=None, transform3=None):\n",
    "        self.dataset = dataset\n",
    "        self.transform1 = transform1\n",
    "        self.transform2 = transform2\n",
    "        self.transform3 = transform3\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        image, label = self.dataset[idx]\n",
    "        input_image_vertical = self.transform2(image)\n",
    "        input_image_horizontal = self.transform3(image)\n",
    "        true_image = self.transform1(image)\n",
    "        return input_image_vertical, input_image_horizontal, true_image, label\n",
    "\n",
    "# Load the CelebA dataset\n",
    "dataset = datasets.CelebA(root=data_dir, split='train', download=False)\n",
    "\n",
    "# Create the dual transform dataset\n",
    "dual_transform_dataset = DualTransformDataset(dataset, transform1=transform1, transform2=transform2, transform3=transform3)\n",
    "\n",
    "# Create a DataLoader\n",
    "dataloader = DataLoader(dual_transform_dataset, batch_size=32, shuffle=True, num_workers=4)\n",
    "\n",
    "# Check the dataset\n",
    "print(f'Number of samples: {len(dataset)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "020d112e",
   "metadata": {},
   "outputs": [],
   "source": [
    "total = 0\n",
    "total_vertical = 0\n",
    "total_horizontal = 0\n",
    "total_noises_vertical = 0\n",
    "total_noises_horizontal = 0\n",
    "criterion = nn.MSELoss()\n",
    "model = model.to('cuda')\n",
    "pbar = tqdm(enumerate(dataloader))\n",
    "for batch_idx, (input_image_vertical, input_image_horizontal, true_image, targets) in pbar:\n",
    "    input_image_vertical, input_image_horizontal, true_image = input_image_vertical.to('cuda'), input_image_horizontal.to('cuda'), true_image.to('cuda')\n",
    "    \n",
    "    foutputs, _ = model(input_image_vertical)\n",
    "    loss = criterion(foutputs, true_image)\n",
    "    total_vertical += loss.item()\n",
    "\n",
    "    foutputs, _ = model(input_image_horizontal)\n",
    "    loss = criterion(foutputs, true_image)\n",
    "    total_horizontal += loss.item()\n",
    "    \n",
    "    total_noises_vertical += criterion(input_image_vertical, true_image).item()\n",
    "    total_noises_horizontal += criterion(input_image_horizontal, true_image).item()\n",
    "    \n",
    "    total += targets.size(0)\n",
    "    \n",
    "    pbar.set_description(\n",
    "        'Batch Idx: (%d/%d) | Loss_V: %.4f | Noise_V: %.4f | Loss_H: %.4f | Noise_H: %.4f' %\n",
    "        (batch_idx, len(dataloader), total_vertical/(batch_idx+1), total_noises_vertical/(batch_idx+1), total_horizontal/(batch_idx+1), total_noises_horizontal/(batch_idx+1))\n",
    "    )\n",
    "\n",
    "print('noises level = ' + str(total_noises / (batch_idx + 1)) + 'loss = ' + str(total_loss / (batch_idx + 1)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
