from typing import Union, Optional
# from datetime import datetime
import logging as py_logging
import sys
import os

__all__ = ["set_logger"]


# def time_log() -> str:
#     a = datetime.now()
#     return f"-" * 72 + f"     {a.year:>4}-{a.month:>2}-{a.day:>2} {a.hour:>2}:{a.minute:>2}:{a.second:>2}"


class PadoFormatter(py_logging.Formatter):
    """Formatter for PadoLogger."""

    DEFAULT_FORMAT = (
        "%(color)s"
        "------------------------------------------------------------------------ "
        "[%(levelname)1.1s] %(asctime)s %(rank)s "
        "%(end_color)s\n"
        "%(msg)s"
    )
    DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"

    # see ASCII color code
    DEFAULT_COLORS = {
        py_logging.DEBUG: "\033[34m",
        py_logging.INFO: "\033[32m",
        py_logging.WARNING: "\033[33m",
        py_logging.ERROR: "\033[35m",
        py_logging.CRITICAL: "\033[31m",
    }

    def __init__(self, rank: int = 0, colored: bool = True):
        super().__init__(fmt=self.DEFAULT_FORMAT, datefmt=self.DEFAULT_DATE_FORMAT)
        self.rank = rank
        self._colored = colored

    def format(self, record) -> str:
        record.asctime = self.formatTime(record, self.datefmt)
        if (record.levelno in self.DEFAULT_COLORS) and self._colored:
            record.color = self.DEFAULT_COLORS[record.levelno]
            record.end_color = "\033[0m"  # reset
        else:
            record.color = ""
            record.end_color = ""

        if self.rank > 0:
            record.rank = "RANK:" + str(self.rank)
        else:
            record.rank = ""

        formatted = self.DEFAULT_FORMAT % record.__dict__
        if formatted[-1] == "\n":
            formatted = formatted[:-1]
        return formatted


class PadoLogger(py_logging.Logger):
    """Logger for Pado framework."""

    NOTSET = py_logging.NOTSET  # 0
    DEBUG = py_logging.DEBUG  # 10
    INFO = py_logging.INFO  # 20
    WARNING = py_logging.WARNING  # 30
    ERROR = py_logging.ERROR  # 40
    CRITICAL = py_logging.CRITICAL  # 50

    def __init__(self, name: str, level: Union[int, str] = INFO) -> None:
        super().__init__(name=name, level=level)
        self.propagate = False

    def get_level(self) -> int:
        # wrapper of logger
        return self.getEffectiveLevel()

    def set_level(self, level: Union[int, str]) -> None:
        # wrapper of logger
        return self.setLevel(level)


py_logging.setLoggerClass(PadoLogger)


class _ErrorFilter(py_logging.Filter):
    def filter(self, record) -> bool:
        return record.levelno < PadoLogger.ERROR


def set_logger(log_dir: str, rank: int = 0, world_size: int = 1,
               *, name: Optional[str] = "pado",
               level: Union[int, str] = PadoLogger.INFO,
               colored: bool = True,
               stdout_all: bool = False):
    """Setup logger configuration."""
    py_logging.setLoggerClass(PadoLogger)
    base_logger = py_logging.getLogger(name=name)

    if world_size <= 1:
        log_file = os.path.join(log_dir, "out.log")
        base_logger.propagate = True  # is this right?
    else:
        log_file = os.path.join(log_dir, f"out_rank{rank}.log")

    formatter = PadoFormatter(rank, colored=colored)
    # force set level
    if os.environ.get("PADO_DEBUG"):  # CLI PADO_DEBUG=1 ...
        level = PadoLogger.DEBUG

    file_handler = py_logging.FileHandler(log_file)
    # set formatters
    file_handler.setFormatter(formatter)
    # prepare handlers
    base_logger.handlers.clear()

    # set level
    file_handler.setLevel(level)
    base_logger.setLevel(level)
    # add handlers
    base_logger.addHandler(file_handler)

    # add stdout, stderr handlers
    if (rank == 0) or stdout_all:
        stdout_handler = py_logging.StreamHandler(sys.stdout)
        stderr_handler = py_logging.StreamHandler(sys.stderr)

        stdout_handler.setFormatter(formatter)
        stderr_handler.setFormatter(formatter)

        stdout_handler.setLevel(level)  # above level and below error
        stdout_handler.addFilter(_ErrorFilter())
        stderr_handler.setLevel(PadoLogger.ERROR)  # only error

        base_logger.addHandler(stdout_handler)
        base_logger.addHandler(stderr_handler)

    return base_logger
