# Copyright (C) 2021 Project Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Beam pipeline for analyzing CodeNet traces.
"""

import functools
import logging
import os

import apache_beam as beam
from apache_beam.io import tfrecordio
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions

from core.data import codenet
from core.data import codenet_paths
from core.data import data_io
from core.data import tokenization
from core.data import trace_features
from core.data.beam_execution import utils
from scripts import process_codenet  # TODO: migrate.

import fire

DEFAULT_TOKENIZER_PATH = codenet_paths.DEFAULT_TOKENIZER_PATH


def make_example(split_and_index, ids, tokenizer):
  split, index = split_and_index
  problem_id, submission_id = ids
  trace = codenet.load_trace(problem_id, submission_id)

  runtime_error_problem = process_codenet.process_codenet_single_example(
      tokenizer, problem_id, submission_id)
  if runtime_error_problem:
    example = data_io.to_tf_example(runtime_error_problem, trace=trace)
    return example


def render_trace(split_and_index, ids):
  split, index = split_and_index
  problem_id, submission_id = ids
  trace = codenet.load_trace(problem_id, submission_id)
  if trace:
    with open(f'out/trace-{problem_id}-{submission_id}.txt', 'w') as f:
      def write(x=''):
        f.write(x + '\n')
      trace_features.render(trace, print_fn=write)


def run_dataset_pipeline(splits='all', tokenizer_path=DEFAULT_TOKENIZER_PATH, **flags):
  save_main_session = True
  pipeline_options = PipelineOptions.from_dictionary(flags)
  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

  split_ids_files = utils.get_split_ids_files(splits)
  output_path = flags['output']

  def write_to_filename(filename):
    return beam.io.WriteToText(os.path.join(output_path, filename))

  tokenizer = tokenization.load_tokenizer(path=tokenizer_path)

  with beam.Pipeline(options=pipeline_options) as p:
    _ = (
        p
        | 'SplitIdFiles' >> beam.Create(split_ids_files)
        | 'SubmissionIds' >> beam.FlatMapTuple(utils.get_indexed_problem_and_submission_ids)
        | 'Reshuffle' >> beam.Reshuffle()
        | 'MakeExamples' >> beam.MapTuple(functools.partial(make_example, tokenizer=tokenizer))
        | 'Filter' >> beam.Filter(lambda x: x)
        | 'SerializeToString' >> beam.Map(lambda x: x.SerializeToString())
        | 'Write' >> tfrecordio.WriteToTFRecord(output_path, file_name_suffix='.tfrecord.gz')
    )


if __name__ == '__main__':
  logging.getLogger().setLevel(logging.INFO)
  fire.Fire(run_dataset_pipeline)
