import argparse
import csv
import json


_MOVE_OPS = ['Move', 'Remove', 'Put']


def reformat_predict_move(input_file, output_file, format):
  """
  Converts data into predicting the next move instead of
  predicting the contents of a box.

  Args:
    jsonl_file: The path to the .jsonl file.

  Returns:
    The path to the reformatted .jsonl file.
  """

  if format == "jsonl":
    with open(input_file, "r") as f:
      jsonl_data = f.readlines()

    with open(output_file, "w") as f:
      for line in jsonl_data:
        json_data = json.loads(line)
        sentence = json_data["sentence"]
        # Skip if there is no operation (initial descriptions)
        if not any([op in sentence for op in _MOVE_OPS]):
          continue
        # Remove description from sentence
        sentence = ". ".join(sentence.split(". ")[:-1]) + "."
        # Remove "Box n" prefix and the final move from prefix
        split_prefix = json_data["prefix"].split(". ")
        prefix = ". ".join(split_prefix[:-2]) + "."
        # Make masked_content the final move
        masked_content = split_prefix[-2]
        new_json = {"sentence": sentence, "prefix": prefix, "masked_content": masked_content}
        f.write(json.dumps(new_json) + "\n")

  elif format == "tsv":
    lines_to_write = []
    with open(input_file, "r") as f:
      reader = csv.DictReader(f, delimiter="\t")
      for row in reader:
        sentence = row["sentence"]
        # Skip if there is no operation (initial descriptions)
        if not any([op in sentence for op in _MOVE_OPS]):
          continue
        # Remove description from sentence
        sentence = ". ".join(sentence.split(". ")[:-1]) + "."
        # Remove "Box n" prefix and the final move from prefix
        prefix = ". ".join(row["prefix"].split(". ")[:-2]) + "."
        new_row = {"sentence": sentence, "prefix": prefix}
        lines_to_write.append(new_row)

    with open(output_file, "w") as f:
      writer = csv.DictWriter(f, fieldnames=["sentence", "prefix"], delimiter="\t")
      writer.writeheader()
      writer.writerows(lines_to_write)

  return output_file
      

if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("--input_file", help="The path to the .jsonl file to be converted.", required=True)
  parser.add_argument("--output_file", help="The output path.", required=True)
  parser.add_argument("--format", help="The format of the input file.", default="tsv")
  args = parser.parse_args()

  if not "condensed" in args.input_file:
    raise ValueError("If you don't use the condensed dataset, you will have duplicate entries!")

  print(f"Wrote converted file to {reformat_predict_move(args.input_file, args.output_file, args.format)}")
