# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to create text versions of CLRS data."""

from typing import Any, Optional

import numpy as np

from src.exps_performance.clrs import samplers, specs

CLRS_TASKS_WITH_HINTS = tuple(
    [
        "activity_selector",
        "articulation_points",
        "bellman_ford",
        "bfs",
        "binary_search",
        "bridges",
        "bubble_sort",
        "dag_shortest_paths",
        "dfs",
        "dijkstra",
        "find_maximum_subarray_kadane",
        "floyd_warshall",
        "graham_scan",
        "heapsort",
        "insertion_sort",
        "jarvis_march",
        "kmp_matcher",
        "lcs_length",
        "matrix_chain_order",
        "minimum",
        "mst_kruskal",
        "mst_prim",
        "naive_string_matcher",
        "optimal_bst",
        "quickselect",
        "quicksort",
        "strongly_connected_components",
        "task_scheduling",
        "topological_sort",
    ],
)
CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER = {
    "naive_string_matcher": "s",
    "kmp_matcher": "s",
}
CLRS_SEARCH_TAKS_OUTPUT_REPLACER = {
    "binary_search": ["low", "high"],
    "find_maximum_subarray_kadane": ["best_low", "best_high"],
    "quickselect": ["pivot"],
}
CLRS_PARENTHESES_TRACES = frozenset({"binary_search", "find_maximum_subarray_kadane"})
CLRS_SORTING_TASKS = ["bubble_sort", "heapsort", "insertion_sort", "quicksort"]

DEFAULT_SEPARATOR = ", "
INPUT_TRACE_MARKER = "initial_trace:"
TRACE_ANSWER_SEPARATOR = " | "
OUTPUT_TRACE_MARKER = "trace"
PERMUTATION_SEPARATOR = "->"
SEQUENCE_SEPARATOR = " "

_HINT_PREFIX = "_h"


def format_clrs_example(
    algo: str,
    sample: samplers.Feedback,
    use_hints: bool = False,
) -> tuple[str, str]:
    """Formats CLRS example into prompt for the LLM.

    Args:
      algo: Name of the algorithm the sample comes from.
      sample: A sample generated by a CLRS sampler.
      use_hints: if True the initial CLRS hint is added to the input, the rest of
        to the output.

    Returns:
      The question and answer prompts.
    """
    input_, output_names, output, hints_added = sample_to_str(
        algo=algo,
        sample=sample,
        use_hints=use_hints,
    )
    if hints_added:
        output_name_str = TRACE_ANSWER_SEPARATOR.join([OUTPUT_TRACE_MARKER, output_names])
    else:
        output_name_str = output_names

    question = f"{algo}:\n{input_}\n{output_name_str}:\n"
    answer = f"{output}\n\n"

    return question, answer


def _get_output_names(
    algo_name: str,
    spec: specs.Spec,
    use_hints: bool,
) -> list[str]:
    """Gets the output names for a CLRS algorithm."""
    if algo_name in CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER and use_hints:
        return [CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER[algo_name]]
    elif algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER and use_hints:
        return CLRS_SEARCH_TAKS_OUTPUT_REPLACER[algo_name]
    else:
        return [spec_name for spec_name in spec if spec[spec_name][0] == specs.Stage.OUTPUT]


def _get_output_str(sample: samplers.Feedback, spec: specs.Spec, algo_name: str, use_hints: bool) -> list[str]:
    """Gets the output string for a CLRS algorithm."""
    if algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER and use_hints:
        output_results = []
        spec_names = CLRS_SEARCH_TAKS_OUTPUT_REPLACER[algo_name]
        for spec_name in spec_names:
            x = _get_feature_by_name(sample.features.hints, spec_name).data[-1]
            output_results.append(
                _feature_to_str(
                    name=spec_name,
                    spec=spec,
                    x=x,
                    with_name=False,
                    inputs=sample.features.inputs,
                )
            )
        return [DEFAULT_SEPARATOR.join(output_results)]
    else:
        return _create_output_feature_strs(
            spec=spec,
            inputs=sample.features.inputs,
            outputs=sample.outputs,
        )


