# Copyright (C) 2021 Project Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Beam pipeline for analyzing CodeNet traces.
"""

import functools
import logging
import os
import subprocess

import apache_beam as beam
from apache_beam.io import tfrecordio
from apache_beam.io.gcp import gcsio
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions

import fire

from core.data import codenet
from core.data import codenet_paths
from core.data import data_io
from core.data import tokenization
from core.data import trace_features
from core.data.beam_execution import utils
from scripts import process_codenet  # TODO: migrate.


DEFAULT_TOKENIZER_PATH = codenet_paths.DEFAULT_TOKENIZER_PATH
PIPE = subprocess.PIPE


def run_pylint(split_and_index, ids, out_path=codenet_paths.OUT_ROOT):
  split, index = split_and_index
  problem_id, submission_id = ids
  submission_data = codenet.get_submission_data(problem_id, submission_id)
  pylint_process = subprocess.run(
      ['python3', '-m', 'pylint', '--from-stdin', submission_id],
      input=submission_data['python_source'].encode('utf-8'),
      stdout=PIPE,
      stderr=PIPE,
  )
  pylint_stdout = pylint_process.stdout
  pylint_stderr = pylint_process.stderr

  # write to gs://project-codenet-storage/out/pylint/
  lint_path = codenet.get_lint_path(problem_id, submission_id, out_path=out_path)
  lint_stderr_path = codenet.get_lint_stderr_path(problem_id, submission_id, out_path=out_path)
  gcsio_client = gcsio.GcsIO()
  with gcsio_client.open(lint_path, 'wb') as f:
    f.write(pylint_stdout)
  with gcsio_client.open(lint_stderr_path, 'wb') as f:
    f.write(pylint_stderr)

  return pylint_stdout


def run_pipeline(splits='all', **flags):
  save_main_session = True
  pipeline_options = PipelineOptions.from_dictionary(flags)
  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

  split_ids_files = utils.get_split_ids_files(splits)
  output_path = flags['output']

  with beam.Pipeline(options=pipeline_options) as p:
    _ = (
        p
        | 'SplitIdFiles' >> beam.Create(split_ids_files)
        | 'SubmissionIds' >> beam.FlatMapTuple(utils.get_indexed_problem_and_submission_ids)
        | 'Reshuffle' >> beam.Reshuffle()
        | 'Lint' >> beam.MapTuple(functools.partial(run_pylint, out_path=output_path))
        | 'Write' >> tfrecordio.WriteToTFRecord(output_path, file_name_suffix='.tfrecord.gz')
    )


if __name__ == '__main__':
  logging.getLogger().setLevel(logging.INFO)
  fire.Fire(run_pipeline)
