import tensorflow as tf
import numpy as np 
import scipy.io  
import argparse 
import struct
import errno
import time                       
import cv2
import os

import io
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.logging.set_verbosity(tf.logging.FATAL)
nn_distance_module=tf.load_op_library('/mnt/hard0/ICLR_469/chamfer-distance/tf_nndistance_so.so')

'''
  parsing and configuration
'''
def parse_args():

  desc = "TensorFlow implementation of 'A Neural Algorithm for Artistic Style'"  
  parser = argparse.ArgumentParser(description=desc)

  # options for single image
  parser.add_argument('--verbose', action='store_true',
    help='Boolean flag indicating if statements should be printed to the console.')

  parser.add_argument('--img_name', type=str, 
    default='result',
    help='Filename of the output image.')

  parser.add_argument('--style_imgs_weights', nargs='+', type=float,
    default=[1.0],
    help='Interpolation weights of each of the style images. (example: 0.5 0.5)')

  parser.add_argument('--init_img_type', type=str, 
    default='content',
    choices=['random', 'content', 'style', 'init'],
    help='Image used to initialize the network. (default: %(default)s)')
  
  parser.add_argument('--max_size', type=int, 
    default=512,
    help='Maximum width or height of the input images. (default: %(default)s)')
  
  parser.add_argument('--content_weight', type=float, 
    default=5e0,
    help='Weight for the content loss function. (default: %(default)s)')
  
  parser.add_argument('--style_weight', type=float, 
    default=1e4,
    help='Weight for the style loss function. (default: %(default)s)')
  
  parser.add_argument('--tv_weight', type=float, 
    default=1e-3,
    help='Weight for the total variational loss function. Set small (e.g. 1e-3). (default: %(default)s)')

  parser.add_argument('--content_loss_function', type=int,
    default=1,
    choices=[1, 2, 3],
    help='Different constants for the content layer loss function. (default: %(default)s)')
  
  parser.add_argument('--content_layers', nargs='+', type=str, 
    default=['conv4_2'],
    help='VGG19 layers used for the content image. (default: %(default)s)')
  
  parser.add_argument('--style_layers', nargs='+', type=str,
    default=['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1'],
    help='VGG19 layers used for the style image. (default: %(default)s)')
  
  parser.add_argument('--content_layer_weights', nargs='+', type=float, 
    default=[1.0], 
    help='Contributions (weights) of each content layer to loss. (default: %(default)s)')
  
  parser.add_argument('--style_layer_weights', nargs='+', type=float, 
    default=[0.2, 0.2, 0.2, 0.2, 0.2],
    help='Contributions (weights) of each style layer to loss. (default: %(default)s)')
    
  parser.add_argument('--original_colors', action='store_true',
    help='Transfer the style but not the colors.')

  parser.add_argument('--color_convert_type', type=str,
    default='yuv',
    choices=['yuv', 'ycrcb', 'luv', 'lab'],
    help='Color space for conversion to original colors (default: %(default)s)')

  parser.add_argument('--color_convert_time', type=str,
    default='after',
    choices=['after', 'before'],
    help='Time (before or after) to convert to original colors (default: %(default)s)')

  parser.add_argument('--style_mask', action='store_true',
    help='Transfer the style to masked regions.')

  parser.add_argument('--style_mask_imgs', nargs='+', type=str, 
    default=None,
    help='Filenames of the style mask images (example: face_mask.png) (default: %(default)s)')
  
  parser.add_argument('--noise_ratio', type=float, 
    default=1.0, 
    help="Interpolation value between the content image and noise image if the network is initialized with 'random'.")

  parser.add_argument('--seed', type=int, 
    default=0,
    help='Seed for the random number generator. (default: %(default)s)')
  
  parser.add_argument('--model_weights', type=str, 
    default='imagenet-vgg-verydeep-19.mat',
    help='Weights and biases of the VGG-19 network.')
  
  parser.add_argument('--pooling_type', type=str,
    default='avg',
    choices=['avg', 'max'],
    help='Type of pooling in convolutional neural network. (default: %(default)s)')
  
  parser.add_argument('--device', type=str, 
    default='/gpu:0',
    choices=['/gpu:0', '/cpu:0'],
    help='GPU or CPU mode.  GPU mode requires NVIDIA CUDA. (default|recommended: %(default)s)')
  
  parser.add_argument('--img_output_dir', type=str, 
    default='./image_output',
    help='Relative or absolute directory path to output image and data.')
  
  # optimizations
  parser.add_argument('--optimizer', type=str, 
    default='lbfgs',
    choices=['lbfgs', 'adam'],
    help='Loss minimization optimizer.  L-BFGS gives better results.  Adam uses less memory. (default|recommended: %(default)s)')
  
  parser.add_argument('--learning_rate', type=float, 
    default=1e0, 
    help='Learning rate parameter for the Adam optimizer. (default: %(default)s)')
  
  parser.add_argument('--max_iterations', type=int, 
    default=1000,
    help='Max number of iterations for the Adam or L-BFGS optimizer. (default: %(default)s)')

  parser.add_argument('--print_iterations', type=int, 
    default=50,
    help='Number of iterations between optimizer print statements. (default: %(default)s)')
  

  parser.add_argument('--view_num', type=int, default=3, help='.')
  args = parser.parse_args()

  # normalize weights
  args.style_layer_weights   = normalize(args.style_layer_weights)
  args.content_layer_weights = normalize(args.content_layer_weights)
  args.style_imgs_weights    = normalize(args.style_imgs_weights)

  # create directories for output
  if args.video:
    maybe_make_directory(args.video_output_dir)
  else:
    maybe_make_directory(args.img_output_dir)

  return args

'''
  pre-trained vgg19 convolutional neural network
  remark: layers are manually initialized for clarity.
'''