def sample_to_str(
    algo: str,
    sample: samplers.Feedback,
    use_hints: bool = False,
) -> tuple[str, str, str, bool]:
    """Converts a CLRS sample into input and output strings.

      Output examples without hints:
        1. insertion_sort
            input_str = 'key: [0.549 0.715 0.603 0.545 0.424]'
            output_names_strs = 'pred'
            output_str = '[0.424 0.545 0.549 0.603 0.715]'
        2. find_maximum_subarray
            input_str = 'key: [0.098 0.43 0.206 0.09 -0.153]'
            output_names_strs = 'start, end'
            output_str = '0, 3'
        3. binary_search
            input_str = 'key: [0.424 0.545 0.549 0.603 0.715], target: 0.646'
            output_names_strs = 'return'
            output_str = '4'

      Output examples with hints:
        1. insertion_sort
            input_str = 'insertion_sort: key: [0.549 0.715 0.603], initial_trace:
            [0.549 0.715 0.603] trace | pred:'
            output_names_strs = 'pred'
            output_str = '[0.549 0.715 0.603] | [0.549 0.603 0.715]'
        2. find_maximum_subarray
            input_str = 'find_maximum_subarray_kadane: key: [0.098 0.43 0.206],
            initial_trace: (0, 0) trace | (best_low, best_high):'
            output_names_strs = 'best_low, best_high'
            output_str = '(0, 1) | (0, 2)'
        3. binary_search
            input_str = 'binary_search: key: [0.549 0.603 0.715], target: 0.545,
            initial_trace: (0, 2) trace | (low, high):'
            output_names_strs = 'return'
            output_str = '(0, 1) | (0, 0)'

      For more details about task specs refer to
      clrs._src.specs


    Args:
      algo: Name of the algorithm the sample comes from.
      sample: A sample generated by a CLRS sampler.
      use_hints: if True the initial CLRS hint is added to the input, the rest of
        to the output.

    Returns:
      A 3-tuple of (input, output_names, output) strings.
    """
    spec = specs.SPECS[algo]

    # Create input prompt.
    input_strs = _create_input_feature_strs(spec, sample.features.inputs)
    input_str = DEFAULT_SEPARATOR.join(input_strs)
    # Create output prompt.
    output_names = _get_output_names(
        algo_name=algo,
        spec=spec,
        use_hints=use_hints,
    )
    output_strs = _get_output_str(
        sample,
        spec,
        algo_name=algo,
        use_hints=use_hints,
    )
    output_str = DEFAULT_SEPARATOR.join(output_strs)
    output_names_strs = DEFAULT_SEPARATOR.join(output_names)

    hints_added = False
    if use_hints:
        input_hint_str, output_hint_str, hints_added = _create_hint_feature_strs(
            algo_name=algo,
            spec=spec,
            inputs=sample.features.inputs,
            hints=sample.features.hints,
            output_names=output_names,
        )
        output_str = _format_hint([output_str], algo_name=algo)
        output_names_strs = _format_hint([output_names_strs], algo_name=algo)

        if input_hint_str:
            input_hint_str = f"{INPUT_TRACE_MARKER} {input_hint_str}"
            input_str = DEFAULT_SEPARATOR.join([input_str, input_hint_str])
            output_str = TRACE_ANSWER_SEPARATOR.join(
                [
                    output_hint_str if output_hint_str else "",
                    output_str,
                ],
            )

    return input_str, output_names_strs, output_str, hints_added


def _create_input_feature_strs(
    spec: specs.Spec,
    inputs: samplers.Features,
) -> list[str]:
    """Extracts input features and convert them into strings."""
    input_strs = []
    for spec_name in spec:
        stage, _, _ = spec[spec_name]  # (stage, location, type)

        if stage != specs.Stage.INPUT:
            continue

        if _do_not_include_input_in_text(spec_name, spec):
            continue

        input_strs.append(
            _feature_to_str(
                name=spec_name,
                spec=spec,
                x=_get_feature_by_name(inputs, spec_name).data,
                with_name=True,
            ),
        )
    return input_strs


def _create_output_feature_strs(
    spec: specs.Spec,
    inputs: samplers.Features,
    outputs: samplers.Features,
) -> list[str]:
    """Extracts output features and convert them into strings."""
    output_strs = []
    for spec_name in spec:
        stage, _, _ = spec[spec_name]

        if stage != specs.Stage.OUTPUT:
            continue

        x = _get_feature_by_name(outputs, spec_name).data
        output_strs.append(
            _feature_to_str(
                name=spec_name,
                spec=spec,
                x=x,
                with_name=False,
                inputs=inputs,
            )
        )

    return output_strs


def _is_hint_field(
    field_name: str,
    algo_name: str,
    output_names: list[str],
) -> bool:
    """Checks if a field is a hint field."""
    if algo_name in CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER:
        return field_name == CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER[algo_name]
    if algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER:
        return field_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER[algo_name]
    else:
        return field_name[: -len(_HINT_PREFIX)] in output_names


