# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to create text versions of CLRS data."""
from typing import Any, Optional
from clrs._src import samplers
from clrs._src import specs
import numpy as np


CLRS_TASKS_WITH_HINTS = tuple(
    [
        'activity_selector',
        'articulation_points',
        'bellman_ford',
        'bfs',
        'binary_search',
        'bridges',
        'bubble_sort',
        'dag_shortest_paths',
        'dfs',
        'dijkstra',
        'find_maximum_subarray_kadane',
        'floyd_warshall',
        'graham_scan',
        'heapsort',
        'insertion_sort',
        'jarvis_march',
        'kmp_matcher',
        'lcs_length',
        'matrix_chain_order',
        'minimum',
        'mst_kruskal',
        'mst_prim',
        'naive_string_matcher',
        'optimal_bst',
        'quickselect',
        'quicksort',
        'strongly_connected_components',
        'task_scheduling',
        'topological_sort',
    ],
)
CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER = {
    'naive_string_matcher': 's',
    'kmp_matcher': 's',
}
CLRS_SEARCH_TAKS_OUTPUT_REPLACER = {
    'binary_search': ['low', 'high'],
    'find_maximum_subarray_kadane': ['best_low', 'best_high'],
    'quickselect': ['pivot'],
}
CLRS_PARENTHESES_TRACES = frozenset(
    {'binary_search', 'find_maximum_subarray_kadane'}
)
CLRS_SORTING_TASKS = ['bubble_sort', 'heapsort', 'insertion_sort', 'quicksort']

DEFAULT_SEPARATOR = ', '
INPUT_TRACE_MARKER = 'initial_trace:'
TRACE_ANSWER_SEPARATOR = ' | '
OUTPUT_TRACE_MARKER = 'trace'
PERMUTATION_SEPARATOR = '->'
SEQUENCE_SEPARATOR = ' '

_HINT_PREFIX = '_h'


def format_clrs_example(
    algo: str,
    sample: samplers.Feedback,
    use_hints: bool = False,
) -> tuple[str, str]:
  """Formats CLRS example into prompt for the LLM.

  Args:
    algo: Name of the algorithm the sample comes from.
    sample: A sample generated by a CLRS sampler.
    use_hints: if True the initial CLRS hint is added to the input, the rest of
      to the output.

  Returns:
    The question and answer prompts.
  """
  input_, output_names, output, hints_added = sample_to_str(
      algo=algo,
      sample=sample,
      use_hints=use_hints,
  )
  if hints_added:
    output_name_str = TRACE_ANSWER_SEPARATOR.join(
        [OUTPUT_TRACE_MARKER, output_names]
    )
  else:
    output_name_str = output_names

  question = f'{algo}:\n{input_}\n{output_name_str}:\n'
  answer = f'{output}\n\n'

  return question, answer


def _get_output_names(
    algo_name: str,
    spec: specs.Spec,
    use_hints: bool,
) -> list[str]:
  """Gets the output names for a CLRS algorithm."""
  if algo_name in CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER and use_hints:
    return [CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER[algo_name]]
  elif algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER and use_hints:
    return CLRS_SEARCH_TAKS_OUTPUT_REPLACER[algo_name]
  else:
    return [
        spec_name
        for spec_name in spec
        if spec[spec_name][0] == specs.Stage.OUTPUT
    ]


def _get_output_str(
    sample: samplers.Feedback, spec, algo_name: str, use_hints: bool
) -> list[str]:
  """Gets the output string for a CLRS algorithm."""
  if algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER and use_hints:
    output_results = []
    spec_names = CLRS_SEARCH_TAKS_OUTPUT_REPLACER[algo_name]
    for spec_name in spec_names:
      x = _get_feature_by_name(sample.features.hints, spec_name).data[-1]
      output_results.append(
          _feature_to_str(
              name=spec_name,
              spec=spec,
              x=x,
              with_name=False,
              inputs=sample.features.inputs,
          )
      )
    return [DEFAULT_SEPARATOR.join(output_results)]
  else:
    return _create_output_feature_strs(
        spec=spec,
        inputs=sample.features.inputs,
        outputs=sample.outputs,
    )


