# Copyright (C) 2021 Project Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Beam pipeline for comparing CodeNet metadata status with computed status.
"""

import logging

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions

from core.data import codenet
from core.data import error_kinds
from core.data.beam_execution import utils

import fire


def _get_statuses(problem_id, submission_id):
  metadata = codenet.get_submission_metadata(problem_id, submission_id)
  codenet_error_status = metadata['status'] in ('Runtime Error', 'Time Limit Exceeded')
  # codenet_error_status: True indicates an error (incl. timeout).

  our_error_kind = codenet.get_submission_error_kind(problem_id, submission_id)
  our_error_status = our_error_kind not in (error_kinds.NO_ERROR, error_kinds.NO_ERROR_WITH_STDERR, error_kinds.NO_DATA)
  # our_error_status: True indicates an error (incl. timeout).
  return codenet_error_status, our_error_status


def _check_matches(codenet_error_status, our_error_status):
  matches = codenet_error_status == our_error_status
  results = [
      ('matches', matches),
      ('raw', (codenet_error_status, our_error_status)),
  ]
  if codenet_error_status:
    results.append(
        ('codenet error', matches),
    )
  else:
    results.append(
        ('no codenet error', matches),
    )
  if our_error_status:
    results.append(
        ('our error', matches),
    )
  else:
    results.append(
        ('no error', matches),
    )
  return results


def run_check_matches(num_problems=4053, **flags):
  problem_ids = [f'p{problem_number:05d}' for problem_number in range(1, num_problems)]

  save_main_session = True
  pipeline_options = PipelineOptions.from_dictionary(flags)
  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
  with beam.Pipeline(options=pipeline_options) as p:
    _ = (
        p
        | 'ProblemIds' >> beam.Create(problem_ids)
        | 'SubmissionIds' >> beam.FlatMap(utils.get_submission_ids)
        | 'Reshuffle' >> beam.Reshuffle()
        | 'Statuses' >> beam.MapTuple(_get_statuses)
        | 'Matches' >> beam.FlatMapTuple(_check_matches)
        | 'Count' >> beam.combiners.Count.PerElement()
        | 'Write' >> beam.io.WriteToText(flags['output'])
    )


def run_check_matches_for_splits(splits='all', **flags):
  save_main_session = True
  pipeline_options = PipelineOptions.from_dictionary(flags)
  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

  split_ids_files = utils.get_split_ids_files(splits)

  with beam.Pipeline(options=pipeline_options) as p:
    _ = (
        p
        | 'IdFiles' >> beam.Create(split_ids_files)
        | 'SubmissionIds' >> beam.FlatMap(utils.get_problem_and_submission_ids)
        | 'Reshuffle' >> beam.Reshuffle()
        | 'Statuses' >> beam.MapTuple(_get_statuses)
        | 'Matches' >> beam.FlatMapTuple(_check_matches)
        | 'Count' >> beam.combiners.Count.PerElement()
        | 'Write' >> beam.io.WriteToText(flags['output'])
    )


if __name__ == '__main__':
  logging.getLogger().setLevel(logging.INFO)
  fire.Fire()
