import logging
import os
from datetime import datetime
from typing import Optional


def set_logger(logger_name, output_dir: str, log_filename: str, log_level: str = "INFO"):
    logger = logging.getLogger(logger_name)
    
    # Clear any existing handlers to avoid duplicates
    logger.handlers.clear()
    
    # Set logger level to allow all messages
    logger.setLevel(logging.DEBUG)
    
    # Prevent propagation to root logger to avoid duplicate messages
    logger.propagate = False
    
    # Create file handler which logs all messages
    os.makedirs(output_dir, exist_ok=True)
    fh = logging.FileHandler(os.path.join(output_dir, f'{log_filename}.log'))
    fh.setLevel(logging.DEBUG)  # Log all levels to file
    
    # Create console handler with INFO level (shows info, warning, error, critical)
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)  # Show info and above in console
    
    # Create formatters
    file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    console_formatter = logging.Formatter('%(levelname)s - %(message)s')
    
    fh.setFormatter(file_formatter)
    ch.setFormatter(console_formatter)
    
    # Add the handlers to logger
    logger.addHandler(ch)
    logger.addHandler(fh)
    
    return logger


class SharedLogger:
    """
    Static shared logger class that centralizes logging configuration
    and returns loggers that all write to the same directory/file.

    - Configure once via SharedLogger.configure(...), or via env vars:
      PRUNER_LOG_DIR, PRUNER_LOG_FILE
    - Subsequent get_logger calls reuse the same output_dir/log_filename
    - Uses the formatting and handler setup in set_logger above
    """

    _log_dir: str = os.getenv("PRUNER_LOG_DIR", "logs")
    _log_filename: str = os.getenv("PRUNER_LOG_FILE", "pruning")

    @classmethod
    def configure(cls, log_dir: Optional[str] = None, log_filename: Optional[str] = None) -> None:
        if log_dir is not None:
            cls._log_dir = log_dir
        if log_filename is not None:
            cls._log_filename = log_filename

    @classmethod
    def get_logger(cls, name: Optional[str] = None) -> logging.Logger:
        logger_name = name or "propagation_pruner"
        return set_logger(logger_name, output_dir=cls._log_dir, log_filename=cls._log_filename)