{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import os\n",
    "\n",
    "from PIL import Image\n",
    "from torchvision.transforms import Resize, Compose, ToTensor, Normalize\n",
    "import numpy as np\n",
    "import skimage\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import time\n",
    "\n",
    "def get_mgrid(sidelen, dim=2):\n",
    "    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.\n",
    "    sidelen: int\n",
    "    dim: int'''\n",
    "    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])\n",
    "    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)\n",
    "    mgrid = mgrid.reshape(-1, dim)\n",
    "    return mgrid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SineLayer(nn.Module):\n",
    "    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.\n",
    "    \n",
    "    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the \n",
    "    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a \n",
    "    # hyperparameter.\n",
    "    \n",
    "    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of \n",
    "    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)\n",
    "    \n",
    "    def __init__(self, in_features, out_features, bias=True,\n",
    "                 is_first=False, omega_0=30):\n",
    "        super().__init__()\n",
    "        self.omega_0 = omega_0\n",
    "        self.is_first = is_first\n",
    "        \n",
    "        self.in_features = in_features\n",
    "        self.linear = nn.Linear(in_features, out_features, bias=bias)\n",
    "        \n",
    "        self.init_weights()\n",
    "    \n",
    "    def init_weights(self):\n",
    "        with torch.no_grad():\n",
    "            if self.is_first:\n",
    "                self.linear.weight.uniform_(-1 / self.in_features, \n",
    "                                             1 / self.in_features)      \n",
    "            else:\n",
    "                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, \n",
    "                                             np.sqrt(6 / self.in_features) / self.omega_0)\n",
    "        \n",
    "    def forward(self, input):\n",
    "        x = self.omega_0 * self.linear(input)\n",
    "        # return torch.sin(x)\n",
    "        return torch.sin(x) * torch.sqrt(torch.maximum(torch.abs(x), torch.tensor(1e-43, requires_grad=True)))\n",
    "    \n",
    "    def forward_with_intermediate(self, input): \n",
    "        # For visualization of activation distributions\n",
    "        intermediate = self.omega_0 * self.linear(input)\n",
    "        return torch.sin(intermediate), intermediate\n",
    "    \n",
    "    \n",
    "class Siren(nn.Module):\n",
    "    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, \n",
    "                 first_omega_0=30, hidden_omega_0=30.):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.net = []\n",
    "        self.net.append(SineLayer(in_features, hidden_features, \n",
    "                                  is_first=True, omega_0=first_omega_0))\n",
    "\n",
    "        for i in range(hidden_layers):\n",
    "            self.net.append(SineLayer(hidden_features, hidden_features, \n",
    "                                      is_first=False, omega_0=hidden_omega_0))\n",
    "\n",
    "        if outermost_linear:\n",
    "            final_linear = nn.Linear(hidden_features, out_features)\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, \n",
    "                                              np.sqrt(6 / hidden_features) / hidden_omega_0)\n",
    "                \n",
    "            self.net.append(final_linear)\n",
    "        else:\n",
    "            self.net.append(SineLayer(hidden_features, out_features, \n",
    "                                      is_first=False, omega_0=hidden_omega_0))\n",
    "        \n",
    "        self.net = nn.Sequential(*self.net)\n",
    "    \n",
    "    def forward(self, coords):\n",
    "        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input\n",
    "        output = self.net(coords)\n",
    "        return output, coords        \n",
    "\n",
    "    def forward_with_activations(self, coords, retain_grad=False):\n",
    "        '''Returns not only model output, but also intermediate activations.\n",
    "        Only used for visualizing activations later!'''\n",
    "        activations = OrderedDict()\n",
    "\n",
    "        activation_count = 0\n",
    "        x = coords.clone().detach().requires_grad_(True)\n",
    "        activations['input'] = x\n",
    "        for i, layer in enumerate(self.net):\n",
    "            if isinstance(layer, SineLayer):\n",
    "                x, intermed = layer.forward_with_intermediate(x)\n",
    "                \n",
    "                if retain_grad:\n",
    "                    x.retain_grad()\n",
    "                    intermed.retain_grad()\n",
    "                    \n",
    "                activations['_'.join((str(layer.__class__), \"%d\" % activation_count))] = intermed\n",
    "                activation_count += 1\n",
    "            else: \n",
    "                x = layer(x)\n",
    "                \n",
    "                if retain_grad:\n",
    "                    x.retain_grad()\n",
    "                    \n",
    "            activations['_'.join((str(layer.__class__), \"%d\" % activation_count))] = x\n",
    "            activation_count += 1\n",
    "\n",
    "        return activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def laplace(y, x):\n",
    "    grad = gradient(y, x)\n",
    "    return divergence(grad, x)\n",
    "\n",
    "\n",
    "def divergence(y, x):\n",
    "    div = 0.\n",
    "    for i in range(y.shape[-1]):\n",
    "        div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]\n",
    "    return div\n",
    "\n",
    "\n",
    "def gradient(y, x, grad_outputs=None):\n",
    "    if grad_outputs is None:\n",
    "        grad_outputs = torch.ones_like(y)\n",
    "    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]\n",
    "    return grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cameraman_tensor(sidelength):\n",
    "    img = Image.fromarray(skimage.data.brick())        \n",
    "    transform = Compose([\n",
    "        Resize(sidelength),\n",
    "        ToTensor(),\n",
    "        Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))\n",
    "    ])\n",
    "    img = transform(img)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImageFitting(Dataset):\n",
    "    def __init__(self, sidelength):\n",
    "        super().__init__()\n",
    "        img = get_cameraman_tensor(sidelength)\n",
    "        self.pixels = img.permute(1, 2, 0).view(-1, 1)\n",
    "        self.coords = get_mgrid(sidelength, 2)\n",
    "\n",
    "    def __len__(self):\n",
    "        return 1\n",
    "\n",
    "    def __getitem__(self, idx):    \n",
    "        if idx > 0: raise IndexError\n",
    "            \n",
    "        return self.coords, self.pixels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cameraman = ImageFitting(256)\n",
    "dataloader = DataLoader(cameraman, batch_size=1, pin_memory=True, num_workers=0)\n",
    "\n",
    "img_siren = Siren(in_features=2, out_features=1, hidden_features=256, \n",
    "                  hidden_layers=3, outermost_linear=True)\n",
    "img_siren.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_steps = 501 # Since the whole image is our dataset, this just means 500 gradient descent steps.\n",
    "steps_til_summary = 5\n",
    "\n",
    "optim = torch.optim.Adam(lr=1e-4, params=img_siren.parameters())\n",
    "\n",
    "model_input, ground_truth = next(iter(dataloader))\n",
    "model_input, ground_truth = model_input.cuda(), ground_truth.cuda()\n",
    "\n",
    "for step in range(total_steps):\n",
    "    model_output, coords = img_siren(model_input)    \n",
    "    loss = ((model_output - ground_truth)**2).mean()\n",
    "    \n",
    "    if step == steps_til_summary:  # Changed to allow multiple summaries during the process\n",
    "        print(\"Step %d, Total loss %0.6f\" % (step, loss))\n",
    "        img_grad = gradient(model_output, coords)\n",
    "\n",
    "        plt.figure(figsize=(6,6))  # Removed subplots and created single plot\n",
    "        plt.imshow(img_grad.norm(dim=-1).cpu().view(256,256).detach().numpy())\n",
    "        plt.title(\"Gradients of SPDER Representation at Step \" + str(step))  # Added title\n",
    "        plt.axis('off')  # Hide axes for clean presentation\n",
    "        plt.show()\n",
    "\n",
    "    optim.zero_grad()\n",
    "    loss.backward()\n",
    "    optim.step()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming ground_truth is a torch tensor of shape [1, 256*256, 1]\n",
    "\n",
    "gt = ground_truth.cpu().view(256, 256).detach().numpy()\n",
    "\n",
    "plt.figure(figsize=(6,6))\n",
    "plt.imshow(gt, cmap='gray')\n",
    "plt.title(\"Ground Truth\")\n",
    "plt.axis('off')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Compute gradients\n",
    "gt_grad_y, gt_grad_x = np.gradient(gt)\n",
    "\n",
    "# Compute gradient magnitude\n",
    "gt_grad_mag = np.sqrt(gt_grad_x**2 + gt_grad_y**2)\n",
    "\n",
    "# Plot the ground truth image, its gradients, and the gradient magnitude\n",
    "fig, axes = plt.subplots(1,3, figsize=(18,6))\n",
    "\n",
    "# Ground truth image\n",
    "axes[0].imshow(gt, cmap='gray')\n",
    "axes[0].set_title(\"Ground Truth\")\n",
    "axes[0].axis('off')\n",
    "\n",
    "# Gradient in x-direction\n",
    "axes[1].imshow(gt_grad_x, cmap='gray')\n",
    "axes[1].set_title(\"Gradient in x-direction\")\n",
    "axes[1].axis('off')\n",
    "\n",
    "# Gradient in y-direction\n",
    "axes[2].imshow(gt_grad_y, cmap='gray')\n",
    "axes[2].set_title(\"Gradient in y-direction\")\n",
    "axes[2].axis('off')\n",
    "\n",
    "plt.show()\n",
    "\n",
    "# Plot the gradient magnitude\n",
    "plt.figure(figsize=(6,6))\n",
    "plt.imshow(gt_grad_mag, cmap='gray')\n",
    "plt.title(\"Gradients of Ground Truth\")\n",
    "plt.axis('off')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
