"""Runs Deepwriting model."""
import argparse

import deepwriting_data_lib_external
import deepwriting_model_lib_external
import numpy as np

if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Prepare deepwriting data')
  parser.add_argument('--input_npz', type=str, help='Original data input npz')
  parser.add_argument('--cleaned_npz', type=str, help='Cleaned data input npz')
  parser.add_argument('--saved_model', type=str, help='Path to Saved model')
  parser.add_argument(
      '--corrupted_labels', type=str, help='Corrupted labels txt')
  parser.add_argument('--repeats', type=int, help='Number of times to repeat')
  parser.add_argument('--output_npz', type=str, help='Path to output npz')
  parser.add_argument('--similarity', type=int, help='Similarity (0 or 1)')
  args = parser.parse_args()

  print('Loading Deepwriting original datasets...')
  validation_dataset = deepwriting_data_lib_external.get_dataset(args.input_npz)

  print('Loading Deepwriting saved model...')
  model, inference_model, session = deepwriting_model_lib_external.load_models(
      model_dir=args.saved_model,
      validation_dataset=validation_dataset)

  print('Loading cleaned validation data for Spelling Correction')
  with open(args.cleaned_npz, 'rb') as f:
    valid_data = np.load(f, allow_pickle=True).get('arr_0')

  print('Loading file of corrupted labels')
  with open(args.corrupted_labels, 'r') as f:
    corrupted_labels = [
        ' '.join(line.strip().split(' ')[1:]) for line in f.readlines()
    ]

  data = []
  for _ in range(args.repeats):
    for idx in range(len(valid_data)):
      print(f"{idx}/{len(valid_data)}")
      result = deepwriting_model_lib_external.run_model(
          model=model,
          inference_model=inference_model,
          sess=session,
          validation_dataset=validation_dataset,
          valid_data_sample=valid_data[idx],
          new_text=corrupted_labels[idx],
          similarity=args.similarity)
      data.append((valid_data[idx], result, corrupted_labels[idx]))

  with open(args.output_npz, 'wb') as f:
    np.savez_compressed(f, data, overwrite=True)