def sample_to_str(
    algo: str,
    sample: samplers.Feedback,
    use_hints: bool = False,
) -> tuple[str, str, str, bool]:
  """Converts a CLRS sample into input and output strings.

    Output examples without hints:
      1. insertion_sort
          input_str = 'key: [0.549 0.715 0.603 0.545 0.424]'
          output_names_strs = 'pred'
          output_str = '[0.424 0.545 0.549 0.603 0.715]'
      2. find_maximum_subarray
          input_str = 'key: [0.098 0.43 0.206 0.09 -0.153]'
          output_names_strs = 'start, end'
          output_str = '0, 3'
      3. binary_search
          input_str = 'key: [0.424 0.545 0.549 0.603 0.715], target: 0.646'
          output_names_strs = 'return'
          output_str = '4'

    Output examples with hints:
      1. insertion_sort
          input_str = 'insertion_sort: key: [0.549 0.715 0.603], initial_trace:
          [0.549 0.715 0.603] trace | pred:'
          output_names_strs = 'pred'
          output_str = '[0.549 0.715 0.603] | [0.549 0.603 0.715]'
      2. find_maximum_subarray
          input_str = 'find_maximum_subarray_kadane: key: [0.098 0.43 0.206],
          initial_trace: (0, 0) trace | (best_low, best_high):'
          output_names_strs = 'best_low, best_high'
          output_str = '(0, 1) | (0, 2)'
      3. binary_search
          input_str = 'binary_search: key: [0.549 0.603 0.715], target: 0.545,
          initial_trace: (0, 2) trace | (low, high):'
          output_names_strs = 'return'
          output_str = '(0, 1) | (0, 0)'

    For more details about task specs refer to
    clrs._src.specs


  Args:
    algo: Name of the algorithm the sample comes from.
    sample: A sample generated by a CLRS sampler.
    use_hints: if True the initial CLRS hint is added to the input, the rest of
      to the output.

  Returns:
    A 3-tuple of (input, output_names, output) strings.
  """
  spec = specs.SPECS[algo]

  # Create input prompt.
  input_strs = _create_input_feature_strs(spec, sample.features.inputs)
  input_str = DEFAULT_SEPARATOR.join(input_strs)
  # Create output prompt.
  output_names = _get_output_names(
      algo_name=algo,
      spec=spec,
      use_hints=use_hints,
  )
  output_strs = _get_output_str(
      sample,
      spec,
      algo_name=algo,
      use_hints=use_hints,
  )
  output_str = DEFAULT_SEPARATOR.join(output_strs)
  output_names_strs = DEFAULT_SEPARATOR.join(output_names)

  hints_added = False
  if use_hints:
    input_hint_str, output_hint_str, hints_added = _create_hint_feature_strs(
        algo_name=algo,
        spec=spec,
        inputs=sample.features.inputs,
        hints=sample.features.hints,
        output_names=output_names,
    )
    output_str = _format_hint([output_str], algo_name=algo)
    output_names_strs = _format_hint([output_names_strs], algo_name=algo)

    if input_hint_str:
      input_hint_str = f'{INPUT_TRACE_MARKER} {input_hint_str}'
      input_str = DEFAULT_SEPARATOR.join([input_str, input_hint_str])
      output_str = TRACE_ANSWER_SEPARATOR.join(
          [
              output_hint_str if output_hint_str else '',
              output_str,
          ],
      )

  return input_str, output_names_strs, output_str, hints_added


def _create_input_feature_strs(
    spec: specs.Spec,
    inputs: samplers.Features,
) -> list[str]:
  """Extracts input features and convert them into strings."""
  input_strs = []
  for spec_name in spec:

    stage, _, _ = spec[spec_name]  # (stage, location, type)

    if stage != specs.Stage.INPUT:
      continue

    if _do_not_include_input_in_text(spec_name, spec):
      continue

    input_strs.append(
        _feature_to_str(
            name=spec_name,
            spec=spec,
            x=_get_feature_by_name(inputs, spec_name).data,
            with_name=True,
        ),
    )
  return input_strs


def _create_output_feature_strs(
    spec: specs.Spec,
    inputs: samplers.Features,
    outputs: samplers.Features,
) -> list[str]:
  """Extracts output features and convert them into strings."""
  output_strs = []
  for spec_name in spec:
    stage, _, _ = spec[spec_name]

    if stage != specs.Stage.OUTPUT:
      continue

    x = _get_feature_by_name(outputs, spec_name).data
    output_strs.append(
        _feature_to_str(
            name=spec_name,
            spec=spec,
            x=x,
            with_name=False,
            inputs=inputs,
        )
    )

  return output_strs


def _is_hint_field(
    field_name: str,
    algo_name: str,
    output_names: list[str],
) -> bool:
  """Checks if a field is a hint field."""
  if algo_name in CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER:
    return field_name == CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER[algo_name]
  if algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER:
    return field_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER[algo_name]
  else:
    return field_name[: -len(_HINT_PREFIX)] in output_names


