{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms, models\n",
    "from torch.nn.functional import conv2d\n",
    "\n",
    "import numpy as np\n",
    "import math\n",
    "import os\n",
    "import time\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "resnet18 = models.resnet18(pretrained=True).cuda()\n",
    "\n",
    "def singular_values_tf(conv, inp_shape):\n",
    "    start_time = time.time()\n",
    "    conv_tr = tf.cast(tf.transpose(conv, perm=[2, 3, 0, 1]), tf.complex64)\n",
    "    conv_shape = conv.get_shape().as_list()\n",
    "\n",
    "    padding = tf.constant([[0, 0], [0, 0],\n",
    "                         [0, inp_shape[0] - conv_shape[0]],\n",
    "                         [0, inp_shape[1] - conv_shape[1]]])\n",
    "    transform_coeff = tf.signal.fft2d(tf.pad(conv_tr, padding))\n",
    "\n",
    "    transform_coeff_perm = tf.transpose(transform_coeff, perm = [2, 3, 0, 1])\n",
    "    norms = tf.norm(transform_coeff_perm, ord=2, axis=(2, 3))\n",
    "    max_norm = tf.math.reduce_max(norms)\n",
    "    total_time = time.time() - start_time\n",
    "    return total_time, max_norm\n",
    "\n",
    "def singular_values_np(conv, inp_shape):\n",
    "    start_time = time.time()\n",
    "    transform_coeff = np.fft.fft2(conv, inp_shape, axes=[0, 1])\n",
    "    \n",
    "    norms = np.linalg.norm(transform_coeff, ord=2, axis=(2, 3))\n",
    "    total_time = time.time() - start_time\n",
    "    return total_time, np.amax(norms)\n",
    "\n",
    "def l2_normalize(tensor, eps=1e-12):\n",
    "    norm = float(torch.sqrt(torch.sum(tensor * tensor)))\n",
    "    norm = max(norm, eps)\n",
    "    ans = tensor / norm\n",
    "    return ans\n",
    "\n",
    "def real_spectral_norm(conv_filter, shape):\n",
    "    start_time = time.time()\n",
    "    H, W = shape\n",
    "    c_out = conv_filter.shape[0]\n",
    "    c_in = conv_filter.shape[1]\n",
    "    pad_size = (conv_filter.shape[2] - 1)//2\n",
    "    u = l2_normalize(conv_filter.new_empty(1, c_out, H, W).normal_(0, 1))\n",
    "    v = l2_normalize(conv_filter.new_empty(1, c_in, H, W).normal_(0, 1))\n",
    "    for _ in range(50):\n",
    "        v.data = l2_normalize(F.conv_transpose2d(u.data, conv_filter, padding=pad_size))\n",
    "        u.data = l2_normalize(F.conv2d(v, conv_filter, padding=pad_size))\n",
    "    sigma = torch.sum(u * conv2d(v, conv_filter, padding=pad_size))\n",
    "    total_time = time.time() - start_time\n",
    "    return total_time, sigma\n",
    "\n",
    "def our_bounds_np(conv_filter):\n",
    "    start_time = time.time()\n",
    "    out_ch, in_ch, h, w = conv_filter.shape\n",
    "    \n",
    "    permute1 = np.transpose(conv_filter, axes=[0, 2, 1, 3])\n",
    "    matrix1 = np.reshape(permute1, [out_ch*h, in_ch*w])\n",
    "    norm1 = math.sqrt(h*w)*np.linalg.norm(matrix1, ord=2, axis=(0, 1))\n",
    "\n",
    "    permute2 = np.transpose(conv_filter, axes=[0, 3, 1, 2])\n",
    "    matrix2 = np.reshape(permute2, [out_ch*w, in_ch*h])\n",
    "    norm2 = math.sqrt(h*w)*np.linalg.norm(matrix2, ord=2, axis=(0, 1))\n",
    "\n",
    "    permute3 = conv_filter\n",
    "    matrix3 = np.reshape(permute3, [out_ch, in_ch*h*w])\n",
    "    norm3 = math.sqrt(h*w)*np.linalg.norm(matrix3, ord=2, axis=(0, 1))\n",
    "\n",
    "    permute4 = np.transpose(conv_filter, axes=[0, 2, 3, 1])\n",
    "    matrix4 = np.reshape(permute4, [out_ch*h*w, in_ch])\n",
    "    norm4 = math.sqrt(h*w)*np.linalg.norm(matrix4, ord=2, axis=(0, 1))\n",
    "    \n",
    "    norm_tensor = np.stack([norm1, norm2, norm3, norm4], axis=0)\n",
    "    min_norm = np.amin(norm_tensor)\n",
    "    total_time = time.time() - start_time\n",
    "    return total_time, min_norm\n",
    "\n",
    "def our_bounds_tf(conv_filter):\n",
    "    start_time = time.time()\n",
    "    out_ch, in_ch, h, w = conv_filter.shape\n",
    "    \n",
    "    permute1 = tf.transpose(conv_filter, perm=[0, 2, 1, 3])\n",
    "    matrix1 = tf.reshape(permute1, [out_ch*h, in_ch*w])\n",
    "    norm1 = math.sqrt(h*w)*tf.norm(matrix1, ord=2, axis=(0, 1))\n",
    "\n",
    "    permute2 = tf.transpose(conv_filter, perm=[0, 3, 1, 2])\n",
    "    matrix2 = tf.reshape(permute2, [out_ch*w, in_ch*h])\n",
    "    norm2 = math.sqrt(h*w)*tf.norm(matrix2, ord=2, axis=(0, 1))\n",
    "\n",
    "    permute3 = conv_filter\n",
    "    matrix3 = tf.reshape(permute3, [out_ch, in_ch*h*w])\n",
    "    norm3 = math.sqrt(h*w)*tf.norm(matrix3, ord=2, axis=(0, 1))\n",
    "\n",
    "    permute4 = tf.transpose(conv_filter, perm=[0, 2, 3, 1])\n",
    "    matrix4 = tf.reshape(permute4, [out_ch*h*w, in_ch])\n",
    "    norm4 = math.sqrt(h*w)*tf.norm(matrix4, ord=2, axis=(0, 1))\n",
    "    \n",
    "    norm_tensor = tf.stack([norm1, norm2, norm3, norm4], axis=0)\n",
    "    min_norm = tf.reduce_min(norm_tensor)\n",
    "    total_time = time.time() - start_time\n",
    "    return total_time, min_norm\n",
    "\n",
    "def our_bounds_ch(conv_filter, num_iters=50):\n",
    "    start_time = time.time()\n",
    "    out_ch, in_ch, h, w = conv_filter.shape\n",
    "        \n",
    "    permute1 = torch.transpose(conv_filter, 1, 2)\n",
    "    matrix1 = permute1.reshape(out_ch*h, in_ch*w)\n",
    "    u1 = torch.randn(matrix1.shape[1], device='cuda', requires_grad=False)\n",
    "    v1 = torch.randn(matrix1.shape[0], device='cuda', requires_grad=False)\n",
    "\n",
    "    permute2 = torch.transpose(conv_filter, 1, 3)\n",
    "    matrix2 = permute2.reshape(out_ch*w, in_ch*h)\n",
    "    u2 = torch.randn(matrix2.shape[1], device='cuda', requires_grad=False)\n",
    "    v2 = torch.randn(matrix2.shape[0], device='cuda', requires_grad=False)\n",
    "\n",
    "    permute3 = conv_filter\n",
    "    matrix3 = permute3.reshape(out_ch, in_ch*h*w)\n",
    "    u3 = torch.randn(matrix3.shape[1], device='cuda', requires_grad=False)\n",
    "    v3 = torch.randn(matrix3.shape[0], device='cuda', requires_grad=False)\n",
    "\n",
    "    permute4 = torch.transpose(conv_filter, 0, 1)\n",
    "    matrix4 = permute4.reshape(in_ch, out_ch*h*w)\n",
    "    u4 = torch.randn(matrix4.shape[1], device='cuda', requires_grad=False)\n",
    "    v4 = torch.randn(matrix4.shape[0], device='cuda', requires_grad=False)\n",
    "    \n",
    "    for i in range(num_iters):\n",
    "        v1.data = F.normalize(torch.mv(matrix1.data, u1.data), dim=0)\n",
    "        u1.data = F.normalize(torch.mv(torch.t(matrix1.data), v1.data), dim=0)\n",
    "\n",
    "        v2.data = F.normalize(torch.mv(matrix2.data, u2.data), dim=0)\n",
    "        u2.data = F.normalize(torch.mv(torch.t(matrix2.data), v2.data), dim=0)\n",
    "\n",
    "        v3.data = F.normalize(torch.mv(matrix3.data, u3.data), dim=0)\n",
    "        u3.data = F.normalize(torch.mv(torch.t(matrix3.data), v3.data), dim=0)\n",
    "\n",
    "        v4.data = F.normalize(torch.mv(matrix4.data, u4.data), dim=0)\n",
    "        u4.data = F.normalize(torch.mv(torch.t(matrix4.data), v4.data), dim=0)\n",
    "\n",
    "    sigma1 = torch.mv(v1.unsqueeze(0), torch.mv(matrix1, u1))\n",
    "    sigma2 = torch.mv(v2.unsqueeze(0), torch.mv(matrix2, u2))\n",
    "    sigma3 = torch.mv(v3.unsqueeze(0), torch.mv(matrix3, u3)) \n",
    "    sigma4 = torch.mv(v4.unsqueeze(0), torch.mv(matrix4, u4)) \n",
    "\n",
    "    min_norm = math.sqrt(h*w)*(torch.min(torch.min(torch.min(sigma1, sigma2), sigma3), sigma4)).item()\n",
    "    total_time = time.time() - start_time\n",
    "    return total_time, min_norm\n",
    "\n",
    "def conv_power_iteration(conv_filter, u_list=None, v_list=None, num_iters=50):\n",
    "    start_time = time.time()\n",
    "    out_ch, in_ch, h, w = conv_filter.shape\n",
    "    if u_list is None:\n",
    "        u1 = torch.randn((1, in_ch, 1, w), device='cuda', requires_grad=False)\n",
    "        u1.data = l2_normalize(u1.data)\n",
    "        \n",
    "        u2 = torch.randn((1, in_ch, h, 1), device='cuda', requires_grad=False)\n",
    "        u2.data = l2_normalize(u2.data)\n",
    "\n",
    "        u3 = torch.randn((1, in_ch, h, w), device='cuda', requires_grad=False)\n",
    "        u3.data = l2_normalize(u3.data)\n",
    "\n",
    "        u4 = torch.randn((out_ch, 1, h, w), device='cuda', requires_grad=False)\n",
    "        u4.data = l2_normalize(u4.data)\n",
    "        \n",
    "    if v_list is None:\n",
    "        v1 = torch.randn((out_ch, 1, h, 1), device='cuda', requires_grad=False)\n",
    "        v1.data = l2_normalize(v1.data)\n",
    "        \n",
    "        v2 = torch.randn((out_ch, 1, 1, w), device='cuda', requires_grad=False)\n",
    "        v2.data = l2_normalize(v2.data)\n",
    "\n",
    "        v3 = torch.randn((out_ch, 1, 1, 1), device='cuda', requires_grad=False)\n",
    "        v3.data = l2_normalize(v3.data)\n",
    "\n",
    "        v4 = torch.randn((1, in_ch, 1, 1), device='cuda', requires_grad=False)\n",
    "        v4.data = l2_normalize(v4.data)\n",
    "\n",
    "    for i in range(num_iters):\n",
    "        v1.data = l2_normalize((conv_filter.data*u1.data).sum((1, 3), keepdim=True).data)\n",
    "        u1.data = l2_normalize((conv_filter.data*v1.data).sum((0, 2), keepdim=True).data)\n",
    "        \n",
    "        v2.data = l2_normalize((conv_filter.data*u2.data).sum((1, 2), keepdim=True).data)\n",
    "        u2.data = l2_normalize((conv_filter.data*v2.data).sum((0, 3), keepdim=True).data)\n",
    "        \n",
    "        v3.data = l2_normalize((conv_filter.data*u3.data).sum((1, 2, 3), keepdim=True).data)\n",
    "        u3.data = l2_normalize((conv_filter.data*v3.data).sum(0, keepdim=True).data)\n",
    "        \n",
    "        v4.data = l2_normalize((conv_filter.data*u4.data).sum((0, 2, 3), keepdim=True).data)\n",
    "        u4.data = l2_normalize((conv_filter.data*v4.data).sum(1, keepdim=True).data)\n",
    "\n",
    "    sigma1 = torch.sum(conv_filter.data*u1.data*v1.data)\n",
    "    sigma2 = torch.sum(conv_filter.data*u2.data*v2.data)\n",
    "    sigma3 = torch.sum(conv_filter.data*u3.data*v3.data)\n",
    "    sigma4 = torch.sum(conv_filter.data*u4.data*v4.data)\n",
    "\n",
    "    min_norm = math.sqrt(h*w)*(torch.min(torch.min(torch.min(sigma1, sigma2), sigma3), sigma4)).item()\n",
    "    total_time = time.time() - start_time\n",
    "    return total_time, min_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[64, 3, 7, 7] (224, 224)\n",
      "our: 28.8947, 0.2373\n",
      "real: 15.8215, 0.0372\n",
      "WARNING:tensorflow:From /anaconda/envs/py37_tensorflow/lib/python3.7/site-packages/tensorflow/python/ops/linalg_ops.py:721: setdiff1d (from tensorflow.python.ops.array_ops) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "This op will be removed after the deprecation date. Please switch to tf.sets.difference().\n",
      "exact tf: 15.9164, 0.9458\n",
      "exact np: 15.9164, 0.5979\n",
      "[64, 64, 3, 3] (224, 224)\n",
      "our: 9.3270, 0.0344\n",
      "real: 5.9705, 0.0377\n",
      "exact tf: 6.0060, 8.3601\n",
      "exact np: 6.0060, 24.3969\n",
      "[64, 64, 3, 3] (224, 224)\n",
      "our: 6.2902, 0.0359\n",
      "real: 5.3196, 0.0383\n",
      "exact tf: 5.3414, 9.2406\n",
      "exact np: 5.3414, 24.3861\n",
      "[64, 64, 3, 3] (224, 224)\n",
      "our: 8.7121, 0.0343\n",
      "real: 6.9658, 0.0372\n",
      "exact tf: 6.9997, 9.1890\n",
      "exact np: 6.9997, 24.9708\n",
      "[64, 64, 3, 3] (224, 224)\n",
      "our: 5.3942, 0.0346\n",
      "real: 3.7930, 0.0370\n",
      "exact tf: 3.8162, 9.1697\n",
      "exact np: 3.8162, 24.4463\n",
      "[128, 64, 3, 3] (112, 112)\n",
      "our: 5.8978, 0.0338\n",
      "real: 4.6765, 0.0562\n",
      "exact tf: 4.7070, 3.0651\n",
      "exact np: 4.7070, 37.0092\n",
      "[128, 128, 3, 3] (112, 112)\n",
      "our: 7.2012, 0.0347\n",
      "real: 5.6888, 0.0888\n",
      "exact tf: 5.7223, 9.2648\n",
      "exact np: 5.7223, 31.9692\n",
      "[128, 128, 3, 3] (112, 112)\n",
      "our: 6.7763, 0.0351\n",
      "real: 4.3806, 0.0886\n",
      "exact tf: 4.4096, 9.3261\n",
      "exact np: 4.4096, 31.5828\n",
      "[128, 128, 3, 3] (112, 112)\n",
      "our: 7.5713, 0.0346\n",
      "real: 4.8574, 0.0895\n",
      "exact tf: 4.8851, 9.3140\n",
      "exact np: 4.8851, 32.0184\n",
      "[256, 128, 3, 3] (56, 56)\n",
      "our: 8.4541, 0.0354\n",
      "real: 7.3604, 0.1552\n",
      "exact tf: 7.3948, 3.8356\n",
      "exact np: 7.3948, 12.0333\n",
      "[256, 256, 3, 3] (56, 56)\n",
      "our: 8.0359, 0.0352\n",
      "real: 6.5486, 0.2650\n",
      "exact tf: 6.5829, 11.2795\n",
      "exact np: 6.5829, 29.5707\n",
      "[256, 256, 3, 3] (56, 56)\n",
      "our: 7.5683, 0.0346\n",
      "real: 6.3298, 0.2713\n",
      "exact tf: 6.3609, 11.2443\n",
      "exact np: 6.3609, 29.8919\n",
      "[256, 256, 3, 3] (56, 56)\n",
      "our: 9.1662, 0.0351\n",
      "real: 7.6363, 0.2729\n",
      "exact tf: 7.6767, 11.2369\n",
      "exact np: 7.6767, 30.0118\n",
      "[512, 256, 3, 3] (28, 28)\n",
      "our: 11.0030, 0.0340\n",
      "real: 9.9406, 0.5328\n",
      "exact tf: 9.9906, 5.4060\n",
      "exact np: 9.9906, 11.9969\n",
      "[512, 512, 3, 3] (28, 28)\n",
      "our: 10.4516, 0.0354\n",
      "real: 9.0391, 1.0422\n",
      "exact tf: 9.0939, 15.7642\n",
      "exact np: 9.0939, 33.0677\n",
      "[512, 512, 3, 3] (28, 28)\n",
      "our: 18.3737, 0.0347\n",
      "real: 17.5088, 1.0437\n",
      "exact tf: 17.5995, 15.8242\n",
      "exact np: 17.5995, 33.0822\n",
      "[512, 512, 3, 3] (28, 28)\n",
      "our: 7.5955, 0.0347\n",
      "real: 7.4351, 1.0408\n",
      "exact tf: 7.4794, 15.7958\n",
      "exact np: 7.4793, 32.7415\n"
     ]
    }
   ],
   "source": [
    "for name, param in resnet18.named_parameters():\n",
    "    if 'conv' in name:\n",
    "        out_channels, in_channels, H, W = param.shape\n",
    "        \n",
    "#         param_clone = param.clone()\n",
    "        param_tf = tf.convert_to_tensor(param.permute(2, 3, 1, 0).contiguous().clone().detach().cpu().numpy())\n",
    "    \n",
    "        if out_channels == 512:\n",
    "            inp_shape = (28, 28)\n",
    "        elif out_channels == 256:\n",
    "            inp_shape = (56, 56)\n",
    "        elif out_channels == 128:\n",
    "            inp_shape = (112, 112)\n",
    "        else:\n",
    "            inp_shape = (224, 224)\n",
    "\n",
    "        print(list(param.shape), inp_shape)\n",
    "        \n",
    "        our_time, our_bound = our_bounds_ch(param)\n",
    "        print(\"our: {:.4f}, {:.4f}\".format(our_bound, our_time))\n",
    "\n",
    "        real_time, real_singular = real_spectral_norm(param, (224, 224))\n",
    "        print(\"real: {:.4f}, {:.4f}\".format(real_singular, real_time))\n",
    "        \n",
    "        tf_time, tf_singular = singular_values_tf(param_tf, inp_shape)\n",
    "        print(\"exact tf: {:.4f}, {:.4f}\".format(tf_singular, tf_time))\n",
    "        \n",
    "        np_time, np_singular = singular_values_np(param_tf.numpy(), inp_shape)\n",
    "        print(\"exact np: {:.4f}, {:.4f}\".format(np_singular, np_time))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37_tensorflow",
   "language": "python",
   "name": "conda-env-py37_tensorflow-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
