# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Probing utilities.

The dataflow for an algorithm is represented by `(stage, loc, type, data)`
"probes" that are valid under that algorithm's spec (see `specs.py`).

When constructing probes, it is convenient to represent these fields in a nested
format (`ProbesDict`) to facilate efficient contest-based look-up.

"""

import functools
from typing import Any, Dict, List, Tuple, TypeAlias, Union

import attr
import jax
import jax.numpy as jnp
import numpy as np

from src.exps_performance.clrs import specs

# import tensorflow as tf


_Location: TypeAlias = specs.Location
_Stage: TypeAlias = specs.Stage
_Type: TypeAlias = specs.Type
_OutputClass: TypeAlias = specs.OutputClass

_Array: TypeAlias = np.ndarray | jax.Array
_Data: TypeAlias = Union[_Array, List[_Array]]
_DataOrType: TypeAlias = Union[_Data, str]

ProbesDict = Dict[str, Dict[str, Dict[str, Dict[str, _DataOrType]]]]


def _convert_to_str(element: Any) -> str:
    if isinstance(element, bytes):
        return element.decode("utf-8")
    if isinstance(element, np.ndarray):
        item = element.item()
        if isinstance(item, (bytes, bytearray)):
            return item.decode("utf-8")
        return str(item)
    return str(element)


@jax.tree_util.register_pytree_node_class
@attr.define
class DataPoint:
    """Describes a data point."""

    _name: str
    _location: str
    _type_: str
    data: _Array

    def __init__(self, name: str, location: str, type_: str, data: _Array):
        self._name = name
        self._location = location
        self._type_ = type_
        self.data = data

    @property
    def name(self) -> str:
        return _convert_to_str(self._name)

    @property
    def location(self) -> str:
        return _convert_to_str(self._location)

    @property
    def type_(self) -> str:
        return _convert_to_str(self._type_)

    def __repr__(self) -> str:
        s = f'DataPoint(name="{self.name}",\tlocation={self.location},\t'
        return s + f"type={self.type_},\tdata=Array{self.data.shape})"

    def tree_flatten(self) -> Tuple[Tuple[_Array], Tuple[str, str, str]]:
        data = (self.data,)
        meta = (self.name, self.location, self.type_)
        return data, meta

    @classmethod
    def tree_unflatten(cls, meta: Tuple[str, str, str], data: Tuple[_Array]) -> "DataPoint":
        name, location, type_ = meta
        (subdata,) = data
        return DataPoint(name, location, type_, subdata)


class ProbeError(Exception):
    pass


def initialize(spec: specs.Spec) -> ProbesDict:
    """Initializes an empty `ProbesDict` corresponding with the provided spec."""
    probes = dict()  # type: ignore[var-annotated]
    for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]:
        probes[stage] = {}
        for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
            probes[stage][loc] = {}

    for name in spec:
        stage, loc, t = spec[name]
        probes[stage][loc][name] = {}
        probes[stage][loc][name]["data"] = []
        probes[stage][loc][name]["type_"] = t
    # Pytype thinks initialize() returns a ProbesDict with a str for all final
    # values instead of _DataOrType.
    return probes  # pytype: disable=bad-return-type


def push(probes: ProbesDict, stage: str, next_probe: Dict[str, _DataOrType]) -> None:
    """Pushes a probe into an existing `ProbesDict`."""
    for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
        for name in probes[stage][loc]:
            if name not in next_probe:
                raise ProbeError(f"Missing probe for {name}.")
            if isinstance(probes[stage][loc][name]["data"], _Array):  # type: ignore[misc, arg-type]
                raise ProbeError("Attemping to push to finalized `ProbesDict`.")
            probes[stage][loc][name]["data"].append(next_probe[name])  # type: ignore[union-attr]


def finalize(probes: ProbesDict) -> None:
    """Finalizes a `ProbesDict` by stacking/squeezing `data` field."""
    for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]:
        for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
            for name in probes[stage][loc]:
                if isinstance(probes[stage][loc][name]["data"], _Array):  # type: ignore[misc, arg-type]
                    raise ProbeError("Attemping to re-finalize a finalized `ProbesDict`.")
                if stage == _Stage.HINT:
                    # Hints are provided for each timestep. Stack them here.
                    probes[stage][loc][name]["data"] = np.stack(probes[stage][loc][name]["data"])  # type: ignore[arg-type]
                else:
                    # Only one instance of input/output exist. Remove leading axis.
                    probes[stage][loc][name]["data"] = np.squeeze(np.array(probes[stage][loc][name]["data"]))


def split_stages(
    probes: ProbesDict,
    spec: specs.Spec,
) -> Tuple[List[DataPoint], List[DataPoint], List[DataPoint]]:
    """Splits contents of `ProbesDict` into `DataPoint`s by stage."""

    inputs = []
    outputs = []
    hints = []

    for name in spec:
        stage, loc, t = spec[name]

        if stage not in probes:
            raise ProbeError(f"Missing stage {stage}.")
        if loc not in probes[stage]:
            raise ProbeError(f"Missing location {loc}.")
        if name not in probes[stage][loc]:
            raise ProbeError(f"Missing probe {name}.")
        if "type_" not in probes[stage][loc][name]:
            raise ProbeError(f"Probe {name} missing attribute `type_`.")
        if "data" not in probes[stage][loc][name]:
            raise ProbeError(f"Probe {name} missing attribute `data`.")
        if t != probes[stage][loc][name]["type_"]:
            raise ProbeError(f"Probe {name} of incorrect type {t}.")

        data = probes[stage][loc][name]["data"]
        if not isinstance(probes[stage][loc][name]["data"], _Array):  # type: ignore[misc, arg-type]
            raise ProbeError((f'Invalid `data` for probe "{name}". ' + "Did you forget to call `probing.finalize`?"))

        if t in [_Type.MASK, _Type.MASK_ONE, _Type.CATEGORICAL]:
            if not ((data == 0) | (data == 1) | (data == -1)).all():  # type: ignore[union-attr]
                raise ProbeError(f'0|1|-1 `data` for probe "{name}"')
            # pytype: enable=attribute-error
            if t in [_Type.MASK_ONE, _Type.CATEGORICAL] and not np.all(np.sum(np.abs(data), -1) == 1):
                raise ProbeError(f'Expected one-hot `data` for probe "{name}"')

        dim_to_expand = 1 if stage == _Stage.HINT else 0
        data_point = DataPoint(name=name, location=loc, type_=t, data=np.expand_dims(data, dim_to_expand))

        if stage == _Stage.INPUT:
            inputs.append(data_point)
        elif stage == _Stage.OUTPUT:
            outputs.append(data_point)
        else:
            hints.append(data_point)

    return inputs, outputs, hints


# pylint: disable=invalid-name


def array(A_pos: np.ndarray) -> np.ndarray:
    """Constructs an `array` probe."""
    probe = np.arange(A_pos.shape[0])
    for i in range(1, A_pos.shape[0]):
        probe[A_pos[i]] = A_pos[i - 1]
    return probe


def array_cat(A: np.ndarray, n: int) -> np.ndarray:
    """Constructs an `array_cat` probe."""
    assert n > 0
    probe = np.zeros((A.shape[0], n))
    for i in range(A.shape[0]):
        probe[i, A[i]] = 1
    return probe


def heap(A_pos: np.ndarray, heap_size: int) -> np.ndarray:
    """Constructs a `heap` probe."""
    assert heap_size > 0
    probe = np.arange(A_pos.shape[0])
    for i in range(1, heap_size):
        probe[A_pos[i]] = A_pos[(i - 1) // 2]
    return probe


def graph(A: np.ndarray) -> np.ndarray:
    """Constructs a `graph` probe."""
    probe = (A != 0) * 1.0
    probe = ((A + np.eye(A.shape[0])) != 0) * 1.0
    return np.array(probe)


def mask_one(i: int, n: int) -> np.ndarray:
    """Constructs a `mask_one` probe."""
    assert n > i
    probe = np.zeros(n)
    probe[i] = 1
    return probe


def strings_id(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray:
    """Constructs a `strings_id` probe."""
    probe_T = np.zeros(T_pos.shape[0])
    probe_P = np.ones(P_pos.shape[0])
    return np.concatenate([probe_T, probe_P])


def strings_pair(pair_probe: np.ndarray) -> np.ndarray:
    """Constructs a `strings_pair` probe."""
    n = pair_probe.shape[0]
    m = pair_probe.shape[1]
    probe_ret = np.zeros((n + m, n + m))
    for i in range(0, n):
        for j in range(0, m):
            probe_ret[i, j + n] = pair_probe[i, j]
    return probe_ret


def strings_pair_cat(pair_probe: np.ndarray, nb_classes: int) -> np.ndarray:
    """Constructs a `strings_pair_cat` probe."""
    assert nb_classes > 0
    n = pair_probe.shape[0]
    m = pair_probe.shape[1]

    # Add an extra class for 'this cell left blank.'
    probe_ret = np.zeros((n + m, n + m, nb_classes + 1))
    for i in range(0, n):
        for j in range(0, m):
            probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE

    # Fill the blank cells.
    for i_1 in range(0, n):
        for i_2 in range(0, n):
            probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED
    for j_1 in range(0, m):
        for x in range(0, n + m):
            probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED
    return probe_ret


def strings_pi(T_pos: np.ndarray, P_pos: np.ndarray, pi: np.ndarray) -> np.ndarray:
    """Constructs a `strings_pi` probe."""
    probe = np.arange(T_pos.shape[0] + P_pos.shape[0])
    for j in range(P_pos.shape[0]):
        probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + pi[P_pos[j]]
    return probe


def strings_pos(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray:
    """Constructs a `strings_pos` probe."""
    probe_T = np.copy(T_pos) * 1.0 / T_pos.shape[0]
    probe_P = np.copy(P_pos) * 1.0 / P_pos.shape[0]
    return np.concatenate([probe_T, probe_P])


def strings_pred(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray:
    """Constructs a `strings_pred` probe."""
    probe = np.arange(T_pos.shape[0] + P_pos.shape[0])
    for i in range(1, T_pos.shape[0]):
        probe[T_pos[i]] = T_pos[i - 1]
    for j in range(1, P_pos.shape[0]):
        probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + P_pos[j - 1]
    return probe


@functools.partial(jnp.vectorize, signature="(n)->(n,n)")
def predecessor_pointers_to_permutation_matrix(pointers: jnp.ndarray) -> jnp.ndarray:
    """Converts predecessor pointers to a permutation matrix.

    This function assumes that the pointers represent a linear order of the nodes
    (akin to a linked list), where each node points to its predecessor and the
    first node points to itself. It returns a permutation matrix `P` that sorts
    the nodes into the order implied by the pointers.

    Example:
    ```
    pointers = [2, 1, 1]
    P = [[0, 1, 0],
         [0, 0, 1],
         [1, 0, 0]]
    ```

    Args:
      pointers: array of shape [N] containing pointers. The pointers are assumed
        to describe a linear order such that `pointers[i]` is the predecessor
        of node `i`.

    Returns:
      Permutation matrix `P` of shape [N, N]. Given node features `x` of shape
      [N, F], `P @ x` returns sorted node features.
    """
    # Find the index of the last node: it's the node that no other node points to.
    nb_nodes = pointers.shape[-1]
    pointers_one_hot = jax.nn.one_hot(pointers, nb_nodes)
    last = pointers_one_hot.sum(-2).argmin()

    # Initialize permutation matrix with zeros.
    perm = jnp.zeros([nb_nodes, nb_nodes])

    for i in range(nb_nodes - 1, -1, -1):
        # perm[i, last] = 1
        perm += jax.nn.one_hot(i, nb_nodes)[..., None] * jax.nn.one_hot(last, nb_nodes)
        last = pointers[last]

    return perm


@functools.partial(jnp.vectorize, signature="(n,n)->(n)")
def permutation_matrix_to_predecessor_pointers(perm: jnp.ndarray) -> jnp.ndarray:
    """Converts a permutation matrix to predecessor pointers.

    Given an [N, N] permutation matrix `P` that sorts a list of nodes, this
    function returns predecessor pointers that encode the sorted order.

    Example:
    ```
    P = [[0, 1, 0],
         [0, 0, 1],
         [1, 0, 0]]
    pointers = [2, 1, 1]
    ```

    Args:
      perm: permutation matrix of shape [N, N].

    Returns:
      An array of shape [N] containing predecessor pointers.
    """
    nb_nodes = perm.shape[-1]

    # Initialize pointers to zeros.
    pointers = jnp.zeros([nb_nodes], dtype=int)

    # idx[i] is the index of the node at position i in the sorted order
    idx = perm.argmax(-1)

    # pointers[idx[0]] = idx[0]
    pointers += idx[0] * jax.nn.one_hot(idx[0], nb_nodes)

    for i in range(1, nb_nodes):
        # pointers[idx[i]] = idx[i - 1]
        pointers += idx[i - 1] * jax.nn.one_hot(idx[i], nb_nodes)

    # Ensure that the pointers are in the valid range even if the input is badly
    # formatted. This has no effect for well-formatted input.
    pointers = jnp.minimum(pointers, nb_nodes - 1)

    return pointers


@functools.partial(jnp.vectorize, signature="(n)->(n,n),(n)")
def predecessor_to_cyclic_predecessor_and_first(pointers: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Converts predecessor pointers to cyclic predecessor + first node mask.

    This function assumes that the pointers represent a linear order of the nodes
    (akin to a linked list), where each node points to its predecessor and the
    first node points to itself. It returns the same pointers, except that
    the first node points to the last, and a mask_one marking the first node.

    Example:
    ```
    pointers = [2, 1, 1]
    P = [[0, 0, 1],
         [1, 0, 0],
         [0, 1, 0]],
    M = [0, 1, 0]
    ```

    Args:
      pointers: array of shape [N] containing pointers. The pointers are assumed
        to describe a linear order such that `pointers[i]` is the predecessor
        of node `i`.

    Returns:
      Permutation pointers `P` of shape [N] and one-hot vector `M` of shape [N].
    """
    nb_nodes = pointers.shape[-1]
    pointers_one_hot = jax.nn.one_hot(pointers, nb_nodes)
    # Find the index of the last node: it's the node that no other node points to.
    last = pointers_one_hot.sum(-2).argmin()
    # Find the first node: should be the only one pointing to itself.
    first = pointers_one_hot.diagonal().argmax()
    mask = jax.nn.one_hot(first, nb_nodes)
    pointers_one_hot += mask[..., None] * jax.nn.one_hot(last, nb_nodes)
    pointers_one_hot -= mask[..., None] * mask
    return pointers_one_hot, mask
