import logging
import os
from typing import TYPE_CHECKING, Union

from .constants import FINETRAINERS_LOG_LEVEL


if TYPE_CHECKING:
    from .parallel import ParallelBackendType


class FinetrainersLoggerAdapter(logging.LoggerAdapter):
    def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None:
        super().__init__(logger, {})
        self.parallel_backend = parallel_backend
        self._log_freq = {}
        self._log_freq_counter = {}

    def log(
        self,
        level,
        msg,
        *args,
        main_process_only: bool = False,
        local_main_process_only: bool = True,
        in_order: bool = False,
        **kwargs,
    ):
        # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
        kwargs.setdefault("stacklevel", 2)

        if not self.isEnabledFor(level):
            return

        if self.parallel_backend is None:
            if int(os.environ.get("RANK", 0)) == 0:
                msg, kwargs = self.process(msg, kwargs)
                self.logger.log(level, msg, *args, **kwargs)
            return

        if (main_process_only or local_main_process_only) and in_order:
            raise ValueError(
                "Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True."
            )

        if (main_process_only and self.parallel_backend.is_main_process) or (
            local_main_process_only and self.parallel_backend.is_local_main_process
        ):
            msg, kwargs = self.process(msg, kwargs)
            self.logger.log(level, msg, *args, **kwargs)
            return

        if in_order:
            for i in range(self.parallel_backend.world_size):
                if self.rank == i:
                    msg, kwargs = self.process(msg, kwargs)
                    self.logger.log(level, msg, *args, **kwargs)
                self.parallel_backend.wait_for_everyone()
            return

        if not main_process_only and not local_main_process_only:
            msg, kwargs = self.process(msg, kwargs)
            self.logger.log(level, msg, *args, **kwargs)
            return

    def log_freq(
        self,
        level: str,
        name: str,
        msg: str,
        frequency: int,
        *,
        main_process_only: bool = False,
        local_main_process_only: bool = True,
        in_order: bool = False,
        **kwargs,
    ) -> None:
        if frequency <= 0:
            return
        if name not in self._log_freq_counter:
            self._log_freq[name] = frequency
            self._log_freq_counter[name] = 0
        if self._log_freq_counter[name] % self._log_freq[name] == 0:
            self.log(
                level,
                msg,
                main_process_only=main_process_only,
                local_main_process_only=local_main_process_only,
                in_order=in_order,
                **kwargs,
            )
        self._log_freq_counter[name] += 1


def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]:
    global _logger
    return _logger


def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter:
    _logger.parallel_backend = parallel_backend


_logger = logging.getLogger("finetrainers")
_logger.setLevel(FINETRAINERS_LOG_LEVEL)
_console_handler = logging.StreamHandler()
_console_handler.setLevel(FINETRAINERS_LOG_LEVEL)
_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
_console_handler.setFormatter(_formatter)
_logger.addHandler(_console_handler)
_logger = FinetrainersLoggerAdapter(_logger)
