"""Script for picking certain number of samples.
"""

import argparse
import time
import logging
import collections
import tensorflow.compat.v1 as tf

parser = argparse.ArgumentParser(
    description="Eval sample picker for BERT.")
parser.add_argument(
    '--input_tfrecord',
    type=str,
    default='',
    help='Input tfrecord path')
parser.add_argument(
    '--output_tfrecord',
    type=str,
    default='',
    help='Output tfrecord path')
parser.add_argument(
    '--num_examples_to_pick',
    type=int,
    default=10000,
    help='Number of examples to pick')
parser.add_argument(
    '--max_seq_length',
    type=int,
    default=512,
    help='The maximum number of tokens within a sequence.')
parser.add_argument(
    '--max_predictions_per_seq',
    type=int,
    default=76,
    help='The maximum number of predictions within a sequence.')
args = parser.parse_args()

max_seq_length = args.max_seq_length
max_predictions_per_seq = args.max_predictions_per_seq
logging.basicConfig(level=logging.INFO)

def decode_record(record):
  """Decodes a record to a TensorFlow example."""
  name_to_features = {
      "input_ids":
          tf.FixedLenFeature([max_seq_length], tf.int64),
      "input_mask":
          tf.FixedLenFeature([max_seq_length], tf.int64),
      "segment_ids":
          tf.FixedLenFeature([max_seq_length], tf.int64),
      "masked_lm_positions":
          tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
      "masked_lm_ids":
          tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
      "masked_lm_weights":
          tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
      "next_sentence_labels":
          tf.FixedLenFeature([1], tf.int64),
  }

  example = tf.parse_single_example(record, name_to_features)

  # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
  # So cast all int64 to int32.
  for name in list(example.keys()):
    t = example[name]
    if t.dtype == tf.int64:
      t = tf.to_int32(t)
    example[name] = t

  return example


def create_int_feature(values):
  feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
  return feature


def create_float_feature(values):
  feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
  return feature


if __name__ == '__main__':
  tic = time.time()
  tf.enable_eager_execution()

  d = tf.data.TFRecordDataset(args.input_tfrecord)
  num_examples = 0
  records = []
  for record in d:
    records.append(record)
    num_examples += 1

  writer = tf.python_io.TFRecordWriter(args.output_tfrecord)
  i = 0
  pick_ratio = num_examples / args.num_examples_to_pick
  num_examples_picked = 0
  for i in range(args.num_examples_to_pick):
    example = decode_record(records[int(i * pick_ratio)])
    features = collections.OrderedDict()
    features["input_ids"] = create_int_feature(
        example["input_ids"].numpy())
    features["input_mask"] = create_int_feature(
        example["input_mask"].numpy())
    features["segment_ids"] = create_int_feature(
        example["segment_ids"].numpy())
    features["masked_lm_positions"] = create_int_feature(
        example["masked_lm_positions"].numpy())
    features["masked_lm_ids"] = create_int_feature(
        example["masked_lm_ids"].numpy())
    features["masked_lm_weights"] = create_float_feature(
        example["masked_lm_weights"].numpy())
    features["next_sentence_labels"] = create_int_feature(
        example["next_sentence_labels"].numpy())

    tf_example = tf.train.Example(features=tf.train.Features(feature=features))
    writer.write(tf_example.SerializeToString())
    num_examples_picked += 1

  writer.close()
  toc = time.time()
  logging.info("Picked %d examples out of %d samples in %.2f sec",
               num_examples_picked, num_examples, toc - tic)
