{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.\n",
      "StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.\n"
     ]
    }
   ],
   "source": [
    "# Copyright 2020 Erik Härkönen. All rights reserved.\n",
    "# This file is licensed to you under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License. You may obtain a copy\n",
    "# of the License at http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "# Unless required by applicable law or agreed to in writing, software distributed under\n",
    "# the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS\n",
    "# OF ANY KIND, either express or implied. See the License for the specific language\n",
    "# governing permissions and limitations under the License.\n",
    "\n",
    "# Teaser: sequence of 3 interesting edits\n",
    "%matplotlib inline\n",
    "from notebook_init import *\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "rand = lambda : np.random.randint(np.iinfo(np.int32).max)\n",
    "outdir = Path('out/figures/teaser')\n",
    "makedirs(outdir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "from config import Config\n",
    "from decomposition import get_or_compute\n",
    "\n",
    "def setup_model(model_name, class_name, layer_name):\n",
    "    global inst, model, lat_comp, lat_mean, lat_std\n",
    "\n",
    "    use_w = 'StyleGAN' in model_name\n",
    "    if layer_name in [f'style.{i}' for i in range(9)]:\n",
    "        ## For Subnetwork\n",
    "        use_w = False\n",
    "    inst = get_instrumented_model(model_name, class_name, layer_name, device, use_w=use_w, inst=inst)\n",
    "    model = inst.model\n",
    "    pc_config = Config(components=512, n=1_000_000, batch_size=200,\n",
    "        layer=layer_name, model=model_name, output_class=class_name, use_w=use_w)\n",
    "    \n",
    "    dump_name = get_or_compute(pc_config, inst)\n",
    "    print(dump_name)\n",
    "\n",
    "    with np.load(dump_name) as data:\n",
    "        lat_comp = data['lat_comp']\n",
    "        lat_mean = data['lat_mean']\n",
    "        lat_std = data['lat_stdev']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading ../models/checkpoints/stylegan2/stylegan2_ffhq-config-e_1024.pt\n",
      "tensor([-0.1992, -0.4934,  0.3345, -0.6153, -0.6739,  1.1336, -0.4815, -0.7157,\n",
      "         0.1023, -0.3542], device='cuda:0')\n",
      "torch.Size([1, 512])\n",
      "tensor([-0.0313, -0.3193, -0.1695, -0.2328, -0.2872, -0.5201, -0.2707, -0.0635,\n",
      "         1.0268,  0.7440], device='cuda:0')\n",
      "tensor([-0.1992, -0.4934,  0.3345, -0.6153, -0.6739,  1.1336, -0.4815, -0.7157,\n",
      "         0.1023, -0.3542], device='cuda:0')\n",
      "torch.Size([1, 512])\n",
      "tensor([-0.0313, -0.3193, -0.1695, -0.2328, -0.2872, -0.5201, -0.2707, -0.0635,\n",
      "         1.0268,  0.7440], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "''' Check Subnetwork Forward '''\n",
    "\n",
    "seed = 366745668\n",
    "\n",
    "''' Forward Until style.8 '''\n",
    "layer_key = 'style.8'\n",
    "\n",
    "model_name, class_name, layer_name = 'StyleGAN2', 'ffhq-config-e', layer_key\n",
    "#use_w = 'StyleGAN' in model_name\n",
    "use_w = False\n",
    "inst = get_instrumented_model(model_name, class_name, layer_name, device, use_w=use_w, inst=inst)\n",
    "model = inst.model\n",
    "\n",
    "rng = np.random.RandomState(seed)\n",
    "noise_dim, b = 512, 1\n",
    "noise = torch.from_numpy(\n",
    "        rng.standard_normal(noise_dim * b)\n",
    "        .reshape(b, noise_dim)).float().to(model.device)\n",
    "print(noise[0, :10])\n",
    "\n",
    "inst.retain_layer(layer_key)\n",
    "model.partial_forward(noise, layer_key)\n",
    "z = inst.retained_features()[layer_key]\n",
    "print(z.shape)\n",
    "print(z[0, :10])\n",
    "\n",
    "############################################################################################################\n",
    "''' Forward Until style '''\n",
    "layer_key = 'style'\n",
    "\n",
    "model_name, class_name, layer_name = 'StyleGAN2', 'ffhq-config-e', layer_key\n",
    "#use_w = 'StyleGAN' in model_name\n",
    "use_w = False\n",
    "inst = get_instrumented_model(model_name, class_name, layer_name, device, use_w=use_w, inst=inst)\n",
    "model = inst.model\n",
    "\n",
    "rng = np.random.RandomState(seed)\n",
    "noise_dim, b = 512, 1\n",
    "noise = torch.from_numpy(\n",
    "        rng.standard_normal(noise_dim * b)\n",
    "        .reshape(b, noise_dim)).float().to(model.device)\n",
    "print(noise[0, :10])\n",
    "\n",
    "inst.retain_layer(layer_key)\n",
    "model.partial_forward(noise, layer_key)\n",
    "z = inst.retained_features()[layer_key]\n",
    "print(z.shape)\n",
    "print(z[0, :10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Find GANSpace on subnetwork of mapping network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Not cached\n",
      "[11.09 23:47] Computing stylegan2-ffhq-config-e_style.1_ipca_c512_n1000000.npz\n",
      "Reusing InstrumentedModel instance\n",
      "Layer key style.1\n",
      "Feature shape: torch.Size([1, 512])\n",
      "B=200, N=1000000, dims=512, N/dims=1953.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sampling latents: 100%|██████████| 5010/5010 [00:31<00:00, 158.38it/s]\n",
      "Fitting batches (NB=2000): 100%|##########| 500/500 [01:34<00:00,  5.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performing least squares regression\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Collecting samples: 100%|##########| 5000/5000 [00:50<00:00, 98.45it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 0:03:35.778750\n",
      "/home/chjw1475/Research/localbasis_frechet_mean_220911/cache/components/stylegan2-ffhq-config-e_style.1_ipca_c512_n1000000.npz\n",
      "Not cached\n",
      "[11.09 23:51] Computing stylegan2-ffhq-config-e_style.2_ipca_c512_n1000000.npz\n",
      "Reusing InstrumentedModel instance\n",
      "Layer key style.2\n",
      "Feature shape: torch.Size([1, 512])\n",
      "B=200, N=1000000, dims=512, N/dims=1953.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sampling latents: 100%|██████████| 5010/5010 [00:33<00:00, 150.70it/s]\n",
      "Fitting batches (NB=2000): 100%|##########| 500/500 [01:27<00:00,  5.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performing least squares regression\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Collecting samples: 100%|##########| 5000/5000 [00:45<00:00, 110.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 0:03:23.908581\n",
      "/home/chjw1475/Research/localbasis_frechet_mean_220911/cache/components/stylegan2-ffhq-config-e_style.2_ipca_c512_n1000000.npz\n",
      "Not cached\n",
      "[11.09 23:54] Computing stylegan2-ffhq-config-e_style.3_ipca_c512_n1000000.npz\n",
      "Reusing InstrumentedModel instance\n",
      "Layer key style.3\n",
      "Feature shape: torch.Size([1, 512])\n",
      "B=200, N=1000000, dims=512, N/dims=1953.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sampling latents: 100%|██████████| 5010/5010 [00:31<00:00, 159.62it/s]\n",
      "Fitting batches (NB=2000): 100%|##########| 500/500 [01:22<00:00,  6.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performing least squares regression\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Collecting samples: 100%|##########| 5000/5000 [00:45<00:00, 110.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 0:03:15.699172\n",
      "/home/chjw1475/Research/localbasis_frechet_mean_220911/cache/components/stylegan2-ffhq-config-e_style.3_ipca_c512_n1000000.npz\n",
      "Not cached\n",
      "[11.09 23:58] Computing stylegan2-ffhq-config-e_style.4_ipca_c512_n1000000.npz\n",
      "Reusing InstrumentedModel instance\n",
      "Layer key style.4\n",
      "Feature shape: torch.Size([1, 512])\n",
      "B=200, N=1000000, dims=512, N/dims=1953.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sampling latents: 100%|██████████| 5010/5010 [00:31<00:00, 158.54it/s]\n",
      "Fitting batches (NB=2000): 100%|##########| 500/500 [01:22<00:00,  6.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performing least squares regression\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Collecting samples: 100%|##########| 5000/5000 [00:44<00:00, 112.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 0:03:16.628121\n",
      "/home/chjw1475/Research/localbasis_frechet_mean_220911/cache/components/stylegan2-ffhq-config-e_style.4_ipca_c512_n1000000.npz\n",
      "Not cached\n",
      "[12.09 00:01] Computing stylegan2-ffhq-config-e_style.5_ipca_c512_n1000000.npz\n",
      "Reusing InstrumentedModel instance\n",
      "Layer key style.5\n",
      "Feature shape: torch.Size([1, 512])\n",
      "B=200, N=1000000, dims=512, N/dims=1953.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sampling latents: 100%|██████████| 5010/5010 [00:31<00:00, 156.84it/s]\n",
      "Fitting batches (NB=2000): 100%|##########| 500/500 [01:19<00:00,  6.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performing least squares regression\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Collecting samples: 100%|##########| 5000/5000 [00:45<00:00, 110.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 0:03:14.750994\n",
      "/home/chjw1475/Research/localbasis_frechet_mean_220911/cache/components/stylegan2-ffhq-config-e_style.5_ipca_c512_n1000000.npz\n",
      "Not cached\n",
      "[12.09 00:04] Computing stylegan2-ffhq-config-e_style.6_ipca_c512_n1000000.npz\n",
      "Reusing InstrumentedModel instance\n",
      "Layer key style.6\n",
      "Feature shape: torch.Size([1, 512])\n",
      "B=200, N=1000000, dims=512, N/dims=1953.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sampling latents: 100%|██████████| 5010/5010 [00:32<00:00, 153.72it/s]\n",
      "Fitting batches (NB=2000): 100%|##########| 500/500 [01:21<00:00,  6.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performing least squares regression\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Collecting samples: 100%|##########| 5000/5000 [00:43<00:00, 114.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 0:03:17.923077\n",
      "/home/chjw1475/Research/localbasis_frechet_mean_220911/cache/components/stylegan2-ffhq-config-e_style.6_ipca_c512_n1000000.npz\n",
      "Not cached\n",
      "[12.09 00:08] Computing stylegan2-ffhq-config-e_style.7_ipca_c512_n1000000.npz\n",
      "Reusing InstrumentedModel instance\n",
      "Layer key style.7\n",
      "Feature shape: torch.Size([1, 512])\n",
      "B=200, N=1000000, dims=512, N/dims=1953.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sampling latents: 100%|██████████| 5010/5010 [00:30<00:00, 164.28it/s]\n",
      "Fitting batches (NB=2000): 100%|##########| 500/500 [01:19<00:00,  6.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performing least squares regression\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Collecting samples: 100%|##########| 5000/5000 [00:41<00:00, 121.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 0:03:10.832288\n",
      "/home/chjw1475/Research/localbasis_frechet_mean_220911/cache/components/stylegan2-ffhq-config-e_style.7_ipca_c512_n1000000.npz\n",
      "Not cached\n",
      "[12.09 00:11] Computing stylegan2-ffhq-config-e_style.8_ipca_c512_n1000000.npz\n",
      "Reusing InstrumentedModel instance\n",
      "Layer key style.8\n",
      "Feature shape: torch.Size([1, 512])\n",
      "B=200, N=1000000, dims=512, N/dims=1953.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sampling latents: 100%|██████████| 5010/5010 [00:32<00:00, 153.01it/s]\n",
      "Fitting batches (NB=2000): 100%|##########| 500/500 [01:21<00:00,  6.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performing least squares regression\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Collecting samples: 100%|##########| 5000/5000 [00:42<00:00, 118.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 0:03:17.125200\n",
      "/home/chjw1475/Research/localbasis_frechet_mean_220911/cache/components/stylegan2-ffhq-config-e_style.8_ipca_c512_n1000000.npz\n"
     ]
    },
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-12-7bf7be054e56>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer_idx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m9\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0msetup_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'StyleGAN2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'ffhq-config-e'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf'style.{layer_idx}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mseeds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m6293435\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2105448342\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;31m# + [rand() for _ in range(1)]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# StyleGAN2 faces - emphasis on novel edits\n",
    "#setup_model('StyleGAN2', 'ffhq', 'style')\n",
    "\n",
    "for layer_idx in range(1, 9):\n",
    "    setup_model('StyleGAN2', 'ffhq-config-e', f'style.{layer_idx}')\n",
    "assert False\n",
    "\n",
    "seeds = [6293435, 2105448342] # + [rand() for _ in range(1)]\n",
    "print(seeds)\n",
    "edits = ['wrinkles', 'white_hair', 'in_awe', 'overexposed']\n",
    "perform_edit(seeds, edits, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Rename '''\n",
    "\n",
    "save_dir = '../cache/components/'\n",
    "for basis_name in os.listdir(save_dir):\n",
    "    if 'stylegan2' not in basis_name: continue\n",
    "    layer_name = basis_name.split('_')[1]\n",
    "    layer_idx = layer_name[-1]\n",
    "    \n",
    "    basis_path = os.path.join(save_dir, basis_name)\n",
    "    with np.load(basis_path) as data:        \n",
    "        #lat_comp = data['lat_comp']    \n",
    "        lat_comp = data['act_comp']  ## Since use_w = False, use act_comp instead of lat_comp\n",
    "        np.save(os.path.join(save_dir, f'ganspace_directions_ffhq_stylegan2_style-{layer_idx}.npy'), lat_comp)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "local_basis",
   "language": "python",
   "name": "local_basis"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