def build_model(input_img):
  if args.verbose: print('\nBUILDING VGG-19 NETWORK')
  net = {}
  b, h, w, d = input_img.shape
  print(b,h,w,d)
  if args.verbose: print('loading model weights...')
  vgg_rawnet     = scipy.io.loadmat(args.model_weights)
  vgg_layers     = vgg_rawnet['layers'][0]
  if args.verbose: print('constructing layers...')
  net['input']   = tf.Variable(np.zeros((b, h, w, d), dtype=np.float32))

  if args.verbose: print('LAYER GROUP 1')
  net['conv1_1'] = conv_layer('conv1_1', net['input'], W=get_weights(vgg_layers, 0))
  net['relu1_1'] = relu_layer('relu1_1', net['conv1_1'], b=get_bias(vgg_layers, 0))

  net['conv1_2'] = conv_layer('conv1_2', net['relu1_1'], W=get_weights(vgg_layers, 2))
  net['relu1_2'] = relu_layer('relu1_2', net['conv1_2'], b=get_bias(vgg_layers, 2))
  
  net['pool1']   = pool_layer('pool1', net['relu1_2'])

  if args.verbose: print('LAYER GROUP 2')  
  net['conv2_1'] = conv_layer('conv2_1', net['pool1'], W=get_weights(vgg_layers, 5))
  net['relu2_1'] = relu_layer('relu2_1', net['conv2_1'], b=get_bias(vgg_layers, 5))
  
  net['conv2_2'] = conv_layer('conv2_2', net['relu2_1'], W=get_weights(vgg_layers, 7))
  net['relu2_2'] = relu_layer('relu2_2', net['conv2_2'], b=get_bias(vgg_layers, 7))
  
  net['pool2']   = pool_layer('pool2', net['relu2_2'])
  
  if args.verbose: print('LAYER GROUP 3')
  net['conv3_1'] = conv_layer('conv3_1', net['pool2'], W=get_weights(vgg_layers, 10))
  net['relu3_1'] = relu_layer('relu3_1', net['conv3_1'], b=get_bias(vgg_layers, 10))

  net['conv3_2'] = conv_layer('conv3_2', net['relu3_1'], W=get_weights(vgg_layers, 12))
  net['relu3_2'] = relu_layer('relu3_2', net['conv3_2'], b=get_bias(vgg_layers, 12))

  net['conv3_3'] = conv_layer('conv3_3', net['relu3_2'], W=get_weights(vgg_layers, 14))
  net['relu3_3'] = relu_layer('relu3_3', net['conv3_3'], b=get_bias(vgg_layers, 14))

  net['conv3_4'] = conv_layer('conv3_4', net['relu3_3'], W=get_weights(vgg_layers, 16))
  net['relu3_4'] = relu_layer('relu3_4', net['conv3_4'], b=get_bias(vgg_layers, 16))

  net['pool3']   = pool_layer('pool3', net['relu3_4'])

  if args.verbose: print('LAYER GROUP 4')
  net['conv4_1'] = conv_layer('conv4_1', net['pool3'], W=get_weights(vgg_layers, 19))
  net['relu4_1'] = relu_layer('relu4_1', net['conv4_1'], b=get_bias(vgg_layers, 19))

  net['conv4_2'] = conv_layer('conv4_2', net['relu4_1'], W=get_weights(vgg_layers, 21))
  net['relu4_2'] = relu_layer('relu4_2', net['conv4_2'], b=get_bias(vgg_layers, 21))

  net['conv4_3'] = conv_layer('conv4_3', net['relu4_2'], W=get_weights(vgg_layers, 23))
  net['relu4_3'] = relu_layer('relu4_3', net['conv4_3'], b=get_bias(vgg_layers, 23))

  net['conv4_4'] = conv_layer('conv4_4', net['relu4_3'], W=get_weights(vgg_layers, 25))
  net['relu4_4'] = relu_layer('relu4_4', net['conv4_4'], b=get_bias(vgg_layers, 25))

  net['pool4']   = pool_layer('pool4', net['relu4_4'])

  if args.verbose: print('LAYER GROUP 5')
  net['conv5_1'] = conv_layer('conv5_1', net['pool4'], W=get_weights(vgg_layers, 28))
  net['relu5_1'] = relu_layer('relu5_1', net['conv5_1'], b=get_bias(vgg_layers, 28))

  net['conv5_2'] = conv_layer('conv5_2', net['relu5_1'], W=get_weights(vgg_layers, 30))
  net['relu5_2'] = relu_layer('relu5_2', net['conv5_2'], b=get_bias(vgg_layers, 30))

  net['conv5_3'] = conv_layer('conv5_3', net['relu5_2'], W=get_weights(vgg_layers, 32))
  net['relu5_3'] = relu_layer('relu5_3', net['conv5_3'], b=get_bias(vgg_layers, 32))

  net['conv5_4'] = conv_layer('conv5_4', net['relu5_3'], W=get_weights(vgg_layers, 34))
  net['relu5_4'] = relu_layer('relu5_4', net['conv5_4'], b=get_bias(vgg_layers, 34))

  net['pool5']   = pool_layer('pool5', net['relu5_4'])

  return net, vgg_layers

def conv_layer(layer_name, layer_input, W):
  conv = tf.nn.conv2d(layer_input, W, strides=[1, 1, 1, 1], padding='SAME')
  if args.verbose: print('--{} | shape={} | weights_shape={}'.format(layer_name, 
    conv.get_shape(), W.get_shape()))
  return conv

def relu_layer(layer_name, layer_input, b):
  relu = tf.nn.relu(layer_input + b)
  if args.verbose: 
    print('--{} | shape={} | bias_shape={}'.format(layer_name, relu.get_shape(), 
      b.get_shape()))
  return relu

def pool_layer(layer_name, layer_input):
  if args.pooling_type == 'avg':
    pool = tf.nn.avg_pool(layer_input, ksize=[1, 2, 2, 1], 
      strides=[1, 2, 2, 1], padding='SAME')
  elif args.pooling_type == 'max':
    pool = tf.nn.max_pool(layer_input, ksize=[1, 2, 2, 1], 
      strides=[1, 2, 2, 1], padding='SAME')
  if args.verbose: 
    print('--{}   | shape={}'.format(layer_name, pool.get_shape()))
  return pool

def get_weights(vgg_layers, i):
  weights = vgg_layers[i][0][0][2][0][0]
  W = tf.constant(weights)
  return W

def get_bias(vgg_layers, i):
  bias = vgg_layers[i][0][0][2][0][1]
  b = tf.constant(np.reshape(bias, (bias.size)))
  return b

'''
  'a neural algorithm for artistic style' loss functions
'''
def content_layer_loss(p, x):
  _, h, w, d = p.get_shape()
  M = h.value * w.value
  N = d.value
  if args.content_loss_function   == 1:
    K = 1. / (2. * N**0.5 * M**0.5)
  elif args.content_loss_function == 2:
    K = 1. / (N * M)
  elif args.content_loss_function == 3:  
    K = 1. / 2.
  loss = K * tf.reduce_sum(tf.pow((x - p), 2))
  return loss

