{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import time\n",
    "from notebook_init import *\n",
    "\n",
    "out_root = Path('out/1dim')\n",
    "makedirs(out_root, exist_ok=True)\n",
    "torch.autograd.set_grad_enabled(True)\n",
    "rand = lambda : np.random.randint(np.iinfo(np.int32).max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading ../models/checkpoints/stylegan2/stylegan2_ffhq_1024.pt\n"
     ]
    }
   ],
   "source": [
    "use_w = True\n",
    "dataset = 'ffhq'\n",
    "inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)\n",
    "model = inst.model\n",
    "model.truncation = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds_ffhq = [366745668]#, 327039870]#, 1302626592, 1235907362, 1150529896, 1881703227]\n",
    "# seeds_car = [1839348078, 1150529896]#, 1881703227]\n",
    "# seeds_cat = [889979887,263281582]\n",
    "# seeds_ffhq = [1957070232, 946307783, 327039870]\n",
    "# seeds_ffhq = [1302626592]#,327039870]\n",
    "# seeds_ffhq = [rand() for _ in range(100)]\n",
    "start_direction = 0\n",
    "n_direction = 10\n",
    "num_frames = 7\n",
    "layer_mode = 'all'\n",
    "perturb_intensity = 12"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Local Basis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "366745668\n"
     ]
    }
   ],
   "source": [
    "save_img = True\n",
    "\n",
    "for seed in seeds_ffhq:\n",
    "    print(seed)\n",
    "    rng = np.random.RandomState(seed)\n",
    "    noise, z, z_local_basis, z_sv, noise_basis = get_random_local_basis(model, rng)\n",
    "    lb_dir = z_local_basis.t().unsqueeze(1).detach().to(device) #reshape    \n",
    "    layer_start, layer_end = LayerMode(layer_mode)    \n",
    "    strips = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i in range(start_direction, start_direction + n_direction):\n",
    "            batch_frames = create_strip_centered(inst, 'latent', 'style', [z], 0, lb_dir[i], 0, 1, 0, z, perturb_intensity, layer_start, layer_end, num_frames=num_frames, only_pos=0, only_neg=0)[0]\n",
    "            strips.append(np.hstack(pad_frames(batch_frames)))\n",
    "    grid = np.vstack(strips)\n",
    "    \n",
    "    if save_img:\n",
    "        Image.fromarray(np.uint8(grid*255)).save(out_root / f'LocalBasis_{dataset}_{seed}_{start_direction}to{start_direction+n_direction-1}_ptb{perturb_intensity}.jpg')\n",
    "    else:\n",
    "        plt.figure(figsize=(20,40))\n",
    "#         plt.figure(figsize=(10,20))\n",
    "        plt.imshow(grid)\n",
    "        plt.axis('off')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GANSpace - Only for FFHQ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "gs_dir = np.load('./global_directions/ganspace_directions_ffhq_stylegan2.npy')#Note! Only ffhq is provided.\n",
    "gs_dir = torch.from_numpy(gs_dir).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_img = True\n",
    "\n",
    "for seed in seeds_ffhq:\n",
    "    rng = np.random.RandomState(seed)\n",
    "    noise = torch.from_numpy(\n",
    "            rng.standard_normal(512 * 1)\n",
    "            .reshape(1, 512)).float().to(model.device) #[N, 512]\n",
    "    if model.w_primary:\n",
    "        z = model.model.style(noise)\n",
    "    layer_start, layer_end = LayerMode(layer_mode)\n",
    "    strips = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i in range(start_direction, start_direction + n_direction):\n",
    "            batch_frames = create_strip_centered(inst, 'latent', 'style', [z], 0, gs_dir[i], 0, 1, 0, z, perturb_intensity, layer_start, layer_end, num_frames=num_frames, only_pos=0, only_neg=0)[0]\n",
    "            strips.append(np.hstack(pad_frames(batch_frames)))\n",
    "    grid = np.vstack(strips)\n",
    "    \n",
    "    if save_img:\n",
    "        Image.fromarray(np.uint8(grid*255)).save(out_root / f'ganspace_{seed}_{start_direction}to{start_direction+n_direction-1}_ptb{perturb_intensity}.jpg')\n",
    "    else:\n",
    "        plt.figure(figsize=(20,40))\n",
    "        plt.imshow(grid)\n",
    "        plt.axis('off')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SeFa - Only for FFHQ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "sf_dir = np.load('./global_directions/sefa_directions_ffhq.npy') #Note! Only ffhq is provided.\n",
    "sf_dir = torch.from_numpy(sf_dir).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_img = True\n",
    "\n",
    "for seed in seeds_ffhq:\n",
    "    rng = np.random.RandomState(seed)\n",
    "    noise = torch.from_numpy(\n",
    "            rng.standard_normal(512 * 1)\n",
    "            .reshape(1, 512)).float().to(model.device) #[N, 512]\n",
    "    if model.w_primary:\n",
    "        z = model.model.style(noise)\n",
    "    layer_start, layer_end = LayerMode(layer_mode)\n",
    "    strips = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i in range(start_direction, start_direction + n_direction):\n",
    "            batch_frames = create_strip_centered(inst, 'latent', 'style', [z], 0, sf_dir[i], 0, 1, 0, z, perturb_intensity, layer_start, layer_end, num_frames=num_frames, only_pos=0, only_neg=0)[0]\n",
    "            strips.append(np.hstack(pad_frames(batch_frames)))\n",
    "    grid = np.vstack(strips)\n",
    "    \n",
    "    if save_img:\n",
    "        Image.fromarray(np.uint8(grid*255)).save(out_root / f'sefa_{seed}_{start_direction}to{start_direction+n_direction-1}_ptb{perturb_intensity}.jpg')\n",
    "    else:\n",
    "        plt.figure(figsize=(20,40))\n",
    "        plt.imshow(grid)\n",
    "        plt.axis('off')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ganspace",
   "language": "python",
   "name": "ganspace"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