def _get_output_name(hint_name: str, algo_name: str) -> str:
  """Gets the output name for a hint field."""
  if algo_name in CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER:
    return CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER[algo_name]
  if algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER:
    return hint_name
  else:
    return hint_name[: -len(_HINT_PREFIX)]


def _format_hint(hints: list[str], algo_name: str) -> str:
  """Formats a hint field."""
  result = DEFAULT_SEPARATOR.join(hints)
  if algo_name in CLRS_PARENTHESES_TRACES:
    result = f'({result})'
  return result


def _create_hint_feature_strs(
    algo_name: str,
    spec: specs.Spec,
    inputs: samplers.Features,
    hints: samplers.Features,
    output_names: list[str],
) -> tuple[str, str, bool]:
  """Extracts hint features and convert them into strings."""
  input_hint_strs = []
  unrolled_hints_strs = []
  for hint in hints:
    hint_name = hint.name
    if not _is_hint_field(hint_name, algo_name, output_names):
      continue

    result_hint = _get_feature_by_name(hints, hint_name).data

    output_name = _get_output_name(hint_name, algo_name)

    # The first element of `result_hint` is the initial hint that is used in the
    # input prompt.
    input_hint_strs.append(
        _feature_to_str(
            name=output_name,
            spec=spec,
            x=np.array(result_hint[0]),
            with_name=False,
            inputs=inputs,
        )
    )

    unrolled_hints = []
    # The first element of `result_hint` is an input hint, and the last element
    # is identical to the output result. We don't need either of these elements.
    # for output hints, so we skip them.
    for unrolled_hint in result_hint[1:-1]:
      unrolled_hints.append(
          _feature_to_str(
              name=output_name,
              spec=spec,
              x=np.array(unrolled_hint),
              with_name=False,
              inputs=inputs,
          ),
      )
    unrolled_hints_strs.append(unrolled_hints)

  hints_found = len(input_hint_strs) & len(unrolled_hints_strs)

  input_hint_str = _format_hint(input_hint_strs, algo_name=algo_name)
  output_hint_strs = []
  if hints_found:
    unrolled_hints_lengths = set(
        [len(unrolled_hint) for unrolled_hint in unrolled_hints_strs]
    )
    if len(unrolled_hints_lengths) != 1:
      raise ValueError(f'Output hints have to have equal length. Spec: {spec}')

    for hints in zip(*unrolled_hints_strs):
      output_hint_strs.append(_format_hint(hints, algo_name))

  output_hint_str = DEFAULT_SEPARATOR.join(output_hint_strs)

  return input_hint_str, output_hint_str, bool(hints_found)


def _feature_to_str(
    name: str,
    spec: specs.Spec,
    x: np.ndarray,
    with_name: bool,
    inputs: Optional[samplers.Features] = None,
    edge_masks_as_edge_list: bool = False,
) -> str:
  """Converts a numerical CLRS feature into a string."""
  if x.shape[0] != 1:
    raise ValueError(
        'Feature first dimension (batch) must be 1 but it has shape'
        f' {x.shape}.',
    )

  x = x[0]
  unused_stage, location, typ_ = spec[name]
  match location:
    case specs.Location.NODE:
      output = _convert_node_features_to_str(
          x=x,
          spec_name=name,
          spec=spec,
          spec_type=typ_,
          inputs=inputs,
      )
    case specs.Location.GRAPH:
      output = _convert_graph_features_to_str(
          x=x,
          spec_name=name,
          spec=spec,
          spec_type=typ_,
      )
    case specs.Location.EDGE:
      output = _convert_edge_features_to_str(
          x=x,
          spec_name=name,
          spec=spec,
          spec_type=typ_,
          edge_masks_as_edge_list=edge_masks_as_edge_list,
      )
    case _:
      raise KeyError(f'Hint location not supported in spec {spec[name]}')

  if with_name:
    return f'{name}: {output}'
  else:
    return output


def predecessors_to_order(x: np.ndarray) -> np.ndarray:
  """From list of predecessors to list of ordered node indices."""
  x = x.astype(int)
  y = np.ones(len(x))
  y[x] = 0
  [last] = np.where(y)[0]
  order = np.zeros(len(x), dtype=int)
  order[-1] = last
  for i in range(len(order) - 2, -1, -1):
    order[i] = x[order[i+1]]
  return order


