# Copyright 2018 The SPL Authors.
#
# All rights reserved.
#
# This is the code for reproducing results of the paper.
"""Apache beam inference to create tfrecords with pseudo labels."""


import csv
import json
import os
import threading
import time

from absl import app
from absl import flags
from absl import logging

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import numpy as np
import tensorflow as tf

flags.DEFINE_string('input_path', None, 'Input file pattern.')
flags.DEFINE_string('model_path', None, 'Model path.')
flags.DEFINE_string('output_path', None, 'Output file pattern.')
flags.DEFINE_string('proto_key_set', 'image/encoded,image/label,image/prob',
                    'Proto key name for <image,label,probability>.')
flags.DEFINE_integer('num_worker', 32, 'Number of worker to execute.')
flags.DEFINE_string('direct_running_mode', 'multi_processing',
                    'Number of worker to execute.')
flags.DEFINE_boolean('process_together', False,
                     'Do all steps in one apache program.')

FLAGS = flags.FLAGS
VERBOSE_INTERVAL = 5000
TEST_NUMCLASSES = 14


def create_pipeline(preprocess_image,
                    model_creator,
                    num_classes,
                    input_path,
                    output_path,
                    model_path,
                    reader=beam.io.ReadFromTFRecord,
                    writter=beam.io.WriteToTFRecord,
                    **kwargs):
  """Create beam inference pipeline."""
  del kwargs

  def _save_array_to_csv(input_matrix, output_csv_path):
    with tf.io.gfile.GFile(output_csv_path, 'w') as fid:
      writer = csv.writer(fid, delimiter=',')
      for i in range(input_matrix.shape[0]):
        writer.writerow(list(input_matrix[i, :]))

  def _save_stats(conf_mtrix, output_json_path):
    """Save confusion matrix."""
    # Rows are predictions and columns are labels.
    # conf_mtrix[:,1] is prdictions with label 1.
    class_count = np.sum(conf_mtrix, axis=0)
    hit = np.diag(conf_mtrix)
    class_accuracy = np.true_divide(hit, class_count).tolist()
    mean_accuracy = np.mean(class_accuracy)
    stats = {
        'class_accuracy': class_accuracy,
        'mean_accuracy': float(mean_accuracy)
    }
    with tf.io.gfile.GFile(output_json_path, 'w') as fid:
      json.dump(stats, fid, indent=4)

  class DatasetStatisticsFn(beam.CombineFn):
    """Calculate dataset statistics."""

    def __init__(self, num_class, output_path):
      self._num_class = int(num_class)
      self._output_path = output_path
      self.image_key, self.label_key, self.prob_key = FLAGS.proto_key_set.split(
          ',')

    def create_accumulator(self):
      accumulator = [
          np.zeros((self._num_class, self._num_class), dtype='int64')
      ]
      return accumulator

    def add_input(self, statistics, element):
      if model_creator is not None:
        prob = np.array(
            element.features.feature[self.prob_key].float_list.value)
        pred = int(np.argmax(prob))
      else:
        # read from tfrecord
        if self.prob_key in element.features.feature:
          prob = np.array(
              element.features.feature[self.prob_key].float_list.value)
          pred = int(np.argmax(prob))
        else:
          logging.warning('probe_key not exists in tfExample')
          pred = element.features.feature[self.label_key].int64_list.value[0]
      label = element.features.feature[self.label_key].int64_list.value[0]

      statistics[0][pred, label] += 1
      return statistics

    def merge_accumulators(self, accumulators):
      statistics = [np.zeros((self._num_class, self._num_class), dtype='int64')]
      for accumulator in accumulators:
        for i in range(self._num_class):
          for j in range(self._num_class):
            statistics[0][i, j] += accumulator[0][i, j]
      return statistics

    def extract_output(self, accumulator):
      _save_array_to_csv(
          accumulator[0],
          os.path.join(self._output_path, 'confusion_matrix.json'))
      _save_stats(accumulator[0], os.path.join(self._output_path, 'stats.json'))
      return accumulator

  class InferenceFn(beam.DoFn):
    """Beam inference."""
    session_lock = threading.Lock()

    def __init__(self, model_path):
      self._model_path = model_path
      self.image_key, self.label_key, self.prob_key = FLAGS.proto_key_set.split(
          ',')
      self._count = 0
      self._start_time = time.time()

    def start_bundle(self, **kwargs):
      self._load_inference_model()

    def _load_inference_model(self):
      if model_creator is not None:
        self._model = tf.keras.models.load_model(
            os.path.join(self._model_path, 'model.h5'))
      else:
        logging.warning('model_creator is None. Only rewirte tfrecords')

    def process(self, keyed_serialized_example):
      return self._run_inference_and_generate_detections(
          keyed_serialized_example)

    def preprocess(self, image_bytes):
      return preprocess_image(image_bytes=image_bytes)

    def _run_inference_and_generate_detections(self, keyed_serialized_example):
      if model_creator is not None:  # otherwise return original tf example
        image_bytes = keyed_serialized_example.features.feature[
            self.image_key].bytes_list.value[0]
        # with self._sess.graph.as_default():
        image = self.preprocess(image_bytes)['image']
        logits = self._model.predict(tf.expand_dims(image, 0), steps=1)
        prob_np = tf.nn.softmax(tf.cast(logits, tf.float32)).numpy()
        # add probability to examples
        keyed_serialized_example.features.feature[
            self.prob_key].float_list.value.extend(prob_np.ravel().tolist())
        self._count += 1
        if self._count % VERBOSE_INTERVAL == 0:
          logging.info(
              'Finished {} inference tf examples. Time cost {:.2f}min'.format(
                  self._count, (time.time() - self._start_time) / 60))
      return [keyed_serialized_example]

  def pipeline(root, mode='all', suffix=''):
    file_pattern = os.path.basename(input_path).rstrip('*')
    output_pattern = os.path.join(output_path, file_pattern)
    n_shared = len(tf.io.gfile.glob(input_path))
    filenames = tf.io.gfile.glob(input_path)
    nums = sum(1 for filename in filenames  
               for _ in tf.compat.v1.io.tf_record_iterator(filename))
    logging.info('-' * 100)
    logging.info('Process {} TF examples and save to {} tfrecords'.format(
        nums, n_shared))
    logging.info('-' * 100)
    p = (
        root
        | ('ReadExample{}'.format(suffix)) >> reader(
            input_path, coder=beam.coders.ProtoCoder(tf.train.Example))
        | ('inference{}'.format(suffix)) >> beam.ParDo(InferenceFn(model_path)))
    if mode in ('all', 'write'):
      _ = p | ('WriteExample{}'.format(suffix)) >> writter(
          output_pattern,
          coder=beam.coders.ProtoCoder(tf.train.Example),
          num_shards=n_shared)
    if mode in ('all', 'stat'):
      _ = p | ('GetDatasetStatistics{}'.format(suffix)) >> beam.CombineGlobally(
          DatasetStatisticsFn(num_classes, output_path))

  return pipeline


