# utils/tensorboard_utils.py
import os
from torch.utils.tensorboard import SummaryWriter
import logging


class UnifiedLogger:
    """Unified logger that supports text logging and TensorBoard"""

    def __init__(self, log_dir, name='experiment'):
        self.log_dir = log_dir
        self.name = name

        # Create log directory
        os.makedirs(log_dir, exist_ok=True)

        # TensorBoard writer
        self.writer = SummaryWriter(log_dir=os.path.join(log_dir, name))

        # Text logging
        log_file = os.path.join(log_dir, f'{name}.log')
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(name)

    def add_scalar(self, tag, scalar_value, global_step=None):
        """Add scalar data to TensorBoard"""
        self.writer.add_scalar(tag, scalar_value, global_step)

    def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
        """Add multiple scalar data to TensorBoard"""
        self.writer.add_scalars(main_tag, tag_scalar_dict, global_step)

    def add_histogram(self, tag, values, global_step=None):
        """Add histogram data to TensorBoard"""
        self.writer.add_histogram(tag, values, global_step)

    def add_image(self, tag, img_tensor, global_step=None):
        """Add image data to TensorBoard"""
        self.writer.add_image(tag, img_tensor, global_step)

    def info(self, message):
        """Record info level log"""
        self.logger.info(message)

    def warning(self, message):
        """Record warning level log"""
        self.logger.warning(message)

    def error(self, message):
        """Record error level log"""
        self.logger.error(message)

    def close(self):
        """Close the logger"""
        self.writer.close()

    def __del__(self):
        """Destructor"""
        try:
            self.writer.close()
        except:
            pass