def _convert_node_features_to_str(
    x: np.ndarray,
    spec_name: str,
    spec: specs.Spec,
    spec_type: str,
    inputs: Optional[samplers.Features] = None,
) -> str:
  """Converts node features into string."""
  match spec_type:
    case specs.Type.SHOULD_BE_PERMUTATION:
      # For the text version of CLRS, if the output is a permutation, we present
      # the "key" input values in the order given by the permutation.
      nonsorted_values = _get_feature_by_name(inputs, 'key').data[0]
      permutation_indexes = np.array(predecessors_to_order(x)).astype(int)
      sorted_values = np.array(
          [nonsorted_values[index] for index in permutation_indexes]
      )

      return _bracket(
          SEQUENCE_SEPARATOR.join([f'{scalar:.3g}' for scalar in sorted_values])
      )

    case specs.Type.MASK_ONE:
      [index] = x.nonzero()[0]
      return f'{index}'

    case specs.Type.SCALAR:
      return _bracket(SEQUENCE_SEPARATOR.join([f'{a:.3g}' for a in x]))

    case specs.Type.MASK | specs.Type.POINTER | specs.Type.CATEGORICAL:
      if spec_type == specs.Type.CATEGORICAL:
        categories = np.argmax(x, axis=-1)
        int_output = categories
      else:
        int_output = x.astype(int)
      return _bracket(SEQUENCE_SEPARATOR.join([f'{a}' for a in int_output]))

    case _:
      raise KeyError(f'Feature type not supported in spec {spec[spec_name]}')


def _convert_graph_features_to_str(
    x: np.ndarray,
    spec_name: str,
    spec: specs.Spec,
    spec_type: str,
) -> str:
  """Converts graph features into string."""
  match spec_type:
    case specs.Type.SCALAR:
      return f'{x:.3f}'

    case specs.Type.CATEGORICAL:
      categories = np.argmax(x, axis=-1)
      return f'{categories}'

    case _:
      if spec_type in [
          specs.Type.MASK,
          specs.Type.MASK_ONE,
          specs.Type.POINTER,
      ]:
        return f'{x.astype(int)}'
      else:
        raise KeyError(f'Feature type not supported in spec {spec[spec_name]}')


def _convert_edge_features_to_str(
    x: np.ndarray,
    spec_name: str,
    spec: specs.Spec,
    spec_type: str,
    edge_masks_as_edge_list: bool,
):
  """Converts edge features into string."""

  if edge_masks_as_edge_list:
    if spec_type == specs.Type.MASK or (
        spec_type == specs.Type.SCALAR and _is_binary(x)
    ):
      edges = list(zip(*np.nonzero(x > 0)))
      return DEFAULT_SEPARATOR.join([f'({x},{y})' for x, y in edges])
  else:
    match spec_type:
      case specs.Type.POINTER | specs.Type.MASK | specs.Type.CATEGORICAL:
        if spec_type == specs.Type.CATEGORICAL:
          # lcs_length includes masked elements where the category is -1
          mask = np.any(x == specs.OutputClass.MASKED, axis=-1)
          categories = np.argmax(x, axis=-1)
          categories[mask] = -1
          int_output = categories
        else:
          int_output = x.astype(int)
        row_to_str = lambda r: _bracket(' '.join([f'{a}' for a in r]))
        return _bracket(
            DEFAULT_SEPARATOR.join(
                [row_to_str(r) for r in int_output],
            ),
        )

      case specs.Type.SCALAR:
        row_to_str = lambda r: _bracket(' '.join([f'{a:.3g}' for a in r]))
        return _bracket(DEFAULT_SEPARATOR.join([row_to_str(r) for r in x]))

  raise KeyError(f'Feature type not supported in spec {spec[spec_name]}')


def _get_feature_by_name(examples: samplers.Features, spec_name: str) -> Any:
  filtered_inputs = [
      example for example in examples if example.name == spec_name
  ]

  if len(filtered_inputs) > 1:
    raise ValueError("More than one example has name '{}'".format(spec_name))

  return filtered_inputs[0]


def _is_binary(x: np.ndarray) -> bool:
  precision = 10000
  elements = set(np.unique(np.round(x * precision).astype(int) / precision))
  return elements.issubset({-1, 0, 1})


def _bracket(s: str) -> str:
  return f'[{s}]'


def _do_not_include_input_in_text(spec_name: str, spec: specs.Spec) -> bool:
  if spec_name == 'pos':
    return True
  if spec_name == 'adj' and 'A' in spec:
    return True  # in all cases, 'adj' is redundant with A

  return False