def run(process_image, model_creator, num_classes, **kwargs):

  if not tf.io.gfile.isdir(FLAGS.output_path):
    tf.io.gfile.makedirs(FLAGS.output_path)
  # save as keras conventional model for processes to load
  if model_creator:
    model = model_creator()
    name = kwargs['checkpoint_model_name']
    checkpoint = tf.train.Checkpoint(**{name: model})
    checkpoint.restore(FLAGS.model_path).expect_partial()
    tf.keras.models.save_model(model.keras_model,
                               os.path.join(FLAGS.output_path, 'model.h5'))

  if FLAGS.process_together:
    pipeline_options = PipelineOptions([
        '--direct_num_workers',
        str(FLAGS.num_worker), '--direct_running_mode',
        FLAGS.direct_running_mode
    ])
    p = beam.Pipeline(options=pipeline_options)
    pipeline = create_pipeline(process_image, model_creator, num_classes,
                               FLAGS.input_path, FLAGS.output_path,
                               FLAGS.model_path, **kwargs)
    pipeline(p)
    result = p.run()
    result.wait_until_finish()
  else:
    # apache_beam has internal bugs ot process real large tfrecord files,
    # so seperate pipelines for each file.
    # 1. gather statistics
    # 2. process tfrecord one by one
    for i, single_f in enumerate(sorted(tf.io.gfile.glob(FLAGS.input_path))):
      output_f = os.path.join(FLAGS.output_path, os.path.basename(single_f))
      if tf.io.gfile.glob(output_f + '*'):
        logging.info('tfrecord file {} existed'.format(output_f))
        continue
      # We use the save_main_session option because one or more DoFn's in this
      # workflow rely on global context (e.g.a module imported at module level).
      pipeline_options = PipelineOptions([
          '--direct_num_workers',
          str(FLAGS.num_worker), '--direct_running_mode',
          FLAGS.direct_running_mode
      ])
      p = beam.Pipeline(options=pipeline_options)

      pipeline = create_pipeline(process_image, model_creator, num_classes,
                                 single_f, FLAGS.output_path, FLAGS.model_path,
                                 **kwargs)
      pipeline(p, mode='write', suffix=str(i))
      result = p.run()
      result.wait_until_finish()
      logging.info('Processed file {}'.format(output_f))

    # Set it to None so the pipeline do not exec inference
    model_creator = None
    pipeline_options = PipelineOptions([
        '--direct_num_workers',
        str(FLAGS.num_worker), '--direct_running_mode',
        FLAGS.direct_running_mode
    ])
    p = beam.Pipeline(options=pipeline_options)
    # read the written tfrecords
    input_path = os.path.join(FLAGS.output_path,
                              os.path.basename(FLAGS.input_path))
    pipeline = create_pipeline(process_image, model_creator, num_classes,
                               input_path, FLAGS.output_path, FLAGS.model_path,
                               **kwargs)
    pipeline(p, mode='stat', suffix='stat')
    result = p.run()
    result.wait_until_finish()


def main(_):
  # Used for rewrite tfrecords and debug
  run(lambda x: x, None, TEST_NUMCLASSES)


if __name__ == '__main__':
  flags.mark_flags_as_required(['input_path', 'model_path', 'output_path'])
  app.run(main)
