import logging
import torch.distributed as dist


def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    # Ensure distributed is initialized, if not, default to rank 0 behavior
    rank = dist.get_rank() if dist.is_initialized() else 0
    print(f"Rank detected as: {rank}")  # Diagnostic print
    if rank == 0:  # real logger
        logging.basicConfig(
            level=logging.INFO,
            format='[\033[34m%(asctime)s\033[0m] %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
        )
        logger = logging.getLogger(__name__)
    else:  # dummy logger (does nothing)
        logger = logging.getLogger(__name__)
        logger.addHandler(logging.NullHandler())

    return logger
