import time
import logging
import os
import sys

class MultilineFormatter(logging.Formatter):
    def format(self, record):
        # Get the original formatted message
        original_message = super().format(record)
        
        # Split the message into lines
        lines = original_message.split('\n')
        
        if len(lines) <= 1:
            return original_message
        
        # For multi-line messages, format each line with the prefix
        formatted_lines = []
        for i, line in enumerate(lines):
            if i == 0:
                # First line already has the full format
                formatted_lines.append(line)
            else:
                # Subsequent lines need to be formatted with the same prefix
                # Create a new record for consistent formatting
                new_record = logging.LogRecord(
                    name=record.name,
                    level=record.levelno,
                    pathname=record.pathname,
                    lineno=record.lineno,
                    msg=line,
                    args=(),
                    exc_info=None
                )
                new_record.created = record.created
                formatted_line = super().format(new_record)
                formatted_lines.append(formatted_line)
        
        return '\n'.join(formatted_lines)

class CondaFilter(logging.Filter):
    def filter(self, record):
        if 'optuna' in record.name:
            return True
        return 'envs' not in record.pathname

conda_filter = CondaFilter()

os.makedirs("log", exist_ok=True)

slurm_job_id = os.getenv('SLURM_JOB_ID')
if slurm_job_id:
    job_name = os.getenv('SLURM_JOB_NAME', 'unnamed_job')
    log_filename = f"{job_name}_{slurm_job_id}.out"
else:
    from datetime import datetime
    current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    log_filename = f"{current_time}.log"

file_handler = logging.FileHandler(f"log/{log_filename}")
stream_handler = logging.StreamHandler(sys.stdout)

file_handler.addFilter(conda_filter)
stream_handler.addFilter(conda_filter)

formatter = MultilineFormatter(
    fmt='(%(asctime)s %(filename)s@%(lineno)d %(levelname)s) %(message)s',
    datefmt='%H:%M:%S'
)

file_handler.setFormatter(formatter)
stream_handler.setFormatter(formatter)

logging.basicConfig(
    level=logging.DEBUG,
    handlers=[file_handler, stream_handler]
)

optuna_logger = logging.getLogger("optuna")
optuna_logger.handlers = []
optuna_logger.addHandler(file_handler)
optuna_logger.addHandler(stream_handler)
optuna_logger.setLevel(logging.INFO)

logger = logging.getLogger(__name__)

def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logger.critical("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))

sys.excepthook = handle_exception