def style_layer_loss(a, x):
  b, h, w, d = a.get_shape()
  M = h.value * w.value
  N = d.value
  A = gram_matrix(a, b, M, N)
  G = gram_matrix(x, b, M, N)
  loss = (1./(4 * N**2 * M**2)) * tf.reduce_sum(tf.pow((G - A), 2))
  return loss

def gram_matrix(x, batch, area, depth):
  F = tf.reshape(x, (batch, area, depth))
  G = tf.matmul(tf.transpose(F, (0, 2, 1)), F)
  return G

def mask_style_layer(a, x, mask_img):
  _, h, w, d = a.get_shape()
  mask = get_mask_image(mask_img, w.value, h.value)
  mask = tf.convert_to_tensor(mask)
  tensors = []
  for _ in range(d.value): 
    tensors.append(mask)
  mask = tf.stack(tensors, axis=2)
  mask = tf.stack(mask, axis=0)
  mask = tf.expand_dims(mask, 0)
  a = tf.multiply(a, mask)
  x = tf.multiply(x, mask)
  return a, x

def sum_masked_style_losses(sess, net, style_imgs):
  total_style_loss = 0.
  weights = args.style_imgs_weights
  masks = args.style_mask_imgs
  for img, img_weight, img_mask in zip(style_imgs, weights, masks):
    sess.run(net['input'].assign(img))
    style_loss = 0.
    for layer, weight in zip(args.style_layers, args.style_layer_weights):
      a = sess.run(net[layer])
      x = net[layer]
      a = tf.convert_to_tensor(a)
      a, x = mask_style_layer(a, x, img_mask)
      style_loss += style_layer_loss(a, x) * weight
    style_loss /= float(len(args.style_layers))
    total_style_loss += (style_loss * img_weight)
  total_style_loss /= float(len(style_imgs))
  return total_style_loss

def sum_style_losses(sess, net, style_imgs):
  total_style_loss = 0.
  weights = args.style_imgs_weights
  print(style_imgs.shape, weights, net['input'].shape)
  for img, img_weight in zip([style_imgs], weights):
    print(style_imgs.shape, img.shape)
    sess.run(net['input'].assign(img))
    style_loss = 0.
    for layer, weight in zip(args.style_layers, args.style_layer_weights):
      a = sess.run(net[layer])
      x = net[layer]
      a = tf.convert_to_tensor(a)
      style_loss += style_layer_loss(a, x) * weight
    style_loss /= float(len(args.style_layers))
    total_style_loss += (style_loss * img_weight)
  total_style_loss /= float(len(style_imgs))
  return total_style_loss

def sum_content_losses(sess, net, content_img):
  sess.run(net['input'].assign(content_img))
  content_loss = 0.
  for layer, weight in zip(args.content_layers, args.content_layer_weights):
    p = sess.run(net[layer])
    x = net[layer]
    p = tf.convert_to_tensor(p)
    content_loss += content_layer_loss(p, x) * weight
  content_loss /= float(len(args.content_layers))
  return content_loss

'''
  utilities and i/o
'''
def read_image(path):
  # bgr image
  img = cv2.imread(path, cv2.IMREAD_COLOR)
  check_image(img, path)
  img = img.astype(np.float32)
  img = preprocess(img)
  return img

def write_image(path, img):
  img = postprocess(img)
  cv2.imwrite(path, img)

def preprocess(img):
  imgpre = np.copy(img)
  # bgr to rgb
  imgpre = imgpre[...,::-1]
  # shape (h, w, d) to (1, h, w, d)
  imgpre = imgpre[np.newaxis,:,:,:]
  imgpre -= np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))
  return imgpre

def postprocess(img):
  imgpost = np.copy(img)
  imgpost += np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))
  # shape (1, h, w, d) to (h, w, d)
  imgpost = imgpost[0]
  imgpost = np.clip(imgpost, 0, 255).astype('uint8')
  # rgb to bgr
  imgpost = imgpost[...,::-1]
  return imgpost


def read_weights_file(path):
  lines = open(path).readlines()
  header = list(map(int, lines[0].split(' ')))
  w = header[0]
  h = header[1]
  vals = np.zeros((h, w), dtype=np.float32)
  for i in range(1, len(lines)):
    line = lines[i].rstrip().split(' ')
    vals[i-1] = np.array(list(map(np.float32, line)))
    vals[i-1] = list(map(lambda x: 0. if x < 255. else 1., vals[i-1]))
  # expand to 3 channels
  weights = np.dstack([vals.astype(np.float32)] * 3)
  return weights

def normalize(weights):
  denom = sum(weights)
  if denom > 0.:
    return [float(i) / denom for i in weights]
  else: return [0.] * len(weights)

def maybe_make_directory(dir_path):
  if not os.path.exists(dir_path):  
    os.makedirs(dir_path)

def check_image(img, path):
  if img is None:
    raise OSError(errno.ENOENT, "No such file", path)

'''
  rendering -- where the magic happens
'''
def stylize(content_img, style_imgs, init_img, depth_imgs, cams,
            warp0, warp2, list_world_pts0, list_world_pts2, list_occ_mask, warp3_4=None, frame=None):
  # with tf.device(args.device), tf.Session() as sess:
  with tf.device(args.device), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    # setup network
    tick = time.time()
    net, vgg_layers = build_model(init_img)
    tock = time.time()
    print('Single image elapsed time: {}'.format(tock - tick))

    tick = time.time()
    # style loss
    if args.style_mask:
      L_style = sum_masked_style_losses(sess, net, style_imgs)
    else:
      L_style = sum_style_losses(sess, net, style_imgs)
    
    # content loss
    L_content = sum_content_losses(sess, net, content_img)
    
    # denoising loss
    L_tv = tf.image.total_variation(net['input'])

    L_warp, L_stroke = sum_warp_losses5(sess, net, warp0, 'input', 'output0', list_occ_mask, vgg_layers)

    # loss weights
    alpha = args.content_weight
    beta  = args.style_weight
    theta = args.tv_weight
    
    # total loss
    L_total  = alpha * L_content
    L_total += beta  * L_style
    L_total += theta * L_tv
    L_total += 1000 * L_warp
    L_total += 10 * L_stroke

    # optimization algorithm
    optimizer = get_optimizer(L_total)

    if args.optimizer == 'adam':
      minimize_with_adam(sess, net, optimizer, init_img, L_total)
    elif args.optimizer == 'lbfgs':
      minimize_with_lbfgs(sess, net, optimizer, init_img)

    output_img = sess.run(net['input'])

    tock = time.time()
    print('Single image elapsed time: {}'.format(tock - tick))
    if args.original_colors:
      output_img = convert_to_original_colors(np.copy(content_img), output_img)



    write_image_output_mvs(output_img, content_img, style_imgs, init_img, mode="original")
    vis_occ_masked = sess.run(net['occ_masked'])
    write_image_output_mvs(vis_occ_masked, content_img, style_imgs, init_img, mode="warp")
    # write_image_output_mvs(output_img, content_img, style_imgs, init_img, mode="occ")

