

import logging
import os

import torch

def set_basic_config(level):
    logging.basicConfig(format="%(levelname)s:%(asctime)s:%(message)s", level=level)

def log_to_file(string):
    print(string)
    if os.path.isdir("logs"):
        with open(f"logs/log_{torch.distributed.get_rank()}", "a+") as f:
            f.write(string + "\n")
