import logging
import os

from typing import (
    Literal,
    Optional,
)


class Logger():
    """ The Logger class.
    """

    def __init__(
        self,
        local_rank: int,
        outputs_dir: str,
        root_path: str,
        format: str = '%(asctime)s - %(levelname)s - %(message)s',
        log_mode: Literal['both', 'file', 'stream'] = 'both',
        log_name: Optional[str] = None,
    ) -> None:
        """ Initialize the logger.

        Args:
            local_rank (int): The local rank.
            outputs_dir (str): The outputs directory.
            root_path (str): The root path.
            format (str, optional): The log format. Defaults to '%(asctime)s - %(levelname)s - %(message)s'.
            log_mode (Literal['both', 'file', 'stream'], optional): The log mode. Defaults to 'both'.
            log_name (Optional[str], optional): The log name. Defaults to None.
        """

        # Initialize the logger.
        self.logger = logging.getLogger()
        self.logger.setLevel(level=logging.INFO)

        # Set the format.
        format = f'[rank{local_rank}]: {format}'
        formatter = logging.Formatter(fmt=format)

        # Set the log path.
        log_path = os.path.join(
            outputs_dir,
            'logs',
            f'rank{local_rank}_{log_name}.log' if log_name is not None \
                else f'rank{local_rank}.log',
        )

        # Set file handler.
        if log_mode == 'both' or log_mode == 'file':
            fh = logging.FileHandler(
                filename=log_path,
                mode='a',
                encoding='utf-8',
            )
            fh.setFormatter(fmt=formatter)
            fh.setLevel(level=logging.INFO)

            self.logger.addHandler(hdlr=fh)

        # Set stream handler.
        if (log_mode == 'both' or log_mode == 'stream') and local_rank == 0:
            sh = logging.StreamHandler()
            sh.setFormatter(fmt=formatter)
            sh.setLevel(level=logging.INFO)

            self.logger.addHandler(hdlr=sh)

        # Set the root path.
        self.root_path = root_path if root_path.endswith('/') \
            else (root_path + '/')

    def log(
        self,
        message: str,
        level: Literal['error', 'info', 'warning'] = 'info',
        source: Optional[str] = None,
    ) -> None:
        """ Log the message

        Args:
            message (str): The message.
            level (Literal['error', 'info', 'warning'], optional): The log level. Defaults to 'info'.
            source (Optional[str], optional): The source. Defaults to None.
        """

        if source is not None:
            message = f'{self.remove_root(source_path=source)} - {message}'

        match level:
            case 'error':
                self.logger.error(msg=message)
            case 'info':
                self.logger.info(msg=message)
            case 'warning':
                self.logger.warning(msg=message)
            case _:
                message = 'The log level is not supported.'

                self.logger.error(msg=message)

                raise ValueError(message)

    def remove_root(
        self,
        source_path: str,
    ) -> str:
        """ Remove the root path from the source path.

        Args:
            source_path (str): The source path.

        Returns:
            str: The source path without the root path.
        """

        return source_path.replace(
            self.root_path,
            '',
        )