def minimize_with_lbfgs(sess, net, optimizer, init_img):
  if args.verbose: print('\nMINIMIZING LOSS USING: L-BFGS OPTIMIZER')
  init_op = tf.global_variables_initializer()
  sess.run(init_op)
  sess.run(net['input'].assign(init_img))
  optimizer.minimize(sess)

def minimize_with_adam(sess, net, optimizer, init_img, loss):
  if args.verbose: print('\nMINIMIZING LOSS USING: ADAM OPTIMIZER')
  train_op = optimizer.minimize(loss)
  init_op = tf.global_variables_initializer()
  sess.run(init_op)
  sess.run(net['input'].assign(init_img))
  iterations = 0
  while (iterations < args.max_iterations):
    sess.run(train_op)
    if iterations % args.print_iterations == 0 and args.verbose:
      curr_loss = loss.eval()
      print("At iterate {}\tf=  {}".format(iterations, curr_loss))
    iterations += 1

def get_optimizer(loss):
  print_iterations = args.print_iterations if args.verbose else 0
  if args.optimizer == 'lbfgs':
    optimizer = tf.contrib.opt.ScipyOptimizerInterface(
      loss, method='L-BFGS-B',
      options={'maxiter': args.max_iterations,
                  'disp': print_iterations})
  elif args.optimizer == 'adam':
    optimizer = tf.train.AdamOptimizer(args.learning_rate)
  return optimizer


def write_image_output(output_img, content_img, style_imgs, init_img):
  out_dir = os.path.join(args.img_output_dir, args.img_name)
  maybe_make_directory(out_dir)
  img_path = os.path.join(out_dir, args.img_name+'.png')
  content_path = os.path.join(out_dir, 'content.png')
  init_path = os.path.join(out_dir, 'init.png')

  write_image(img_path, output_img)
  write_image(content_path, content_img)
  write_image(init_path, init_img)
  index = 0
  for style_img in style_imgs:
    path = os.path.join(out_dir, 'style_'+str(index)+'.png')
    write_image(path, style_img)
    index += 1
  
  # save the configuration settings
  out_file = os.path.join(out_dir, 'meta_data.txt')
  f = open(out_file, 'w')
  f.write('image_name: {}\n'.format(args.img_name))
  f.write('content: {}\n'.format(args.content_img))
  index = 0
  for style_img, weight in zip(args.style_imgs, args.style_imgs_weights):
    f.write('styles['+str(index)+']: {} * {}\n'.format(weight, style_img))
    index += 1
  index = 0
  if args.style_mask_imgs is not None:
    for mask in args.style_mask_imgs:
      f.write('style_masks['+str(index)+']: {}\n'.format(mask))
      index += 1
  f.write('init_type: {}\n'.format(args.init_img_type))
  f.write('content_weight: {}\n'.format(args.content_weight))
  f.write('style_weight: {}\n'.format(args.style_weight))
  f.write('tv_weight: {}\n'.format(args.tv_weight))
  f.write('content_layers: {}\n'.format(args.content_layers))
  f.write('style_layers: {}\n'.format(args.style_layers))
  f.write('optimizer_type: {}\n'.format(args.optimizer))
  f.write('max_iterations: {}\n'.format(args.max_iterations))
  f.write('max_image_size: {}\n'.format(args.max_size))
  f.close()

def write_image_output_mvs(output_img, content_img, style_imgs, init_img, mode="original"):
  out_dir = os.path.join(args.img_output_dir, args.img_name)
  maybe_make_directory(out_dir)
  if mode == "original":
    # img_path1 = os.path.join(out_dir, args.img_name + '1.png')
    # img_path2 = os.path.join(out_dir, args.img_name + '2.png')
    # img_path3 = os.path.join(out_dir, args.img_name + '3.png')
    img_path1 = os.path.join(out_dir, args.img_name + '1.jpg')
    img_path2 = os.path.join(out_dir, args.img_name + '2.jpg')
    img_path3 = os.path.join(out_dir, args.img_name + '3.jpg')
    content_path1 = os.path.join(out_dir, 'content1.png')
    content_path2 = os.path.join(out_dir, 'content2.png')
    content_path3 = os.path.join(out_dir, 'content3.png')
    init_path1 = os.path.join(out_dir, 'init1.png')
    init_path2 = os.path.join(out_dir, 'init2.png')
    init_path3 = os.path.join(out_dir, 'init3.png')

    write_image(img_path1, output_img[0:1])
    write_image(img_path2, output_img[1:2])
    write_image(img_path3, output_img[2:3])
    write_image(content_path1, content_img[0:1])
    write_image(content_path2, content_img[1:2])
    write_image(content_path3, content_img[2:3])
    write_image(init_path1, init_img[0:1])
    write_image(init_path2, init_img[1:2])
    write_image(init_path3, init_img[2:3])

    index = 0
    for style_img in style_imgs:
      path = os.path.join(out_dir, 'style_' + str(index) + '.png')
      write_image(path, style_img[np.newaxis,:,:,:])
      index += 1

    # save the configuration settings
    out_file = os.path.join(out_dir, 'meta_data.txt')
    f = open(out_file, 'w')
    f.write('image_name: {}\n'.format(args.img_name))
    f.write('content: {}\n'.format(args.content_img))
    index = 0
    for style_img, weight in zip(args.style_imgs, args.style_imgs_weights):
      f.write('styles[' + str(index) + ']: {} * {}\n'.format(weight, style_img))
      index += 1
    index = 0
    if args.style_mask_imgs is not None:
      for mask in args.style_mask_imgs:
        f.write('style_masks[' + str(index) + ']: {}\n'.format(mask))
        index += 1
    f.write('init_type: {}\n'.format(args.init_img_type))
    f.write('content_weight: {}\n'.format(args.content_weight))
    f.write('style_weight: {}\n'.format(args.style_weight))
    f.write('tv_weight: {}\n'.format(args.tv_weight))
    f.write('content_layers: {}\n'.format(args.content_layers))
    f.write('style_layers: {}\n'.format(args.style_layers))
    f.write('optimizer_type: {}\n'.format(args.optimizer))
    f.write('max_iterations: {}\n'.format(args.max_iterations))
    f.write('max_image_size: {}\n'.format(args.max_size))
    f.close()

  elif mode == "warp":
    img_path1to1 = os.path.join(out_dir, args.img_name + '1to1.png')
    img_path1to2 = os.path.join(out_dir, args.img_name + '1to2_warp.png')
    img_path1to3 = os.path.join(out_dir, args.img_name + '1to3_warp.png')
    img_path2to1 = os.path.join(out_dir, args.img_name + '2to1_warp.png')
    img_path2to2 = os.path.join(out_dir, args.img_name + '2to2.png')
    img_path2to3 = os.path.join(out_dir, args.img_name + '2to3_warp.png')
    img_path3to1 = os.path.join(out_dir, args.img_name + '3to1_warp.png')
    img_path3to2 = os.path.join(out_dir, args.img_name + '3to2_warp.png')
    img_path3to3 = os.path.join(out_dir, args.img_name + '3to3.png')

    write_image(img_path1to1, output_img[0:1])
    write_image(img_path1to2, output_img[1:2])
    write_image(img_path1to3, output_img[2:3])
    write_image(img_path2to1, output_img[3:4])
    write_image(img_path2to2, output_img[4:5])
    write_image(img_path2to3, output_img[5:6])
    write_image(img_path3to1, output_img[6:7])
    write_image(img_path3to2, output_img[7:8])
    write_image(img_path3to3, output_img[8:9])

