# Copyright (C) 2021 Project Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Beam pipeline for analyzing CodeNet traces.
"""

import logging
import os

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions

from core.data import codenet
from core.data.beam_execution import utils

import fire


def load_trace(split_and_index, ids):
  split, index = split_and_index
  problem_id, submission_id = ids
  return codenet.load_trace(problem_id, submission_id)


def count_variables_in_trace(trace):
  all_vars = set()
  for value_strs_dict in trace.value_strs_dicts:
    all_vars.update(value_strs_dict.keys())
  return len(all_vars)


def variable_serialization_lengths_in_trace(trace):
  result = []
  for value_strs_dict in trace.value_strs_dicts:
    for key, value_str in value_strs_dict.items():
      result.append(len(value_str))
  return result


def get_types_in_trace(trace):
  result = []
  for type_strs_dict in trace.type_strs_dicts:
    result.extend(type_strs_dict.values())
  return result


def run_analysis_pipeline(splits='all', **flags):
  save_main_session = True
  pipeline_options = PipelineOptions.from_dictionary(flags)
  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

  split_ids_files = utils.get_split_ids_files(splits)
  out_dir = flags['output']

  def write_to_filename(filename):
    return beam.io.WriteToText(os.path.join(out_dir, filename))

  with beam.Pipeline(options=pipeline_options) as p:
    # 1. Compute the Collection of traces
    traces = (
        p
        | 'SplitIdFiles' >> beam.Create(split_ids_files)
        | 'SubmissionIds' >> beam.FlatMapTuple(utils.get_indexed_problem_and_submission_ids)
        | 'Reshuffle' >> beam.Reshuffle()
        | 'LoadTraces' >> beam.MapTuple(load_trace)
        | 'Filter' >> beam.Filter(lambda x: x)
    )
    # Histogram of trace lengths
    _ = (
        traces
        | 'TraceLengths' >> beam.Map(lambda trace: len(trace.linenos))
        | 'CountTraceLengths' >> beam.combiners.Count.PerElement()
        | 'WriteTraceLengths' >> write_to_filename('trace-lengths.txt')
    )
    # Histogram of source character lengths
    _ = (
        traces
        | 'SourceCharLengths' >> beam.Map(lambda trace: len(trace.python_source))
        | 'CountSourceCharLengths' >> beam.combiners.Count.PerElement()
        | 'WriteSourceCharLengths' >> write_to_filename('source-char-lengths.txt')
    )
    # Histogram of source line lengths
    _ = (
        traces
        | 'SourceLineLengths' >> beam.Map(lambda trace: len(trace.python_source.split('\n')))
        | 'CountSourceLineLengths' >> beam.combiners.Count.PerElement()
        | 'WriteSourceLineLengths' >> write_to_filename('source-line-lengths.txt')
    )
    # Number of variables in each trace
    _ = (
        traces
        | 'VarsPerTrace' >> beam.Map(count_variables_in_trace)
        | 'CountVarsPerTrace' >> beam.combiners.Count.PerElement()
        | 'WriteVarsPerTrace' >> write_to_filename('vars-per-trace.txt')
    )
    # Lengths of value-serializations
    _ = (
        traces
        | 'ValueLengths' >> beam.FlatMap(variable_serialization_lengths_in_trace)
        | 'CountValueLengths' >> beam.combiners.Count.PerElement()
        | 'WriteValueLengths' >> write_to_filename('value-lengths.txt')
    )
    # Counts of value types
    _ = (
        traces
        | 'ValueTypes' >> beam.FlatMap(get_types_in_trace)
        | 'CountValueTypes' >> beam.combiners.Count.PerElement()
        | 'WriteValueTypes' >> write_to_filename('type-counts.txt')
    )


if __name__ == '__main__':
  logging.getLogger().setLevel(logging.INFO)
  fire.Fire(run_analysis_pipeline)
