"""Tools for reading deepwriting data."""
import argparse
import os

import numpy as np
import deepwriting_data_lib_external



if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Prepare deepwriting data')
  parser.add_argument('--input_dir', type=str, help='Path to input npz')
  parser.add_argument('--output_dir', type=str, help='Path to save output npz')
  args = parser.parse_args()

  print('Loading Deepwriting original datasets...')
  train_dataset = deepwriting_data_lib_external.get_dataset(
      os.path.join(args.input_dir, 'deepwriting_training.npz'))
  validation_dataset = deepwriting_data_lib_external.get_dataset(
      os.path.join(args.input_dir, 'deepwriting_validation.npz'))

  for dataset, name in zip([train_dataset, validation_dataset],
                           ['train', 'valid']):
    # Saving line data
    data = []
    for idx in range(dataset.num_samples):
      if idx % 100 == 0:
        print(f"{idx}/{dataset.num_samples}")
      correct_label, ink = deepwriting_data_lib_external.get_sample_label(
          idx, dataset, use_regex=True, return_ink=True)
      if ink is None:
        continue
      if deepwriting_data_lib_external.non_alphabet(correct_label, dataset):
        continue
      subj = dataset.data_dict['subject_labels'][idx]
      data.append((idx, subj, ink, dataset.undo_preprocess(ink), correct_label))
    print(
        f'Prepared line data: {len(data)} of {dataset.num_samples} labels'
        f' in {name}, saving...')
    with open(
        os.path.join(args.output_dir,
                     f'deepwriting_{name}_lines_cleaned.npz'), 'wb') as f:
      np.savez_compressed(f, data, overwrite=True)