'''
  image loading and processing
'''
def get_init_image(init_type, content_img, style_imgs, init_img, frame=None):
  if init_type == 'content':
    return content_img
  elif init_type == 'style':
    return style_imgs[0]
  elif init_type == 'random':
    init_img = get_noise_image(args.noise_ratio, content_img)
    return init_img
  elif init_type == 'init':
    return init_img
  # only for video frames
  elif init_type == 'prev':
    init_img = get_prev_frame(frame)
    return init_img
  elif init_type == 'prev_warped':
    init_img = get_prev_warped_frame(frame)
    return init_img

def get_content_frame(frame):
  fn = args.content_frame_frmt.format(str(frame).zfill(4))
  path = os.path.join(args.video_input_dir, fn)
  img = read_image(path)
  return img

def get_content_image(content_img):
  path = os.path.join(args.content_img_dir, content_img)
   # bgr image
  img = cv2.imread(path, cv2.IMREAD_COLOR)
  check_image(img, path)
  img = img.astype(np.float32)
  h, w, d = img.shape
  mx = args.max_size
  # resize if > max size
  if h > w and h > mx:
    w = (float(mx) / float(h)) * w
    img = cv2.resize(img, dsize=(int(w), mx), interpolation=cv2.INTER_AREA)
  if w > mx:
    h = (float(mx) / float(w)) * h
    img = cv2.resize(img, dsize=(mx, int(h)), interpolation=cv2.INTER_AREA)
  img = preprocess(img)
  return img

def get_image(content_img):
  print(args.init_img_dir, content_img)
  path = os.path.join(args.init_img_dir, content_img)
   # bgr image
  img = cv2.imread(path, cv2.IMREAD_COLOR)
  check_image(img, path)
  img = img.astype(np.float32)
  h, w, d = img.shape
  mx = args.max_size
  # resize if > max size
  if h > w and h > mx:
    w = (float(mx) / float(h)) * w
    img = cv2.resize(img, dsize=(int(w), mx), interpolation=cv2.INTER_AREA)
  if w > mx:
    h = (float(mx) / float(w)) * h
    img = cv2.resize(img, dsize=(mx, int(h)), interpolation=cv2.INTER_AREA)
  img = preprocess(img)
  return img

def get_style_images(content_img):
  _, ch, cw, cd = content_img.shape
  style_imgs = []
  for style_fn in args.style_imgs:
    path = os.path.join(args.style_imgs_dir, style_fn)
    # bgr image
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    check_image(img, path)
    img = img.astype(np.float32)
    img = cv2.resize(img, dsize=(cw, ch), interpolation=cv2.INTER_AREA)
    img = preprocess(img)
    style_imgs.append(img)
  return style_imgs

def get_noise_image(noise_ratio, content_img):
  np.random.seed(args.seed)
  noise_img = np.random.uniform(-20., 20., content_img.shape).astype(np.float32)
  img = noise_ratio * noise_img + (1.-noise_ratio) * content_img
  return img

def get_mask_image(mask_img, width, height):
  path = os.path.join(args.content_img_dir, mask_img)
  img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
  check_image(img, path)
  img = cv2.resize(img, dsize=(width, height), interpolation=cv2.INTER_AREA)
  img = img.astype(np.float32)
  mx = np.amax(img)
  img /= mx
  return img

def get_content_weights(frame, prev_frame):
  forward_fn = args.content_weights_frmt.format(str(prev_frame), str(frame))
  backward_fn = args.content_weights_frmt.format(str(frame), str(prev_frame))
  forward_path = os.path.join(args.video_input_dir, forward_fn)
  backward_path = os.path.join(args.video_input_dir, backward_fn)
  forward_weights = read_weights_file(forward_path)
  backward_weights = read_weights_file(backward_path)
  return forward_weights #, backward_weights

def warp_image(src, flow):
  _, h, w = flow.shape
  flow_map = np.zeros(flow.shape, dtype=np.float32)
  for y in range(h):
    flow_map[1,y,:] = float(y) + flow[1,y,:]
  for x in range(w):
    flow_map[0,:,x] = float(x) + flow[0,:,x]
  # remap pixels to optical flow
  dst = cv2.remap(
    src, flow_map[0], flow_map[1], 
    interpolation=cv2.INTER_CUBIC, borderMode=cv2.BORDER_TRANSPARENT)
  return dst