def _get_output_name(hint_name: str, algo_name: str) -> str:
    """Gets the output name for a hint field."""
    if algo_name in CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER:
        return CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER[algo_name]
    if algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER:
        return hint_name
    else:
        return hint_name[: -len(_HINT_PREFIX)]


def _format_hint(hints: list[str], algo_name: str) -> str:
    """Formats a hint field."""
    result = DEFAULT_SEPARATOR.join(hints)
    if algo_name in CLRS_PARENTHESES_TRACES:
        result = f"({result})"
    return result


def _create_hint_feature_strs(
    algo_name: str,
    spec: specs.Spec,
    inputs: samplers.Features,
    hints: samplers.Features,
    output_names: list[str],
) -> tuple[str, str, bool]:
    """Extracts hint features and convert them into strings."""
    input_hint_strs = []
    unrolled_hints_strs = []
    for hint in hints:
        hint_name = hint.name
        if not _is_hint_field(hint_name, algo_name, output_names):
            continue

        result_hint = _get_feature_by_name(hints, hint_name).data

        output_name = _get_output_name(hint_name, algo_name)

        # The first element of `result_hint` is the initial hint that is used in the
        # input prompt.
        input_hint_strs.append(
            _feature_to_str(
                name=output_name,
                spec=spec,
                x=np.array(result_hint[0]),
                with_name=False,
                inputs=inputs,
            )
        )

        unrolled_hints = []
        # The first element of `result_hint` is an input hint, and the last element
        # is identical to the output result. We don't need either of these elements.
        # for output hints, so we skip them.
        for unrolled_hint in result_hint[1:-1]:
            unrolled_hints.append(
                _feature_to_str(
                    name=output_name,
                    spec=spec,
                    x=np.array(unrolled_hint),
                    with_name=False,
                    inputs=inputs,
                ),
            )
        unrolled_hints_strs.append(unrolled_hints)

    hints_found = len(input_hint_strs) & len(unrolled_hints_strs)

    input_hint_str = _format_hint(input_hint_strs, algo_name=algo_name)
    output_hint_strs = []
    if hints_found:
        unrolled_hints_lengths = set([len(unrolled_hint) for unrolled_hint in unrolled_hints_strs])
        if len(unrolled_hints_lengths) != 1:
            raise ValueError(f"Output hints have to have equal length. Spec: {spec}")

        for hints in zip(*unrolled_hints_strs):  # type: ignore
            output_hint_strs.append(_format_hint(hints, algo_name))  # type: ignore

    output_hint_str = DEFAULT_SEPARATOR.join(output_hint_strs)

    return input_hint_str, output_hint_str, bool(hints_found)


def _feature_to_str(
    name: str,
    spec: specs.Spec,
    x: np.ndarray,
    with_name: bool,
    inputs: Optional[samplers.Features] = None,
    edge_masks_as_edge_list: bool = False,
) -> str:
    """Converts a numerical CLRS feature into a string."""
    if x.shape[0] != 1:
        raise ValueError(
            f"Feature first dimension (batch) must be 1 but it has shape {x.shape}.",
        )

    x = x[0]
    unused_stage, location, typ_ = spec[name]
    match location:
        case specs.Location.NODE:
            output = _convert_node_features_to_str(
                x=x,
                spec_name=name,
                spec=spec,
                spec_type=typ_,
                inputs=inputs,
            )
        case specs.Location.GRAPH:
            output = _convert_graph_features_to_str(
                x=x,
                spec_name=name,
                spec=spec,
                spec_type=typ_,
            )
        case specs.Location.EDGE:
            output = _convert_edge_features_to_str(
                x=x,
                spec_name=name,
                spec=spec,
                spec_type=typ_,
                edge_masks_as_edge_list=edge_masks_as_edge_list,
            )
        case _:
            raise KeyError(f"Hint location not supported in spec {spec[name]}")

    if with_name:
        return f"{name}: {output}"
    else:
        return output


def predecessors_to_order(x: np.ndarray) -> np.ndarray:
    """From list of predecessors to list of ordered node indices."""
    x = x.astype(int)
    y = np.ones(len(x))
    y[x] = 0
    [last] = np.where(y)[0]
    order = np.zeros(len(x), dtype=int)
    order[-1] = last
    for i in range(len(order) - 2, -1, -1):
        order[i] = x[order[i + 1]]
    return order


