{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.models as models\n",
    "\n",
    "import time\n",
    "import numpy as np\n",
    "import math\n",
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "import resnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "im_size = 224\n",
    "\n",
    "def filter_input_sizes(model):\n",
    "    param_dict = {}\n",
    "    size_dict = {}\n",
    "    i = 0\n",
    "    for (name, param) in model.named_parameters():\n",
    "        if 'conv' in name:\n",
    "            param_clone = param.permute(2, 3, 1, 0).contiguous().clone()\n",
    "            param_clone = param_clone.detach().cpu().numpy()\n",
    "            param_dict[name] = tf.convert_to_tensor(param_clone)\n",
    "            if i < 12:\n",
    "                size_dict[name] = im_size\n",
    "            elif i >= 12 and i < 22:\n",
    "                size_dict[name] = im_size//2\n",
    "            elif i>=22 and i < 32:\n",
    "                size_dict[name] = im_size//4\n",
    "            elif i>=32 and i < 42:\n",
    "                size_dict[name] = im_size//8\n",
    "            elif i>=42 and i < 52:\n",
    "                size_dict[name] = im_size//16\n",
    "            elif i>=52:\n",
    "                size_dict[name] = im_size//32\n",
    "            i = i + 1\n",
    "    return param_dict, size_dict\n",
    "\n",
    "def Clip_OperatorNorm(conv, inp_shape, clip_to):\n",
    "    conv_shape = conv.get_shape().as_list()\n",
    "    conv_tr = tf.cast(tf.transpose(conv, perm=[2, 3, 0, 1]), tf.complex64)\n",
    "    padding = tf.constant([[0, 0], [0, 0],\n",
    "                         [0, inp_shape[0] - conv_shape[0]],\n",
    "                         [0, inp_shape[1] - conv_shape[1]]])\n",
    "    transform_coeff = tf.signal.fft2d(tf.pad(conv_tr, padding))\n",
    "    D, U, V = tf.linalg.svd(tf.transpose(transform_coeff, perm = [2, 3, 0, 1]))\n",
    "    \n",
    "    norm = tf.reduce_max(D)\n",
    "    D_clipped = tf.cast(tf.minimum(D, clip_to), tf.complex64)\n",
    "    clipped_coeff = tf.matmul(U, tf.matmul(tf.linalg.diag(D_clipped),\n",
    "                                         V, adjoint_b=True))\n",
    "    clipped_conv_padded = tf.math.real(tf.signal.ifft2d(\n",
    "        tf.transpose(clipped_coeff, perm=[2, 3, 0, 1])))\n",
    "    return tf.slice(tf.transpose(clipped_conv_padded, perm=[2, 3, 0, 1]),\n",
    "                  [0] * len(conv_shape), conv_shape), norm\n",
    "\n",
    "def clip_all_convs(param_dict, size_dict):\n",
    "    for name, conv_filter in param_dict.items():\n",
    "        inp_shape = (size_dict[name], size_dict[name])\n",
    "        Clip_OperatorNorm(conv_filter, inp_shape, 1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def SVD_Conv_Tensor(conv, inp_shape):\n",
    "    conv_tr = tf.cast(tf.transpose(conv, perm=[2, 3, 0, 1]), tf.complex64)\n",
    "    conv_shape = conv.get_shape().as_list()\n",
    "    padding = tf.constant([[0, 0], [0, 0],\n",
    "                         [0, inp_shape[0] - conv_shape[0]],\n",
    "                         [0, inp_shape[1] - conv_shape[1]]])\n",
    "    transform_coeff = tf.signal.fft2d(tf.pad(conv_tr, padding))\n",
    "    \n",
    "    transform_matrix = tf.transpose(transform_coeff, perm = [2, 3, 0, 1])\n",
    "    singular_values = tf.linalg.norm(transform_matrix, axis=(2, 3))\n",
    "    return singular_values\n",
    "        \n",
    "def calc_all_convs(param_dict, size_dict):\n",
    "    for name, conv_filter in param_dict.items():\n",
    "        inp_shape = (size_dict[name], size_dict[name])\n",
    "        SVD_Conv_Tensor(conv_filter, inp_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.nn.DataParallel(resnet.__dict__['resnet32']())\n",
    "# model = models.resnet34(pretrained=False)\n",
    "model = model.cuda()\n",
    "\n",
    "train_dataset = datasets.FakeData(size=10000, image_size=(3, im_size, im_size), \n",
    "                                  num_classes=10, transform=transforms.ToTensor())\n",
    "\n",
    "with torch.no_grad():\n",
    "    param_dict, size_dict = filter_input_sizes(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=32, shuffle=True,\n",
    "    num_workers=8, pin_memory=False)\n",
    "optimizer = torch.optim.SGD(model.parameters(), 0.1,\n",
    "                            momentum=0.9,\n",
    "                            weight_decay=0.)\n",
    "criterion = nn.CrossEntropyLoss().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "54.3857\n"
     ]
    }
   ],
   "source": [
    "from spectral_norm import ConvFilterNorm\n",
    "\n",
    "class SpectralModule(nn.Module):\n",
    "    def __init__(self, model):\n",
    "        super(SpectralModule, self).__init__()\n",
    "        self.spectra = nn.ModuleList()\n",
    "        for name, param in model.named_parameters():\n",
    "            if len(param.shape) > 2:\n",
    "                self.spectra.append(ConvFilterNorm(param))\n",
    "\n",
    "    def forward(self):\n",
    "        sigma_list = []\n",
    "        for s in self.spectra:\n",
    "            sigma_list.append(s())\n",
    "        sigma_arr = torch.Tensor(sigma_list)\n",
    "        return sigma_arr\n",
    "\n",
    "start_time = time.time()\n",
    "spectral_module = SpectralModule(model)\n",
    "for i, (X, y) in enumerate(train_loader):\n",
    "    X = X.cuda()\n",
    "    y = y.cuda()\n",
    "    \n",
    "    output = model(X)\n",
    "    ce_loss = criterion(output, y)\n",
    "    all_sigma = spectral_module()\n",
    "    spectral_loss = all_sigma.sum()#spectralnorm_sum(model, spectral_dict)\n",
    "    loss = ce_loss + 8e-4*spectral_loss\n",
    "\n",
    "    # compute gradient and do SGD step\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "print('{:.4f}'.format(time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "49.4059\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "for i, (X, y) in enumerate(train_loader):\n",
    "    X = X.cuda()\n",
    "    y = y.cuda()\n",
    "    \n",
    "    output = model(X)\n",
    "    loss = criterion(output, y)\n",
    "\n",
    "    # compute gradient and do SGD step\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "print('{:.4f}'.format(time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "156.7245\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "for i, (X, y) in enumerate(train_loader):\n",
    "    X = X.cuda()\n",
    "    y = y.cuda()\n",
    "    \n",
    "    output = model(X)\n",
    "    loss = criterion(output, y)\n",
    "\n",
    "    # compute gradient and do SGD step\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    if i % 100 == 0:\n",
    "        clip_all_convs(param_dict, size_dict)\n",
    "print('{:.4f}'.format(time.time() - start_time))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37_tensorflow",
   "language": "python",
   "name": "conda-env-py37_tensorflow-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