def convert_to_original_colors(content_img, stylized_img):
  content_img  = postprocess(content_img)
  stylized_img = postprocess(stylized_img)
  if args.color_convert_type == 'yuv':
    cvt_type = cv2.COLOR_BGR2YUV
    inv_cvt_type = cv2.COLOR_YUV2BGR
  elif args.color_convert_type == 'ycrcb':
    cvt_type = cv2.COLOR_BGR2YCR_CB
    inv_cvt_type = cv2.COLOR_YCR_CB2BGR
  elif args.color_convert_type == 'luv':
    cvt_type = cv2.COLOR_BGR2LUV
    inv_cvt_type = cv2.COLOR_LUV2BGR
  elif args.color_convert_type == 'lab':
    cvt_type = cv2.COLOR_BGR2LAB
    inv_cvt_type = cv2.COLOR_LAB2BGR
  content_cvt = cv2.cvtColor(content_img, cvt_type)
  stylized_cvt = cv2.cvtColor(stylized_img, cvt_type)
  c1, _, _ = cv2.split(stylized_cvt)
  _, c2, c3 = cv2.split(content_cvt)
  merged = cv2.merge((c1, c2, c3))
  dst = cv2.cvtColor(merged, inv_cvt_type).astype(np.float32)
  dst = preprocess(dst)
  return dst





def occluison_mask(list_points):
  list_ref = []
  for ref in range(args.view_num):
    list_view = []
    for view in range(0, ref):
      list_view.append(list_ref[view][2 * ref + 1])
      list_view.append(list_ref[view][2 * ref])
    list_view.append(tf.ones((1, 20480)))
    list_view.append(tf.ones((1, 20480)))
    for view in range(ref+1, args.view_num):
      rgbxyz_ref = list_points[ref]
      rgbxyz_view = list_points[view]
      dist_ref, idx_view4ref, dist_view, idx_ref4view = point_cloud_distance(rgbxyz_ref, rgbxyz_view)

      tf.stop_gradient(idx_ref4view)
      tf.stop_gradient(idx_view4ref)


      xyz_view4ref = tf.gather(rgbxyz_view, idx_view4ref)
      xyz_ref4view = tf.gather(rgbxyz_ref, idx_ref4view)

      thr = 2
      dist_ref2view = tf.math.sqrt(tf.reduce_sum(tf.pow(rgbxyz_ref - xyz_view4ref, 2), 2))
      dist_view2ref = tf.math.sqrt(tf.reduce_sum(tf.pow(rgbxyz_view - xyz_ref4view, 2), 2))
      occ_mask_ref = tf.cast((dist_ref2view < thr), tf.float32)
      occ_mask_view = tf.cast((dist_view2ref < thr), tf.float32)

      list_view.append(occ_mask_ref)
      list_view.append(occ_mask_view)
    list_ref.append(list_view)

  return list_ref


def point_cloud_distance(xyz1, xyz2):
  """
  For each point in Vs computes distance to the closest point in Vt
  """
  xyz1 = tf.expand_dims(xyz1,0)
  xyz2 = tf.expand_dims(xyz2,0)
  dist1, idx1, dist2, idx2 = nn_distance_module.nn_distance(xyz1, xyz2)
  return dist1, idx1, dist2, idx2

def warp_grid_oog(view_num, inp_cams, inp_depths, inp_shape, layer):
  list_ref = []
  list_world_pts = []

  depth_shape = tf.shape(inp_depths)
  inp_scale = tf.cast(depth_shape[2], 'float32') / tf.cast(inp_shape[2], 'float32')
  print(inp_shape)
  if inp_scale != 1:
    inp_depths = tf.squeeze(tf.compat.v1.image.resize(tf.expand_dims(inp_depths, 3), [inp_shape[1], inp_shape[2]], method='nearest', align_corners=False), 3)

  for ref in range(view_num):
    view_cam_t = tf.cast(tf.squeeze(tf.slice(inp_cams, [ref, 0, 0, 0], [1, -1, -1, -1]), axis=0), 'float32')
    depth_t = tf.squeeze(tf.slice(inp_depths, [ref, 0, 0], [1, -1, -1]), axis=0)
    cam_intrinsic_t = tf.squeeze(tf.slice(view_cam_t, [1, 0, 0], [1, 3, 3]), axis=0)
    R_t = tf.squeeze(tf.slice(view_cam_t, [0, 0, 0], [1, 3, 3]), axis=0)
    t_t = tf.squeeze(tf.slice(view_cam_t, [0, 0, 3], [1, 3, 1]), axis=0)
    R_inv_t = tf.linalg.inv(R_t)

    list_view = []
    for view in range(view_num):
      view_cam_s =  tf.cast(tf.squeeze(tf.slice(inp_cams, [view, 0, 0, 0], [1, -1, -1, -1]), axis=0), 'float32')
      cam_intrinsic_s =tf.squeeze(tf.slice(view_cam_s, [1, 0, 0], [1, 3, 3]), axis=0)
      R_s = tf.squeeze(tf.slice(view_cam_s, [0, 0, 0], [1, 3, 3]), axis=0)
      t_s = tf.squeeze(tf.slice(view_cam_s, [0, 0, 3], [1, 3, 1]), axis=0)

      pixel_grids = get_pixel_grids_nx(inp_shape[1], inp_shape[2], inp_scale)
      pixel_grids = tf.expand_dims(pixel_grids, 0)
      pixel_grids = tf.reshape(pixel_grids, (3, -1))
      uv = tf.matmul(tf.linalg.inv(cam_intrinsic_t), pixel_grids)
      cam_points = (uv * tf.reshape(depth_t, (1, -1)))
      world_points = tf.matmul(R_inv_t, cam_points - t_t)

      if ref == view:
        print('world_points', world_points.shape)
        list_world_pts.append(tf.transpose(world_points,perm=[1,0]))

      num_world_points = tf.shape(world_points)[1]
      transformed_pts = tf.matmul(R_s, world_points) + t_s
      x = transformed_pts[0]
      y = transformed_pts[1]
      z = transformed_pts[2]
      normal_uv = tf.stack(
        [tf.div(x, z), tf.div(y, z), tf.ones_like(x)],
        axis=-1)
      uv = tf.matmul(normal_uv, tf.transpose(cam_intrinsic_s, (1, 0)))
      uv = uv[:, :2]

      grid = tf.reshape(uv, (num_world_points, 1, 2))
      if layer == 'input':
        x_warped = 4*grid[..., 0] - 0.125
        y_warped = 4*grid[..., 1] - 0.125
      else:
        x_warped = grid[..., 0] - 0.5
        y_warped = grid[..., 1] - 0.5
      xy_warped = tf.concat((x_warped, y_warped), 1)
      list_view.append(xy_warped)
    list_ref.append(tf.stack(list_view, 0))
  warp_map = tf.stack(list_ref, 0)

  return warp_map, list_world_pts

