# Copyright (C) 2021 Project Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This is a pipeline for processing CodeNet data.
"""

from typing import Any

import dataclasses
import functools
import json
import logging
import numpy as np
import os
import sys

import apache_beam as beam
from apache_beam.io.gcp import gcsio
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions

from python_graphs import control_flow
from python_graphs import program_graph

from core.data import codenet
from core.data.beam_execution import program_graph_analysis
from core.data.beam_execution import unparser_patch  # pylint: disable=unused-import
from core.data.beam_execution import utils

import fire


def cyclomatic_complexity(control_flow_graph):
  """Computes the cyclomatic complexity of a program from its cfg."""
  start_block = control_flow_graph.start_block
  enter_blocks = control_flow_graph.get_enter_blocks()

  new_blocks = [start_block]
  seen_block_ids = {id(start_block)}
  num_connected_components = 1
  num_edges = 0

  for enter_block in enter_blocks:
    new_blocks.append(enter_block)
    seen_block_ids.add(id(enter_block))
    num_connected_components += 1

  while new_blocks:
    block = new_blocks.pop()
    for next_block in block.exits_from_end:
      num_edges += 1
      if id(next_block) not in seen_block_ids:
        new_blocks.append(next_block)
        seen_block_ids.add(id(next_block))
  num_nodes = len(seen_block_ids)

  p = num_connected_components
  e = num_edges
  n = num_nodes
  return e - n + 2 * p


def tb_as_str(traceback, limit=20):
  error = ''
  i = 0
  while traceback and i < limit:
    i += 1
    error = f'{error}, {os.path.basename(traceback.tb_frame.f_code.co_filename)}:{traceback.tb_frame.f_lineno}'
    traceback = traceback.tb_next
  return error


def get_percentiles(data, percentiles, integer_valued=True):
  """Returns a dict of percentiles of the data.
  Args:
    data: An unsorted list of datapoints.
    percentiles: A list of ints or floats in the range [0, 100] representing the
      percentiles to compute.
    integer_valued: Whether or not the values are all integers. If so,
      interpolate to the nearest datapoint (instead of computing a fractional
      value between the two nearest datapoints).
  Returns:
    A dict mapping each element of percentiles to the computed result.
  """
  # Ensure integer datapoints for cleaner binning if necessary.
  interpolation = 'nearest' if integer_valued else 'linear'
  results = np.percentile(data, percentiles, interpolation=interpolation)
  cast = int if integer_valued else float
  return {percentiles[i]: cast(results[i]) for i in range(len(percentiles))}


def analyze_graph(graph):
  """Performs various analyses on a graph.

  Args:
    graph: A ProgramGraph to analyze.

  Returns:
    A result_dict containing the results of analyses run on the graph.
  """
  num_nodes = program_graph_analysis.num_nodes(graph)
  num_edges = program_graph_analysis.num_edges(graph)
  ast_height = program_graph_analysis.graph_ast_height(graph)

  degree_percentiles = [10, 25, 50, 75, 90]
  degrees_list = program_graph_analysis.degrees(graph)
  max_degree = int(np.max(degrees_list))
  mean_degree = float(np.mean(degrees_list))
  degrees = get_percentiles(degrees_list, degree_percentiles)
  in_degrees = get_percentiles(program_graph_analysis.in_degrees(graph),
                               degree_percentiles)
  out_degrees = get_percentiles(program_graph_analysis.out_degrees(graph),
                                degree_percentiles)

  if num_nodes < 5000:
    diameter = program_graph_analysis.diameter(graph)
    max_betweenness = program_graph_analysis.max_betweenness(graph)
  else:
    diameter = None
    max_betweenness = None


  result_dict = {
      'num_nodes': num_nodes,
      'num_edges': num_edges,
      'ast_height': ast_height,
      'degrees': degrees,
      'in_degrees': in_degrees,
      'out_degrees': out_degrees,
      'diameter': diameter,
      'max_betweenness': max_betweenness,
      'max_degree': max_degree,
      'mean_degree': mean_degree,
  }
  return result_dict


@dataclasses.dataclass
class GraphData:
  source: Any
  cfg_error: Any
  cfg_num_nodes: Any
  cfg_num_blocks: Any
  pg_error: Any
  pg_num_nodes: Any
  pg_edge_counts: Any
  pg_analysis: Any
  cc: Any
  cc_error: Any
  length: Any
  metadata: Any


def make_graph_data(problem_id, submission_id):
  python_path = codenet.get_python_path(problem_id, submission_id)
  source = codenet.read(python_path)

  metadata = codenet.get_submission_metadata(problem_id, submission_id)

  try:
    cfg = control_flow.get_control_flow_graph(source)
    cfg_error = None
  except Exception as e:
    cfg = None
    _, _, traceback = sys.exc_info()
    cfg_error = tb_as_str(traceback)
    cfg_error = f'{type(e).__name__} at {cfg_error}'

  if cfg is not None:
    cfg_num_nodes = len(cfg.nodes)
    cfg_num_blocks = len(cfg.blocks)
  else:
    cfg_num_nodes = None
    cfg_num_blocks = None

  try:
    pg = program_graph.get_program_graph(source)
    pg_error = None
  except Exception as e:
    pg = None
    _, _, traceback = sys.exc_info()
    pg_error = tb_as_str(traceback, limit=7)
    pg_error = f'{type(e).__name__} at {pg_error}'

  # Analyze program graphs here.
  if pg is not None:
    pg_analysis = analyze_graph(pg)
  else:
    pg_analysis = None

  pg_edge_counts = get_pg_edge_counts(pg)
  if pg is not None:
    pg_num_nodes = len(pg.nodes)
  else:
    pg_num_nodes = None

  try:
    cc = cyclomatic_complexity(cfg)
    cc_error = None
  except Exception as e:
    cc = None
    _, _, traceback = sys.exc_info()
    cc_error = tb_as_str(traceback)
    cc_error = f'{type(e).__name__} at {cc_error}'

  length = len([line for line in source.split('\n') if line])
  return GraphData(
      source=source,
      cfg_error=cfg_error,
      cfg_num_nodes=cfg_num_nodes,
      cfg_num_blocks=cfg_num_blocks,
      pg_error=pg_error,
      pg_num_nodes=pg_num_nodes,
      pg_edge_counts=pg_edge_counts,
      pg_analysis=pg_analysis,
      cc=cc,
      cc_error=cc_error,
      length=length,
      metadata=metadata,
  )


def get_pg_edge_counts(pg):
  if pg is None:
    return []

  edge_counts = {}
  for edge in pg.edges:
    if edge.type not in edge_counts:
      edge_counts[edge.type] = 0
    edge_counts[edge.type] += 1
  return list(edge_counts.items())


def get_cfg_error(graph_data):
  if graph_data.cfg_error:
    return [graph_data.cfg_error]


def get_pg_error(graph_data):
  if graph_data.pg_error:
    return [graph_data.pg_error]


def add_pg_error_as_key(graph_data):
  if graph_data.pg_error:
    return [(graph_data.pg_error, graph_data)]


def get_cc_and_length(graph_data):
  if not graph_data.cfg_error:
    return [(graph_data.cc, graph_data.length, graph_data.metadata['status'])]


def get_cc_error(graph_data):
  if graph_data.cc_error:
    return [graph_data.cc_error]


def first(*args):
  return args[0]


def run_pipeline(**flags):
  save_main_session = True
  pipeline_options = PipelineOptions.from_dictionary(flags)
  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

  ids_file_prefix = flags['ids_file_prefix']
  out_dir = flags['output']

  def write_to_filename(filename):
    return beam.io.WriteToText(os.path.join(out_dir, filename))

  gcsio_client = gcsio.GcsIO()
  ids_files = gcsio_client.list_prefix(ids_file_prefix).keys()

  with beam.Pipeline(options=pipeline_options) as p:
    graph_data = (  # Compute all the data.
        p
        | 'IdsFiles' >> beam.Create(ids_files)
        | 'ProblemIdsAndSubmissionIds' >> beam.FlatMap(utils.get_problem_and_submission_ids_from_tuples)
        | 'Reshuffle' >> beam.Reshuffle()
        | 'MakeGraphData' >> beam.MapTuple(make_graph_data)
    )
    _ = (
        graph_data
        | 'GetPGAnalysis' >> beam.FlatMap(lambda x: [json.dumps(x.pg_analysis)] if x.pg_analysis is not None else None)
        | 'WritePGAnalysis' >> write_to_filename('pg-analysis.jsonl')
    )
    _ = (  # The number of programs exhibiting each edge type.
        graph_data
        | 'GetPgEdgeCountKeys' >> beam.FlatMap(lambda x: [edge_count[0] for edge_count in x.pg_edge_counts if edge_count[1] > 0])
        | 'CountTotalProgramsPGEdgeCounts' >> beam.combiners.Count.PerElement()
        | 'WritePGEdgeProgramCounts' >> write_to_filename('pg-edge-program-counts.txt')
    )
    _ = (  # The total number of occurrences of each edge type.
        graph_data
        | 'GetPgEdgeCounts' >> beam.FlatMap(lambda x: x.pg_edge_counts)
        | 'CountTotalPGEdgeCounts' >> beam.CombinePerKey(sum)
        | 'WritePGEdgeValueCounts' >> write_to_filename('pg-edge-value-counts.txt')
    )
    _ = (  # Histogram of CFG num nodes.
        graph_data
        | 'GetCFGNumNodes' >> beam.FlatMap(lambda x: [x.cfg_num_nodes] if x.cfg_num_nodes is not None else None)
        | 'CFGNumNodesHistogram' >> beam.combiners.Count.PerElement()
        | 'WriteCFGNumNodesHistogram' >> write_to_filename('cfg-num-nodes-histogram.txt')
    )
    _ = (  # Histogram of CFG num blocks.
        graph_data
        | 'GetCFGNumBlocks' >> beam.FlatMap(lambda x: [x.cfg_num_blocks] if x.cfg_num_blocks is not None else None)
        | 'CFGNumBlocksHistogram' >> beam.combiners.Count.PerElement()
        | 'WriteCFGNumBlocksHistogram' >> write_to_filename('cfg-num-blocks-histogram.txt')
    )
    _ = (  # Histogram of PG num nodes.
        graph_data
        | 'GetPGNumNodes' >> beam.FlatMap(lambda x: [x.pg_num_nodes] if x.pg_num_nodes is not None else None)
        | 'PGNumNodesHistogram' >> beam.combiners.Count.PerElement()
        | 'WritePGNumNodesHistogram' >> write_to_filename('pg-num-nodes-histogram.txt')
    )
    _ = (  # The number of occurrences of each CFG error.
        graph_data
        | 'FilterJustCfgErrors' >> beam.FlatMap(get_cfg_error)
        | 'CountCfgErrorTypes' >> beam.combiners.Count.PerElement()
        | 'WriteCfgErrorTypes' >> write_to_filename('cfg-error-type-counts.txt')
    )
    _ = (  # The number of occurrences of each PG error.
        graph_data
        | 'FilterJustPgErrors' >> beam.FlatMap(get_pg_error)
        | 'CountPgErrorTypes' >> beam.combiners.Count.PerElement()
        | 'WritePgErrorTypes' >> write_to_filename('pg-error-type-counts.txt')
    )
    _ = (  # For each PG error, an example.
        graph_data
        | 'KeyJustPgErrors' >> beam.FlatMap(add_pg_error_as_key)
        | 'GetPgErrorExample' >> beam.CombinePerKey(first)
        # | 'PgErrorKeepSource' >> beam.Map(lambda x: (x[0], x[1].source))
        | 'WritePgErrorExamples' >> write_to_filename('pg-error-type-examples.txt')
    )
    _ = (  # Counts for (cyclomatic_complexity, LOC) pairs.
        graph_data
        | 'FilterJustGraphs' >> beam.FlatMap(get_cc_and_length)
        | 'CountCCLengthValues' >> beam.combiners.Count.PerElement()
        | 'WriteCCLengthValueCounts' >> write_to_filename('cc-length-counts.txt')
    )
    _ = (  # The number of occurrences for each CC error.
        graph_data
        | 'FilterJustCcErrors' >> beam.FlatMap(get_cc_error)
        | 'CountCcErrorTypes' >> beam.combiners.Count.PerElement()
        | 'WriteCcErrorTypes' >> write_to_filename('cc-error-type-counts.txt')
    )
    # _ = (
    #     graph_data
    #     | 'FilterJustCcErrors' >> beam.FlatMap(get_pg_)
    #     | 'CountCcErrorTypes' >> beam.combiners.Count.PerElement()
    #     | 'WriteCcErrorTypes' >> write_to_filename('pg-type-frequencies.txt')
    # )


if __name__ == '__main__':
  logging.getLogger().setLevel(logging.INFO)
  fire.Fire(run_pipeline)
