{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.\n",
      "StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.\n"
     ]
    }
   ],
   "source": [
    "# Copyright 2020 Erik Härkönen. All rights reserved.\n",
    "# This file is licensed to you under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License. You may obtain a copy\n",
    "# of the License at http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "# Unless required by applicable law or agreed to in writing, software distributed under\n",
    "# the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS\n",
    "# OF ANY KIND, either express or implied. See the License for the specific language\n",
    "# governing permissions and limitations under the License.\n",
    "\n",
    "# Teaser: sequence of 3 interesting edits\n",
    "%matplotlib inline\n",
    "from notebook_init import *\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "rand = lambda : np.random.randint(np.iinfo(np.int32).max)\n",
    "outdir = Path('out/figures/teaser')\n",
    "makedirs(outdir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup_model(model_name, class_name, layer_name):\n",
    "    global inst, model, lat_comp, lat_mean, lat_std\n",
    "\n",
    "    use_w = 'StyleGAN' in model_name\n",
    "    if layer_name in [f'style.{i}' for i in range(9)]:\n",
    "        ## For Subnetwork\n",
    "        use_w = False\n",
    "    elif layer_name in [f'g_mapping.dense{i}' for i in range(8)]:\n",
    "        use_w = False\n",
    "        print(use_w)\n",
    "    \n",
    "    inst = get_instrumented_model(model_name, class_name, layer_name, device, use_w=use_w, inst=inst)\n",
    "    model = inst.model\n",
    "    print(f'w_primary : ', model.w_primary)\n",
    "    pc_config = Config(components=300, n=1_000_000, batch_size=200,\n",
    "        layer=layer_name, model=model_name, output_class=class_name, use_w=use_w)\n",
    "    \n",
    "    dump_name = get_or_compute(pc_config, inst)\n",
    "    print(dump_name)\n",
    "\n",
    "    with np.load(dump_name) as data:\n",
    "        lat_comp = data['lat_comp']\n",
    "        lat_mean = data['lat_mean']\n",
    "        lat_std = data['lat_stdev']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-0.1992, -0.4934,  0.3345, -0.6153, -0.6739,  1.1336, -0.4815, -0.7157,\n",
      "         0.1023, -0.3542], device='cuda:0')\n",
      "torch.Size([1, 512])\n",
      "tensor([ 0.1570,  0.4542, -0.6847, -0.9636,  0.7698, -0.4679, -0.7773, -0.6020,\n",
      "        -0.7998, -0.6603], device='cuda:0')\n",
      "tensor([-0.1992, -0.4934,  0.3345, -0.6153, -0.6739,  1.1336, -0.4815, -0.7157,\n",
      "         0.1023, -0.3542], device='cuda:0')\n",
      "torch.Size([1, 512])\n",
      "tensor([ 0.1570,  0.4542, -0.1369, -0.1927,  0.7698, -0.0936, -0.1555, -0.1204,\n",
      "        -0.1600, -0.1321], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "''' Check Subnetwork Forward '''\n",
    "\n",
    "seed = 366745668\n",
    "\n",
    "''' Forward Until style.8 '''\n",
    "layer_key = 'g_mapping.dense7'\n",
    "\n",
    "model_name, class_name, layer_name = 'StyleGAN', 'ffhq', layer_key\n",
    "#use_w = 'StyleGAN' in model_name\n",
    "use_w = False\n",
    "inst = get_instrumented_model(model_name, class_name, layer_name, device, use_w=use_w, inst=inst)\n",
    "model = inst.model\n",
    "\n",
    "rng = np.random.RandomState(seed)\n",
    "noise_dim, b = 512, 1\n",
    "noise = torch.from_numpy(\n",
    "        rng.standard_normal(noise_dim * b)\n",
    "        .reshape(b, noise_dim)).float().to(model.device)\n",
    "print(noise[0, :10])\n",
    "\n",
    "inst.retain_layer(layer_key)\n",
    "model.partial_forward(noise, layer_key)\n",
    "z = inst.retained_features()[layer_key]\n",
    "print(z.shape)\n",
    "print(z[0, :10])\n",
    "\n",
    "############################################################################################################\n",
    "''' Forward Until style '''\n",
    "layer_key = 'g_mapping'\n",
    "\n",
    "model_name, class_name, layer_name = 'StyleGAN', 'ffhq', layer_key\n",
    "#use_w = 'StyleGAN' in model_name\n",
    "use_w = False\n",
    "inst = get_instrumented_model(model_name, class_name, layer_name, device, use_w=use_w, inst=inst)\n",
    "model = inst.model\n",
    "\n",
    "rng = np.random.RandomState(seed)\n",
    "noise_dim, b = 512, 1\n",
    "noise = torch.from_numpy(\n",
    "        rng.standard_normal(noise_dim * b)\n",
    "        .reshape(b, noise_dim)).float().to(model.device)\n",
    "print(noise[0, :10])\n",
    "\n",
    "inst.retain_layer(layer_key)\n",
    "model.partial_forward(noise, layer_key)\n",
    "z = inst.retained_features()[layer_key]\n",
    "print(z.shape)\n",
    "print(z[0, :10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Find GANSpace on subnetwork of mapping network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "w_primary :  False\n",
      "Not cached\n",
      "[19.05 14:23] Computing stylegan-ffhq_g_mapping.dense7_ipca_c300_n1000000.npz\n",
      "Reusing InstrumentedModel instance\n",
      "Layer key g_mapping.dense7\n",
      "Feature shape: torch.Size([1, 512])\n",
      "B=200, N=1000000, dims=512, N/dims=1953.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sampling latents: 100%|███████████████████████████████████████████████████████████| 5010/5010 [00:16<00:00, 310.67it/s]\n",
      "Fitting batches (NB=2000): 100%|#####################################################| 500/500 [02:02<00:00,  4.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performing least squares regression\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Collecting samples: 100%|#########################################################| 5000/5000 [00:31<00:00, 158.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 0:03:23.164914\n",
      "C:\\Users\\Choi\\Desktop\\Research\\localbasis_RPCA_220419\\cache\\components\\stylegan-ffhq_g_mapping.dense7_ipca_c300_n1000000.npz\n"
     ]
    },
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_28212\\3259257944.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      6\u001b[0m     \u001b[0msetup_model\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'StyleGAN'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'ffhq'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34mf'g_mapping.dense{layer_idx}'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      7\u001b[0m     \u001b[1;31m#setup_model('StyleGAN2', 'ffhq', 'style.8')\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m \u001b[1;32massert\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      9\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     10\u001b[0m \u001b[0mseeds\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;36m6293435\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2105448342\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;31m# + [rand() for _ in range(1)]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# StyleGAN2 faces - emphasis on novel edits\n",
    "#setup_model('StyleGAN2', 'ffhq', 'style')\n",
    "\n",
    "#for layer_idx in range(0, 8):\n",
    "for layer_idx in [7]:\n",
    "    setup_model('StyleGAN', 'ffhq', f'g_mapping.dense{layer_idx}')\n",
    "    #setup_model('StyleGAN2', 'ffhq', 'style.8')\n",
    "assert False\n",
    "\n",
    "seeds = [6293435, 2105448342] # + [rand() for _ in range(1)]\n",
    "print(seeds)\n",
    "edits = ['wrinkles', 'white_hair', 'in_awe', 'overexposed']\n",
    "perform_edit(seeds, edits, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(300, 1, 512)\n"
     ]
    },
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_28212\\936383889.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     12\u001b[0m         \u001b[0mlat_comp\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'act_comp'\u001b[0m\u001b[1;33m]\u001b[0m  \u001b[1;31m## Since use_w = False, use act_comp instead of lat_comp\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     13\u001b[0m         \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlat_comp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 14\u001b[1;33m         \u001b[1;32massert\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     15\u001b[0m         \u001b[1;31m#np.save(os.path.join(save_dir, f'ganspace_directions_ffhq_stylegan2_style-{layer_idx}.npy'), lat_comp)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     16\u001b[0m         \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msave_dir\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbasis_name\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;34m'npy'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlat_comp\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "''' Rename '''\n",
    "\n",
    "save_dir = '../cache/components/'\n",
    "for basis_name in os.listdir(save_dir):\n",
    "#     if 'stylegan2' not in basis_name: continue\n",
    "#     layer_name = basis_name.split('_')[1]\n",
    "#     layer_idx = layer_name[-1]\n",
    "    \n",
    "    basis_path = os.path.join(save_dir, basis_name)\n",
    "    with np.load(basis_path) as data:        \n",
    "        #lat_comp = data['lat_comp']    \n",
    "        lat_comp = data['act_comp']  ## Since use_w = False, use act_comp instead of lat_comp\n",
    "        print(lat_comp.shape)\n",
    "        assert False\n",
    "        #np.save(os.path.join(save_dir, f'ganspace_directions_ffhq_stylegan2_style-{layer_idx}.npy'), lat_comp)\n",
    "        np.save(os.path.join(save_dir, basis_name[:-3]+'npy'), lat_comp)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Rename '''\n",
    "\n",
    "save_dir = '../cache/components/'\n",
    "for basis_name in os.listdir(save_dir):\n",
    "    if 'stylegan2' not in basis_name: continue\n",
    "    layer_name = basis_name.split('_')[1]\n",
    "    layer_idx = layer_name[-1]\n",
    "    \n",
    "    basis_path = os.path.join(save_dir, basis_name)\n",
    "    with np.load(basis_path) as data:        \n",
    "        #lat_comp = data['lat_comp']    \n",
    "        lat_comp = data['act_comp']  ## Since use_w = False, use act_comp instead of lat_comp\n",
    "        print(lat_comp.shape)\n",
    "        assert False\n",
    "        np.save(os.path.join(save_dir, f'ganspace_directions_ffhq_stylegan2_style-{layer_idx}.npy'), lat_comp)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "local_basis",
   "language": "python",
   "name": "local_basis"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
