import contextlib
import os

__all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]


class CloudpickleWrapper:
    def __init__(self, fn):
        self.fn = fn

    def __getstate__(self):
        import cloudpickle

        return cloudpickle.dumps(self.fn)

    def __setstate__(self, ob):
        import pickle

        self.fn = pickle.loads(ob)

    def __call__(self):
        return self.fn()


@contextlib.contextmanager
def clear_mpi_env_vars():
    """
    `from mpi4py import MPI` will call `MPI_Init` by default. If the child
    process has MPI environment variables, MPI will think that the child process
    is an MPI process just like the parent and do bad things such as hang.

    This context manager is a hacky way to clear those environment variables
    temporarily such as when we are starting multiprocessing Processes.
    """
    removed_environment = {}
    for k, v in list(os.environ.items()):
        for prefix in ["OMPI_", "PMI_"]:
            if k.startswith(prefix):
                removed_environment[k] = v
                del os.environ[k]
    try:
        yield
    finally:
        os.environ.update(removed_environment)
