# Copyright (C) 2021 Project Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Beam pipeline for tracing CodeNet submissions in parallel.
"""

import logging

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions

from core.data import codenet
from core.data.beam_execution import utils

import fire


def run_codenet_submissions(splits='all', **flags):
  save_main_session = True
  pipeline_options = PipelineOptions.from_dictionary(flags)
  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

  split_ids_files = utils.get_split_ids_files(splits)

  with beam.Pipeline(options=pipeline_options) as p:
    _ = (
        p
        | 'IdFiles' >> beam.Create(split_ids_files)
        | 'SubmissionIds' >> beam.FlatMap(utils.get_problem_and_submission_ids)
        | 'Reshuffle' >> beam.Reshuffle()
        | 'Run' >> beam.MapTuple(codenet.run_for_traces)
    )


if __name__ == '__main__':
  logging.getLogger().setLevel(logging.INFO)
  fire.Fire(run_codenet_submissions)