def _convert_node_features_to_str(
    x: np.ndarray,
    spec_name: str,
    spec: specs.Spec,
    spec_type: str,
    inputs: Optional[samplers.Features] = None,
) -> str:
    """Converts node features into string."""
    match spec_type:
        case specs.Type.SHOULD_BE_PERMUTATION:
            # For the text version of CLRS, if the output is a permutation, we present
            # the "key" input values in the order given by the permutation.
            nonsorted_values = _get_feature_by_name(inputs, "key").data[0]  # type: ignore
            permutation_indexes = np.array(predecessors_to_order(x)).astype(int)
            sorted_values = np.array([nonsorted_values[index] for index in permutation_indexes])

            return _bracket(SEQUENCE_SEPARATOR.join([str(scalar) for scalar in sorted_values]))

        case specs.Type.MASK_ONE:
            [index] = x.nonzero()[0]
            return f"{index}"

        case specs.Type.SCALAR:
            return _bracket(SEQUENCE_SEPARATOR.join([str(a) for a in x]))

        case specs.Type.MASK | specs.Type.POINTER | specs.Type.CATEGORICAL:
            if spec_type == specs.Type.CATEGORICAL:
                categories = np.argmax(x, axis=-1)
                int_output = categories
            else:
                int_output = x.astype(int)
            return _bracket(SEQUENCE_SEPARATOR.join([f"{a}" for a in int_output]))

        case _:
            raise KeyError(f"Feature type not supported in spec {spec[spec_name]}")


def _convert_graph_features_to_str(
    x: np.ndarray,
    spec_name: str,
    spec: specs.Spec,
    spec_type: str,
) -> str:
    """Converts graph features into string."""
    match spec_type:
        case specs.Type.SCALAR:
            return str(x)

        case specs.Type.CATEGORICAL:
            categories = np.argmax(x, axis=-1)
            return f"{categories}"

        case _:
            if spec_type in [
                specs.Type.MASK,
                specs.Type.MASK_ONE,
                specs.Type.POINTER,
            ]:
                return f"{x.astype(int)}"
            else:
                raise KeyError(f"Feature type not supported in spec {spec[spec_name]}")


def _convert_edge_features_to_str(
    x: np.ndarray,
    spec_name: str,
    spec: specs.Spec,
    spec_type: str,
    edge_masks_as_edge_list: bool,
) -> str:
    """Converts edge features into string."""

    if edge_masks_as_edge_list:
        if spec_type == specs.Type.MASK or (spec_type == specs.Type.SCALAR and _is_binary(x)):
            edges = list(zip(*np.nonzero(x > 0)))
            return DEFAULT_SEPARATOR.join([f"({x},{y})" for x, y in edges])
    else:
        match spec_type:
            case specs.Type.POINTER | specs.Type.MASK | specs.Type.CATEGORICAL:
                if spec_type == specs.Type.CATEGORICAL:
                    # lcs_length includes masked elements where the category is -1
                    mask = np.any(x == specs.OutputClass.MASKED, axis=-1)
                    categories = np.argmax(x, axis=-1)
                    categories[mask] = -1
                    int_output = categories
                else:
                    int_output = x.astype(int)
                row_to_str = lambda r: _bracket(" ".join([f"{a}" for a in r]))  # noqa: E731
                return _bracket(
                    DEFAULT_SEPARATOR.join(
                        [row_to_str(r) for r in int_output],
                    ),
                )

            case specs.Type.SCALAR:
                row_to_str = lambda r: _bracket(" ".join([str(a) for a in r]))  # noqa: E731
                return _bracket(DEFAULT_SEPARATOR.join([row_to_str(r) for r in x]))

    raise KeyError(f"Feature type not supported in spec {spec[spec_name]}")


def _get_feature_by_name(examples: samplers.Features, spec_name: str) -> Any:
    filtered_inputs = [example for example in examples if example.name == spec_name]

    if len(filtered_inputs) > 1:
        raise ValueError("More than one example has name '{}'".format(spec_name))

    return filtered_inputs[0]


def _is_binary(x: np.ndarray) -> bool:
    precision = 10000
    elements = set(np.unique(np.round(x * precision).astype(int) / precision))
    return elements.issubset({-1, 0, 1})


def _bracket(s: str) -> str:
    return f"[{s}]"


def _do_not_include_input_in_text(spec_name: str, spec: specs.Spec) -> bool:
    if spec_name == "pos":
        return True
    if spec_name == "adj" and "A" in spec:
        return True  # in all cases, 'adj' is redundant with A

    return False
