import logging
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional


def setup_logging(
    log_dir: str = "logs",
    log_level: str = "INFO",
    experiment_name: Optional[str] = None,
    console_output: bool = True
) -> logging.Logger:
    """
    Set up centralized logging for the torchcfm project.
    
    Parameters
    ----------
    log_dir : str
        Directory to store log files
    log_level : str
        Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
    experiment_name : Optional[str]
        Name for the experiment (used in log filename)
    console_output : bool
        Whether to also output to console
        
    Returns
    -------
    logging.Logger
        Configured logger instance
    """
    # Create log directory
    log_path = Path(log_dir)
    log_path.mkdir(exist_ok=True)
    
    # Generate log filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    if experiment_name:
        log_filename = f"{experiment_name}_{timestamp}.log"
    else:
        log_filename = f"torchcfm_{timestamp}.log"
    
    log_file = log_path / log_filename
    
    # Create logger
    logger = logging.getLogger("torchcfm")
    logger.setLevel(getattr(logging, log_level.upper()))
    
    # Clear existing handlers
    logger.handlers.clear()
    
    # Create formatters
    detailed_formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
    )
    simple_formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s - %(message)s'
    )
    
    # File handler (detailed)
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(detailed_formatter)
    logger.addHandler(file_handler)
    
    # Console handler (simple)
    if console_output:
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setLevel(getattr(logging, log_level.upper()))
        console_handler.setFormatter(simple_formatter)
        logger.addHandler(console_handler)
    
    # Log initial setup
    logger.info(f"Logging initialized. Log file: {log_file}")
    logger.info(f"Log level: {log_level}")
    
    return logger


def get_logger(name: str = "torchcfm") -> logging.Logger:
    """
    Get a logger instance with the specified name.
    
    Parameters
    ----------
    name : str
        Logger name
        
    Returns
    -------
    logging.Logger
        Logger instance
    """
    return logging.getLogger(name)


def log_numerical_error(
    logger: logging.Logger,
    error_type: str,
    details: dict,
    context: str = ""
):
    """
    Log numerical errors with structured information.
    
    Parameters
    ----------
    logger : logging.Logger
        Logger instance
    error_type : str
        Type of numerical error (e.g., "OT_PLAN_ZERO", "DIVIDE_BY_ZERO")
    details : dict
        Error details (cost matrix info, parameters, etc.)
    context : str
        Additional context information
    """
    error_msg = f"NUMERICAL_ERROR: {error_type}"
    if context:
        error_msg += f" | Context: {context}"
    
    logger.error(error_msg)
    logger.error(f"Error details: {details}")
    
    # Log additional debugging info
    logger.debug(f"Full error context: {error_type} in {context}")
    for key, value in details.items():
        logger.debug(f"  {key}: {value}")


def log_ot_plan_info(
    logger: logging.Logger,
    plan_sum: float,
    cost_matrix_info: dict,
    parameters: dict
):
    """
    Log OT plan information for debugging.
    
    Parameters
    ----------
    logger : logging.Logger
        Logger instance
    plan_sum : float
        Sum of the OT plan
    cost_matrix_info : dict
        Information about the cost matrix
    parameters : dict
        OT parameters used
    """
    logger.info(f"OT Plan sum: {plan_sum:.2e}")
    logger.info(f"Cost matrix info: {cost_matrix_info}")
    logger.info(f"OT parameters: {parameters}")
    
    if plan_sum < 1e-8:
        logger.warning("OT plan sum is very small - potential numerical instability")
