{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"fid_save_ims.ipynb","provenance":[],"collapsed_sections":[],"machine_shape":"hm","authorship_tag":"ABX9TyNlPJusctE8UaAAZ4fydBmD"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","metadata":{"id":"sykZ83w6yAPJ"},"source":["#####################\n","# ## COLAB SETUP ## #\n","#####################\n","\n","from google.colab import drive\n","drive.mount('/content/drive')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"cm3_9YMVyHTA"},"source":["!pip install -q tfds-nightly\n","!pip install tensorflow_addons\n","!tfds --version"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"AkDVxFJRx-0d"},"source":["# log into gcloud project with bucket that contains tf files\n","!gcloud auth login\n","!gcloud config set project generative-experiments-tfrc\n","# !gcloud config set project poised-lens-305218\n","\n","# give colab access to project\n","from google.colab import auth\n","auth.authenticate_user()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"bjo9C_4axzSJ"},"source":["####################\n","# ## PARAMETERS ## #\n","####################\n","\n","config = {\n","  # paths for connecting to cloud storage\n","  #\"vm_path\": 'ebm_transformer_reu/',\n","  \"project_id\": 'generative-experiments-tfrc',\n","  \"gs_path\": 'gen-tfrc-uscentral1',\n","  \"exp_name\": 'celeb_a_hybrid_new_1',\n","  \"exp_dir\": '/content/drive/My Drive/Colab Notebooks/code_files/fid_tf2/fid_out',\n","\n","  # device type ('tpu' or 'gpu' or 'cpu')\n","  \"device_type\": 'gpu',\n","  # number of gpus if using gpu device\n","  'num_gpus': 1,\n","\n","  # exp params\n","  \"exp_type\": \"folder\",\n","  \"num_fid_rounds\": 520,\n","  \"batch_size\": 96,\n","  \"image_dims\": [64, 64, 3], # cifar10: 32x32, celeb_a 64x64, imagenet: 128x128\n","  \"split\": \"train\",\n","  #\"transform_type\": \"train\",\n","\n","  # data type and augmentation parameters\n","  \"data_type\": 'celeb_a', # cifar10, celeb_a, imagenet2012\n","  \"random_crop\": False,\n","\n","  # ebm network\n","  \"net_type\": 'ebm_sngan',\n","  \"ebm_weights\": \"gs://gen-tfrc-uscentral1/tfrc_out/imagenet2012/nonconv_resnet_norm_5e-5_21-09-23-02-32-05/checkpoints/ebm_220000.ckpt\",\n","  #\"ebm_weights\": \"gs://gen-tfrc-uscentral1/tfrc_out/imagenet2012/nonconv_resnet_21-09-18-21-01-16/checkpoints/ebm_230000.ckpt\",\n","  #\"ebm_weights\": 'gs://gen-tfrc-uscentral1/tfrc_out/imagenet/conv_trans_21-06-13-22-58-04/checkpoints/ebm_60000.ckpt',\n","\n","  # langevin sampling parameters\n","  \"mcmc_steps\": 200,\n","  \"epsilon\": 3e-3,\n","  \"mcmc_init\": \"coop\",\n","  \"mcmc_temp\": 1e-7,\n","  # clipping parameters\n","  \"clip_langevin_grad\": False,\n","  \"max_langevin_norm\": 0.1,\n","\n","  \"gen_type\": \"gen_sngan\",\n","  \"z_sz\": 128,\n","  \"truncation\": 1,\n","  #\"fixed_gen\": False,\n","  #\"gen_weights\": 'gs://gen-tfrc-uscentral1/tfrc_out/cifar10/nonconv_trans_no_aug_21-06-28-21-08-02/checkpoints/gen_60000.ckpt'\n","  #\"gen_weights\": \"gs://gen-tfrc-uscentral1/tfrc_out/imagenet2012/nonconv_resnet_21-09-18-21-01-16/checkpoints/gen_225000.ckpt\",\n","  \"gen_weights\": \"gs://gen-tfrc-uscentral1/tfrc_out/imagenet2012/nonconv_resnet_norm_5e-5_21-09-23-02-32-05/checkpoints/gen_220000.ckpt\"\n","}"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Z6K4-XW-xgXz"},"source":["# save images from tf2 model to png files to use original fid code for evaluation\n","\n","import os\n","import sys\n","sys.path.insert(0, '/content/drive/My Drive/Colab Notebooks/code_files/fid_tf2')\n","from datetime import datetime\n","import pickle\n","from tqdm import tqdm\n","import importlib\n","from pathlib import Path\n","\n","import numpy as np\n","from PIL import Image\n","\n","import tensorflow as tf\n","import tensorflow_datasets as tfds\n","\n","from init import init_strategy, initialize_nets_and_optim, initialize_data\n","from data import get_dataset\n","from utils import setup_exp, plot_ims\n","\n","import argparse\n","\n","\n","#def save_png_images(config, images, fid_iter, id):\n","#  for i in range(images.shape[0]):\n","#    im_out = (255 * np.clip((images[i] + 1) / 2, 0, 1)).astype(np.uint8)\n","#    Image.fromarray(im_out).save(os.path.join(config['exp_dir'], config['exp_name'], str(id), \n","#                                              'im_'+str(i+config['batch_size']*fid_iter)+'.png'))\n","\n","def save_samples(strategy, config, ebm, gen=None, train_iterator=None, save_str='samples.pdf'):\n","\n","  @tf.function\n","  def langevin_update(images_in): # can pass in the iterator!\n","    images_samp_init = tf.identity(images_in)\n","    if config['mcmc_init'] == 'coop':\n","      # re-draw samples to avoid duplication on tpu device\n","      images_samp_init = tf.identity(gen(images_samp_init))\n","    images_samp = tf.identity(images_samp_init)\n","\n","    # langevin updates\n","    if config['mcmc_steps'] > 0:\n","      for i in tf.range(int(config['mcmc_steps'])):\n","        with tf.GradientTape() as tape:\n","          tape.watch(images_samp)\n","          energy = tf.math.reduce_sum(ebm(images_samp, training=False)) / config['mcmc_temp']\n","        grads = tape.gradient(energy, images_samp)\n","        # clip gradient norm (set to large value that won't interfere with standard dynamics)\n","        if config['clip_langevin_grad']:\n","          grads = tf.clip_by_norm(grads, config['max_langevin_norm'] / ((config['epsilon'] ** 2) / 2), axes=[1, 2, 3])\n","\n","        # update images\n","        images_samp -= ((config['epsilon'] ** 2) / 2) * grads\n","        images_samp += config['epsilon'] * tf.random.normal(shape=tpu_tensor_size)\n","    return images_samp, images_samp_init\n","\n","  per_replica_batch_size = config['batch_size'] // strategy.num_replicas_in_sync\n","  images_np_1 = np.zeros([0] + config['image_dims'])\n","  images_np_2 = np.zeros([0] + config['image_dims'])\n","\n","  for i in range(config['num_fid_rounds']):\n","    print('Batch {} of {}'.format(i+1, config['num_fid_rounds']))\n","\n","    # data images\n","    images_data = next(train_iterator)\n","\n","    # generate samples from model\n","    if config['mcmc_init'] == 'data':\n","      sample_init = next(gen)\n","    elif config['mcmc_init'] == 'coop':\n","      z_init_tf = gen.generate_latent_z(config['batch_size'])\n","      # z_init_tf = tf.random.normal([config['batch_size'], config['z_sz']])\n","      def get_z_init(ctx):\n","        rep_id = ctx.replica_id_in_sync_group\n","        return z_init_tf[(rep_id*per_replica_batch_size):((rep_id+1)*per_replica_batch_size)]\n","      sample_init = strategy.experimental_distribute_values_from_function(get_z_init)\n","    else:\n","      raise ValueError('Invalid mcmc_init')\n","    images_sample, images_sample_init = strategy.run(langevin_update, args=(sample_init,))\n","\n","    if i == 0:\n","      plot_ims(os.path.join(config['exp_dir'], config['exp_name'], 'images/' + save_str), \n","               strategy.gather(images_sample, 0))\n","      plot_ims(os.path.join(config['exp_dir'], config['exp_name'], 'images/init_' + save_str), \n","               strategy.gather(images_sample_init, 0))\n","\n","    # record batch images\n","    p1 = Path(os.path.join(config['exp_dir'], config['exp_name'], 'numpy_out/images1.npy'))\n","    with p1.open('ab') as f:\n","        images_data_rescale = np.rint(255 * (np.clip(images_data.numpy(), -1, 1) + 1) / 2)\n","        np.save(f, images_data_rescale.astype(np.uint8))\n","    p2 = Path(os.path.join(config['exp_dir'], config['exp_name'], 'numpy_out/images2.npy'))\n","    with p2.open('ab') as f:\n","        images_sample_rescale = np.rint(255 * (np.clip(images_sample.numpy(), -1, 1) + 1) / 2)\n","        np.save(f, images_sample_rescale.astype(np.uint8))\n","\n","\n","###############\n","# ## SETUP ## #\n","###############\n","\n","# setup folders, save code, set seed and get device\n","setup_exp(os.path.join(config['exp_dir'], config['exp_name']), \n","          ['images', 'numpy_out'],\n","          [os.path.join('/content/drive/My Drive/Colab Notebooks/code_files/fid_tf2', code_file) for code_file in ['fid_save_ims.ipynb', 'nets.py', 'utils.py', 'data.py', 'init.py']],\n","          config['gs_path'])\n","\n","if config['device_type'] == 'tpu':\n","  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])\n","  tf.config.experimental_connect_to_cluster(resolver)\n","  # This is the TPU initialization code that has to be at the beginning.\n","  tf.tpu.experimental.initialize_tpu_system(resolver)\n","  print(\"All devices: \", tf.config.list_logical_devices('TPU'))\n","  # Set up TPU Distribution\n","  strategy = tf.distribute.TPUStrategy(resolver)\n","else:\n","  strategy = init_strategy(config)\n","\n","\n","##################################################\n","# ## INITIALIZE NETS, DATA, PERSISTENT STATES ## #\n","##################################################\n","\n","# load nets and optim\n","ebm, _, gen, _ = initialize_nets_and_optim(config, strategy)\n","ebm.trainable = False\n","if gen is not None:\n","  gen.trainable = False\n","\n","# test deterministic output of ebm\n","with strategy.scope():\n","  z_test = tf.random.normal(shape=[3]+config['image_dims'])\n","  z_out_1 = ebm(z_test)\n","  z_out_2 = ebm(z_test[0:2])\n","z_out_1 = strategy.gather(z_out_1, axis=0)\n","z_out_2 = strategy.gather(z_out_2, axis=0)\n","print('EBM Determinism Test: ', z_out_1[0], z_out_2[0])\n","\n","# test deterministic output of gen\n","if gen is not None:\n","  with strategy.scope():\n","    gen_z = gen.generate_latent_z(3)\n","    gen_out_1 = gen(gen_z)\n","    gen_out_2 = gen(gen_z[0:2])\n","  gen_out_1 = strategy.gather(gen_out_1, axis=0)\n","  gen_out_2 = strategy.gather(gen_out_2, axis=0)\n","  print('Gen Determinism Test: ', tf.math.reduce_max(tf.math.abs(gen_out_1[0] - gen_out_2[0])))\n","\n","# generator for data\n","train_iterator, _, _ = initialize_data(config, strategy)\n","if config['mcmc_init'] == 'data':\n","  # generator for data mcmc init\n","  gen, _, _ = initialize_data(config, strategy)\n","\n","# Calculate per replica batch size, and distribute the datasets\n","per_replica_batch_size = config['batch_size'] // strategy.num_replicas_in_sync\n","batch_size = per_replica_batch_size * strategy.num_replicas_in_sync\n","tpu_tensor_size = [per_replica_batch_size] + config['image_dims']\n","\n","# save bank of data and model samples as two np arrays\n","save_samples(strategy, config, ebm, gen, train_iterator)"],"execution_count":null,"outputs":[]}]}