def get_pixel_grids_nx(height, width, inv_scale):
  x_linspace = tf.linspace(0.5 * inv_scale, tf.cast(width, 'float32') * inv_scale - 0.5 * inv_scale, width)
  y_linspace = tf.linspace(0.5 * inv_scale, tf.cast(height, 'float32') * inv_scale - 0.5 * inv_scale, height)
  x_coordinates, y_coordinates = tf.meshgrid(x_linspace, y_linspace)
  x_coordinates = tf.reshape(x_coordinates, [-1])
  y_coordinates = tf.reshape(y_coordinates, [-1])
  ones = tf.ones_like(x_coordinates)
  indices_grid = tf.concat([x_coordinates, y_coordinates, ones], 0)
  return indices_grid



def sum_warp_losses5(sess, net, warp_map, layer, output_key, list_occ_mask, vgg_layers):
  warp_loss = 0.
  warpgram_loss = 0.
  stroke_loss = 0.
  x = net[layer]
  _, h, w, d = x.get_shape()
  M = h.value * w.value
  N = d.value

  list_warped = []
  list_occ = []
  list_occ2 = []
  list_save = []
  list_save_masked = []
  for ref in range(args.view_num):
    for view in range(args.view_num):
      x21s = interpolate_batch2(tf.slice(x, [view, 0, 0, 0], [1, -1, -1, -1]),
                                tf.reshape(tf.slice(warp_map, [ref, view, 0, 0], [1, 1, -1, 1]), (1, -1)),
                                tf.reshape(tf.slice(warp_map, [ref, view, 0, 1], [1, 1, -1, 1]), (1, -1)))
      occ = tf.cast((x21s > 0), 'float32')

      list_warped.append(x21s)
      list_occ.append(occ)

      if layer == 'input':
        list_save.append(tf.reshape(x21s, (1, h, w, 3)))
        occ_mask_view0 = tf.reshape(list_occ_mask[view][2 * ref + 1], (1, h / 4, w / 4, 1))
        occ_mask_view0 = tf.image.resize(occ_mask_view0, size=(h, w), method='nearest', align_corners=False)
        occ_mask_view0 = tf.reshape(occ_mask_view0, (1, h * w, 1))

        list_save_masked.append(tf.reshape(x21s * occ_mask_view0, (1, h, w, 3)))
        list_occ2.append(occ_mask_view0)
      else:
        y = tf.compat.v1.image.resize(net['input'], [128, 160], method='nearest', align_corners=False)
        x21ss = interpolate_batch2(tf.slice(y, [view, 0, 0, 0], [1, -1, -1, -1]),
                                tf.reshape(tf.slice(warp_map, [ref, view, 0, 0], [1, 1, -1, 1]), (1, -1)),
                                tf.reshape(tf.slice(warp_map, [ref, view, 0, 1], [1, 1, -1, 1]), (1, -1)))

        list_save.append(tf.reshape(x21ss, (1, h, w, 3)))

        occ_mask_view0 = tf.reshape(list_occ_mask[view][2 * ref + 1], (1, -1, 1))
        list_occ2.append(occ_mask_view0)

  for ref in range(args.view_num):
    for view in range(args.view_num):

      if ref != view:
        f1 = tf.reshape(tf.slice(x, [ref, 0, 0, 0], [1, -1, -1, -1]), (1, -1, 3))
        f2 = list_warped[args.view_num*ref + view]

        w = list_occ2[args.view_num * ref + view]
        warp_loss += warp_layer_loss(f1, f2, w)



        f11 = tf.reshape(f1, (1, 512, 640, 3))
        f12 = tf.reshape(f2, (1, 512, 640, 3))
        w1 = tf.reshape(w, (1, 512, 640, 1))
        x11 = tf.slice(x, [ref, 0, 0, 0], [1, -1, -1, -1])
        x111 = f11
        x112 = w1 * f12 + (1 - w1) * x11

        ff0 = tf.slice(net['relu3_4'], [ref, 0, 0, 0], [1, -1, -1, -1])

        x112 = conv_layer('conv1_1', x112, W=get_weights(vgg_layers, 0))
        x112 = relu_layer('relu1_1', x112, b=get_bias(vgg_layers, 0))
        x112 = conv_layer('conv1_2', x112, W=get_weights(vgg_layers, 2))
        x112 = relu_layer('relu1_2', x112, b=get_bias(vgg_layers, 2))
        x112 = pool_layer('pool1', x112)

        x112 = conv_layer('conv2_1', x112, W=get_weights(vgg_layers, 5))
        x112 = relu_layer('relu2_1', x112, b=get_bias(vgg_layers, 5))
        x112 = conv_layer('conv2_2', x112, W=get_weights(vgg_layers, 7))
        x112 = relu_layer('relu2_2', x112, b=get_bias(vgg_layers, 7))
        x112 = pool_layer('pool2', x112)

        x112 = conv_layer('conv3_1', x112, W=get_weights(vgg_layers, 10))
        x112 = relu_layer('relu3_1', x112, b=get_bias(vgg_layers, 10))
        x112 = conv_layer('conv3_2', x112, W=get_weights(vgg_layers, 12))
        x112 = relu_layer('relu3_2', x112, b=get_bias(vgg_layers, 12))
        x112 = conv_layer('conv3_3', x112, W=get_weights(vgg_layers, 14))
        x112 = relu_layer('relu3_3', x112, b=get_bias(vgg_layers, 14))
        x112 = conv_layer('conv3_4', x112, W=get_weights(vgg_layers, 16))
        pp12 = relu_layer('relu3_4', x112, b=get_bias(vgg_layers, 16))

        stroke_loss += warp_layer_loss2(ff0, pp12)



  net[output_key] = tf.concat(list_save, 0)
  if layer == 'input':
    net['occ_masked'] = tf.concat(list_save_masked, 0)
  return warp_loss, stroke_loss


