import os
import logging
import importlib

logger = logging.getLogger(__name__)

MESHFLOW_BACKEND = None

__all__ = [
    "add", "equal", "zeros_like", "min", "max", "allclose", "concatenate", "chunk", "narrow",
    "Tensor", "tree_flatten", "tree_unflatten", "clone", "from_numpy"
]


def backend_valid(_backend):
    return _backend in {"torch", "jax", "tvm"}


def init_backend(backend="torch"):
    assert backend_valid(backend)
    global MESHFLOW_BACKEND
    MESHFLOW_BACKEND = backend
    modules = importlib.import_module("." + backend, __name__)
    for val in __all__:
        exec("globals()['%s'] = modules.%s" % (val, val))
    logger.info(f"========= MeshFlow init with backend {backend}. =========")


def get_backend():
    global MESHFLOW_BACKEND
    return MESHFLOW_BACKEND


init_backend(os.environ.get("MESHFLOW_BACKEND", "torch"))
