"""
Functions that work on nested structures of torch.Tensor or numpy array
"""
from collections.abc import Sequence
import numpy as np
import torch
import tree
def any_describe_str(x, shape_only=False):
    """
    Describe type, shape, device, data type (of np array/tensor)
    Very useful for debugging
    """
    t = type(x)
    tname = type(x).__name__
    if isinstance(x, np.ndarray):
        shape = list(x.shape)
        if x.size == 1:
            if shape_only:
                return f"np scalar: {x.item()} {shape}"
            else:
                return f"np scalar: {x.item()} {shape} {x.dtype}"
        else:
            if shape_only:
                return f"np: {shape}"
            else:
                return f"np: {shape} {x.dtype}"
    elif torch.is_tensor(x):
        shape = list(x.size())
        if x.numel() == 1:
            if shape_only:
                return f"torch scalar: {x.item()} {shape}"
            else:
                return f"torch scalar: {x.item()} {shape} {x.dtype} {x.device}"
        else:
            if shape_only:
                return f"torch: {shape}"
            else:
                return f"torch: {shape} {x.dtype} {x.device}"
    elif isinstance(x, str):
        return x
    elif isinstance(x, Sequence):
        return f"{tname}[{len(x)}]"
    elif x is None:
        return "None"
    elif np.issubdtype(t, np.number) or np.issubdtype(t, np.bool_):
        return f"{tname}: {x}"
    else:
        return f"{tname}"
def any_describe(x, msg="", *, shape_only=False):
    from pprint import pprint
    if isinstance(x, str) and msg != "":
        x, msg = msg, x
    if msg:
        msg += ": "
    print(msg, end="")
    pprint(tree.map_structure(lambda i: any_describe_str(i, shape_only=shape_only), x))
