# utils/logger.py
import os
import logging


def setup_logger(log_dir: str, name: str = None, level: int = logging.INFO) -> logging.Logger:
    """
    Create a logger that outputs to both console and file.
    Args:
        log_dir: Directory for log files, will be created automatically.
        name:    Logger name, uses root logger if None.
        level:   Logging level, default is INFO.
    """
    os.makedirs(log_dir, exist_ok=True)
    logger = logging.getLogger(name)
    logger.setLevel(level)

    # Avoid adding duplicate handlers
    if logger.handlers:
        logger.handlers.clear()

    formatter = logging.Formatter(
        fmt="%(asctime)s %(levelname)s [%(name)s] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S"
    )

    # Console output
    ch = logging.StreamHandler()
    ch.setLevel(level)
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    # File output
    log_file = os.path.join(log_dir, f"{name or 'train'}.log")
    fh = logging.FileHandler(log_file, mode="a", encoding="utf-8")
    fh.setLevel(level)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    return logger


class UnifiedLogger:
    """Unified logger that supports text logging and TensorBoard"""

    def __init__(self, log_dir, name='experiment'):
        self.log_dir = log_dir
        self.name = name

        # Create log directory
        os.makedirs(log_dir, exist_ok=True)

        # Initialize TensorBoard writer (optional)
        self.writer = None
        try:
            from torch.utils.tensorboard import SummaryWriter
            tensorboard_dir = os.path.join(log_dir, 'tensorboard', name)
            self.writer = SummaryWriter(log_dir=tensorboard_dir)
            print(f"✓ TensorBoard logging enabled: {tensorboard_dir}")
        except ImportError:
            print("⚠ TensorBoard not available, only text logging will be used")
        except Exception as e:
            print(f"⚠ TensorBoard initialization failed: {e}")

        # Setup text logging
        self.logger = setup_logger(log_dir, name)

    def add_scalar(self, tag, scalar_value, global_step=None):
        """Add scalar data to TensorBoard"""
        if self.writer:
            try:
                self.writer.add_scalar(tag, scalar_value, global_step)
            except Exception as e:
                self.logger.warning(f"Failed to log scalar {tag}: {e}")

        # Also record to text log
        step_info = f" (step {global_step})" if global_step is not None else ""
        self.logger.info(f"METRIC [{tag}]: {scalar_value:.6f}{step_info}")

    def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
        """Add multiple scalar data to TensorBoard"""
        if self.writer:
            try:
                self.writer.add_scalars(main_tag, tag_scalar_dict, global_step)
            except Exception as e:
                self.logger.warning(f"Failed to log scalars {main_tag}: {e}")

        # Record to text log
        step_info = f" (step {global_step})" if global_step is not None else ""
        for tag, value in tag_scalar_dict.items():
            self.logger.info(f"METRIC [{main_tag}/{tag}]: {value:.6f}{step_info}")

    def add_histogram(self, tag, values, global_step=None):
        """Add histogram data to TensorBoard"""
        if self.writer:
            try:
                self.writer.add_histogram(tag, values, global_step)
            except Exception as e:
                self.logger.warning(f"Failed to log histogram {tag}: {e}")

    def add_image(self, tag, img_tensor, global_step=None):
        """Add image data to TensorBoard"""
        if self.writer:
            try:
                self.writer.add_image(tag, img_tensor, global_step)
            except Exception as e:
                self.logger.warning(f"Failed to log image {tag}: {e}")

    def info(self, message):
        """Record info level log"""
        self.logger.info(message)

    def warning(self, message):
        """Record warning level log"""
        self.logger.warning(message)

    def error(self, message):
        """Record error level log"""
        self.logger.error(message)

    def debug(self, message):
        """Record debug level log"""
        self.logger.debug(message)

    def close(self):
        """Close the logger"""
        if self.writer:
            try:
                self.writer.close()
            except:
                pass

    def flush(self):
        """Flush the buffer"""
        if self.writer:
            try:
                self.writer.flush()
            except:
                pass

        # Flush text log handlers
        for handler in self.logger.handlers:
            try:
                handler.flush()
            except:
                pass

    def __del__(self):
        """Destructor"""
        self.close()


# Compatible alias
def create_logger(log_dir, name='train'):
    """Convenience function to create a logger"""
    return UnifiedLogger(log_dir, name)


# Test code when imported as a module
if __name__ == "__main__":
    # Test logger functionality
    logger = UnifiedLogger("test_logs", "test")

    logger.info("Test info message")
    logger.warning("Test warning message")
    logger.add_scalar("test/loss", 0.5, 1)
    logger.add_scalars("test/metrics", {"accuracy": 0.85, "f1": 0.82}, 1)

    print("Logger test completed")
    logger.close()