"""Different groupings of functions."""
import dataclasses
import collections
from typing import Optional, Set

import sympy as sp

# typedefs
# NodeOrId = Union[sp.Expr, int]


###################################################################################


# Ignoring stuff like secants and cotangents since we do not generate them by default.
FN_TYPES = {
    'exp_like': ('exp', 'sin', 'sinh', 'cos', 'cosh', 'tan', 'tanh'),
    'log_like': ('log', 'asin', 'asinh', 'acos', 'acosh', 'atan', 'atanh'),

    'trigonometric_any': ('sin', 'cos', 'tan', 'asin', 'acos', 'atan'),
    'trigonometric_fwd': ('sin', 'cos', 'tan'),
    'trigonometric_inv': ('asin', 'acos', 'atan'),

    'hyperbolic_any': ('sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh'),
    'hyperbolic_fwd': ('sinh', 'cosh', 'tanh'),
    'hyperbolic_inv': ('asinh', 'acosh', 'atanh'),

    'sin_like_any': ('sin', 'sinh', 'asin', 'asinh'),
    'sin_like_fwd': ('sin', 'sinh'),
    'sin_like_inv': ('asin', 'asinh'),

    'cos_like_any': ('cos', 'cosh', 'acos', 'acosh'),
    'cos_like_fwd': ('cos', 'cosh'),
    'cos_like_inv': ('acos', 'acosh'),

    'tan_like_any': ('tan', 'tanh', 'atan', 'atanh'),
    'tan_like_fwd': ('tan', 'tanh'),
    'tan_like_inv': ('atan', 'atanh'),

    'sin_any': ('sin', 'asin'),
    'cos_any': ('cos', 'acos'),
    'tan_any': ('tan', 'atan'),

    'sinh_any': ('sinh', 'asinh'),
    'cosh_any': ('cosh', 'acosh'),
    'tanh_any': ('tanh', 'atanh'),
}


FN_TYPES = {
    k: frozenset(getattr(sp, a) for a in v)
    for k, v in FN_TYPES.items()
}


def _map_fn_to_its_types(fn_types):
    ret = collections.defaultdict(set)
    for fn_type, fns in fn_types.items():
        for fn in fns:
            ret[fn].add(fn_type)
    return {
        k: frozenset(v)
        for k, v in ret.items()
    }


FN_TO_TYPES_SET = _map_fn_to_its_types(FN_TYPES)

SUPPORTED_FNS = frozenset(FN_TO_TYPES_SET.keys())


###################################################################################


class NodeInfo:

    def __init__(self, expr: sp.Expr, x: sp.Symbol):
        self.expr = expr
        self.x = x


class NestedFnTree:

    def __init__(self, expr: sp.Expr, x: sp.Symbol):
        self.expr = expr
        self.x = x

        self._id_to_node = {
            id(node): node
            for node in sp.preorder_traversal(expr)
        }
        self._id_to_children_ids = self._compute_tree_info()

    def _compute_tree_info(self):
        id_to_children_ids = {}
        context = []

        def _fn(node, context):
            node_id = id(node)

            if node.func in SUPPORTED_FNS:
                if node_id not in id_to_children_ids:
                    id_to_children_ids[node_id] = []
                for parent_id in context:
                    id_to_children_ids[parent_id].append(node_id)

                context.append(node_id)

            for child in node.args:
                _fn(child, context)

            if node.func in SUPPORTED_FNS:
                assert context.pop() == node_id

        _fn(self.expr, context)

        return id_to_children_ids

    def has_nested_fn_types(self, outer_type: str, inner_type: str) -> bool:
        pass


# def compute_nested_fn_tree(expr: sp.Expr):
#     pass

###################################################################################


# class FnTypeComputer:
#     def __init__(self, expr: sp.Expr, x: sp.Symbol):
#         self.expr = expr
#         self.x = x

#         self._set_up()

#     def _set_up(self):
#         # id_to_node, id_to_is_constant
#         pass


###########################################################


def compute_all_fn_types(expr: sp.Expr) -> Set[str]:
    ret = set()
    for node in sp.preorder_traversal(expr):
        ret.update(FN_TO_TYPES_SET.get(node.func, ()))
    return ret


def compute_all_nested_fn_types(expr: sp.Expr):

    def _fn(expr):
        pass
    pass


def compute_all_directly_nested_fn_types(expr: sp.Expr):
    pass


# Stuff where fn(var) vs fn(constant)