def interpolate_batch2(image, x, y):
  image_shape = tf.shape(image)
  batch_size = image_shape[0]
  height = image_shape[1]
  width = image_shape[2]

  # image coordinate to pixel coordinate
  x = x - 0.5
  y = y - 0.5
  x0 = tf.cast(tf.floor(x), 'int32')
  x1 = x0 + 1
  y0 = tf.cast(tf.floor(y), 'int32')
  y1 = y0 + 1
  max_y = tf.cast(height - 1, dtype='int32')
  max_x = tf.cast(width - 1, dtype='int32')
  x0 = tf.clip_by_value(x0, 0, max_x)
  x1 = tf.clip_by_value(x1, 0, max_x)
  y0 = tf.clip_by_value(y0, 0, max_y)
  y1 = tf.clip_by_value(y1, 0, max_y)
  b = tf.expand_dims(repeat_int_batch(tf.range(batch_size), height * width), 2)  # B x HW

  indices_a = tf.concat([b, tf.stack([y0, x0], axis=2)], axis=2)
  indices_b = tf.concat([b, tf.stack([y0, x1], axis=2)], axis=2)
  indices_c = tf.concat([b, tf.stack([y1, x0], axis=2)], axis=2)
  indices_d = tf.concat([b, tf.stack([y1, x1], axis=2)], axis=2)

  pixel_values_a = tf.gather_nd(image, indices_a)
  pixel_values_b = tf.gather_nd(image, indices_b)
  pixel_values_c = tf.gather_nd(image, indices_c)
  pixel_values_d = tf.gather_nd(image, indices_d)

  x0 = tf.cast(x0, 'float32')
  x1 = tf.cast(x1, 'float32')
  y0 = tf.cast(y0, 'float32')
  y1 = tf.cast(y1, 'float32')
  area_a = tf.expand_dims(((y1 - y) * (x1 - x)), 2)
  area_b = tf.expand_dims(((y1 - y) * (x - x0)), 2)
  area_c = tf.expand_dims(((y - y0) * (x1 - x)), 2)
  area_d = tf.expand_dims(((y - y0) * (x - x0)), 2)
  output = tf.add_n([area_a * pixel_values_a,
                     area_b * pixel_values_b,
                     area_c * pixel_values_c,
                     area_d * pixel_values_d])
  return output

def repeat_int_batch(x, num_repeats):
  ones = tf.ones((1, num_repeats), dtype='int32')
  x = tf.reshape(x, shape=(-1, 1))
  x = tf.matmul(x, ones)
  return x

def warp_layer_loss(p, x, occ):
  # _, h, w, d = p.get_shape()
  # M = h.value * w.value
  # print(h,w,d)
  _, hw, d = p.get_shape()
  M = hw.value
  N = d.value
  if args.content_loss_function   == 1:
    K = 1. / (2. * N**0.5 * M**0.5)
  elif args.content_loss_function == 2:
    K = 1. / (N * M)
  elif args.content_loss_function == 3:
    K = 1. / 2.
  loss = K * tf.reduce_sum(tf.pow((x - p) * occ, 2))
  return loss


def render_single_image():

  num_virtual_plane = 128
  interval_scale = 1.6
  rgb1 = cv2.imread("./data/dtu/Images/scan6/rect_002_3_r5000.png").astype(np.float32)
  rgb2 = cv2.imread("./data/dtu/Rectified/scan6/rect_003_3_r5000.png").astype(np.float32)
  rgb3 = cv2.imread("./data/dtu/Rectified/scan6/rect_009_3_r5000.png").astype(np.float32)
  cam1 = io.load_cam_dtu(open("./data/dtu/Cameras/00000001_cam.txt"),
                         num_depth=num_virtual_plane,
                         interval_scale=interval_scale).astype(np.float32)
  cam2 = io.load_cam_dtu(open("./data/dtu/Cameras/00000002_cam.txt"),
                         num_depth=num_virtual_plane,
                         interval_scale=interval_scale).astype(np.float32)
  cam3 = io.load_cam_dtu(open("./data/dtu/Cameras/00000008_cam.txt"),
                         num_depth=num_virtual_plane,
                         interval_scale=interval_scale).astype(np.float32)
  depth1 = io.load_pfm("./data/dtu/Depths/scan6/depth_map_0001.pfm")[0].astype(np.float32)
  depth2 = io.load_pfm("./data/dtu/Depths/scan6/depth_map_0002.pfm")[0].astype(np.float32)
  depth3 = io.load_pfm("./data/dtu/Depths/scan6/depth_map_0008.pfm")[0].astype(np.float32)

  init1 = cv2.imread("./dtu/Rectified/scan6_train/rect_002_3_r5000.png").astype(np.float32)
  init2 = cv2.imread("./dtu/Rectified/scan6_train/rect_003_3_r5000.png").astype(np.float32)
  init3 = cv2.imread("./dtu/Rectified/scan6_train/rect_009_3_r5000.png").astype(np.float32)
  img1 = preprocess(rgb1)
  img2 = preprocess(rgb2)
  img3 = preprocess(rgb3)

  # style = cv2.imread("./styles4/antimonocromatismo.jpg").astype(np.float32)
  style = cv2.imread("./styles3/starry-night.jpg").astype(np.float32)
  # style = cv2.imread("./candy/candy.jpg").astype(np.float32)
  # style = cv2.imread("./la_muse/la_muse.jpg").astype(np.float32)
  style = cv2.resize(style, dsize=(640, 512), interpolation=cv2.INTER_AREA)
  style = preprocess(style)

  init1 = preprocess(init1)
  init2 = preprocess(init2)
  init3 = preprocess(init3)

  content_img = np.concatenate([img1,img2,img3], 0)
  style_img = np.concatenate([style, style, style], 0)
  init_img = np.concatenate([init1, init2, init3], 0)
  cams = np.stack([cam1, cam2, cam3])
  depths = np.stack([depth1, depth2, depth3], 0)

  pre_sess = tf.Session()
  with pre_sess.as_default():
    warp0, list_world_pts0 = warp_grid_oog(args.view_num, cams, depths, [args.view_num, 512, 640, 3], 'input')
    warp2, list_world_pts2 = warp_grid_oog(args.view_num, cams, depths, [args.view_num, 128, 160, 3], 'relu3_4')
    list_occ_mask = occluison_mask(list_world_pts2)
  warp0, warp2, list_world_pts0, list_world_pts2, list_occ_mask = pre_sess.run([warp0, warp2, list_world_pts0, list_world_pts2, list_occ_mask])

  with tf.Graph().as_default():
    print('\n---- RENDERING SINGLE IMAGE ----\n')
    # init_img = get_init_image(args.init_img_type, content_img, style_img, init_img)
    # tick = time.time()
    stylize(content_img, style_img, init_img, depths, cams, warp0, warp2, list_world_pts0, list_world_pts2, list_occ_mask)
    # tock = time.time()
    # print('Single image elapsed time: {}'.format(tock - tick))

def main():
  global args
  args = parse_args()
  render_single_image()

if __name__ == '__main__':
  main()
