{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import time\n",
    "from notebook_init import *\n",
    "\n",
    "out_root = Path('out/iter')\n",
    "makedirs(out_root, exist_ok=True)\n",
    "torch.autograd.set_grad_enabled(True)\n",
    "rand = lambda : np.random.randint(np.iinfo(np.int32).max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading ../models/checkpoints/stylegan2/stylegan2_ffhq_1024.pt\n"
     ]
    }
   ],
   "source": [
    "use_w = True\n",
    "dataset = 'ffhq'\n",
    "inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)\n",
    "model = inst.model\n",
    "model.truncation = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds_ffhq = [366745668]#, 327039870]#, 1302626592, 1235907362, 1150529896, 1881703227]\n",
    "# seeds_car = [1839348078, 1150529896]#, 1881703227]\n",
    "# seeds_cat = [889979887,263281582]\n",
    "# seeds_ffhq = [rand() for _ in range(50)]\n",
    "start_direction = 0\n",
    "n_direction = 10\n",
    "num_frames = 7\n",
    "num_steps = 12\n",
    "layer_mode = 'all'\n",
    "perturb_intensity = 12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "366745668\n",
      "Getting iter trvs took 1.1073486804962158\n",
      "Getting iter trvs took 1.1689033508300781\n",
      "Getting iter trvs took 1.222881555557251\n",
      "Getting iter trvs took 1.1012294292449951\n",
      "Getting iter trvs took 1.1423900127410889\n",
      "Getting iter trvs took 1.1574020385742188\n",
      "Getting iter trvs took 1.1662263870239258\n",
      "Getting iter trvs took 1.1923329830169678\n",
      "Getting iter trvs took 1.1449370384216309\n",
      "Getting iter trvs took 1.0948066711425781\n"
     ]
    }
   ],
   "source": [
    "save_img = True\n",
    "\n",
    "for seed in seeds_ffhq:\n",
    "    print(seed)\n",
    "    rng = np.random.RandomState(seed)\n",
    "    noise, z, z_local_basis, z_sv, noise_basis = get_random_local_basis(model, rng)\n",
    "    layer_start, layer_end = LayerMode(layer_mode)\n",
    "    strips = []\n",
    "    with torch.no_grad():\n",
    "        for trvs_idx in range(start_direction, start_direction+n_direction):\n",
    "            #batch_frames = create_strip(inst, 'latent', 'style', [z], 0, orient * local_lat_comp[i], 0, local_lat_std[i], perturb_intensity, 0, 18, num_frames=7)[0]\n",
    "            batch_frames, z_batch = create_strip_iter(inst, 'latent', 'style', trvs_idx, sigma = perturb_intensity, layer_start = layer_start, layer_end = layer_end, \n",
    "                                             random_state = None, noise = noise, compare_basis = True, num_frames=num_frames, num_steps = num_steps*2, only_pos=0, scale = 1, verbose=0)\n",
    "            strips.append(np.hstack(pad_frames(batch_frames[0])))\n",
    "    grid = np.vstack(strips)\n",
    "\n",
    "    if save_img:\n",
    "        Image.fromarray(np.uint8(grid*255)).save(out_root / f'iter_trvs_{dataset}_{seed}_{start_direction}to{start_direction+n_direction-1}_step{num_steps}_ptb{perturb_intensity}_{layer_mode}.jpg') \n",
    "    else:\n",
    "        plt.figure(figsize=(20,40))\n",
    "        plt.imshow(grid)\n",
    "        plt.axis('off')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "local_basis1",
   "language": "python",
   "name": "local_basis1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
