"""
Logging utilities for the Fisher dimension framework.
"""

from __future__ import annotations

import logging
import sys
from pathlib import Path
from typing import Optional


# Store loggers
_loggers = {}


def setup_logging(
    level: str = "INFO",
    log_file: Optional[str] = None,
    format_str: Optional[str] = None,
    include_timestamps: bool = True
) -> logging.Logger:
    """
    Set up logging configuration.

    Args:
        level: Logging level (DEBUG, INFO, WARNING, ERROR)
        log_file: Optional file to write logs to
        format_str: Custom format string
        include_timestamps: Whether to include timestamps

    Returns:
        Root logger
    """
    # Default format
    if format_str is None:
        if include_timestamps:
            format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
        else:
            format_str = "%(name)s - %(levelname)s - %(message)s"

    # Convert level string to constant
    level_map = {
        'DEBUG': logging.DEBUG,
        'INFO': logging.INFO,
        'WARNING': logging.WARNING,
        'ERROR': logging.ERROR,
        'CRITICAL': logging.CRITICAL,
    }
    log_level = level_map.get(level.upper(), logging.INFO)

    # Create formatter
    formatter = logging.Formatter(format_str)

    # Get root logger
    root_logger = logging.getLogger('fisher_dimension')
    root_logger.setLevel(log_level)

    # Clear existing handlers
    root_logger.handlers.clear()

    # Console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(log_level)
    console_handler.setFormatter(formatter)
    root_logger.addHandler(console_handler)

    # File handler if specified
    if log_file:
        log_path = Path(log_file)
        log_path.parent.mkdir(parents=True, exist_ok=True)

        file_handler = logging.FileHandler(log_path)
        file_handler.setLevel(log_level)
        file_handler.setFormatter(formatter)
        root_logger.addHandler(file_handler)

    return root_logger


def get_logger(name: str = "fisher_dimension") -> logging.Logger:
    """
    Get a logger with the specified name.

    Args:
        name: Logger name (will be prefixed with 'fisher_dimension.')

    Returns:
        Logger instance
    """
    if not name.startswith('fisher_dimension'):
        name = f'fisher_dimension.{name}'

    if name not in _loggers:
        logger = logging.getLogger(name)
        _loggers[name] = logger

    return _loggers[name]


def log_experiment_start(
    experiment_name: str,
    config: dict,
    logger: Optional[logging.Logger] = None
) -> None:
    """
    Log the start of an experiment.

    Args:
        experiment_name: Name of the experiment
        config: Experiment configuration
        logger: Logger to use (default: root logger)
    """
    if logger is None:
        logger = get_logger()

    logger.info("=" * 60)
    logger.info(f"Starting experiment: {experiment_name}")
    logger.info("=" * 60)

    # Log key config parameters
    for key, value in config.items():
        if not isinstance(value, dict):
            logger.info(f"  {key}: {value}")


def log_experiment_end(
    experiment_name: str,
    duration: float,
    success: bool = True,
    logger: Optional[logging.Logger] = None
) -> None:
    """
    Log the end of an experiment.

    Args:
        experiment_name: Name of the experiment
        duration: Duration in seconds
        success: Whether experiment completed successfully
        logger: Logger to use
    """
    if logger is None:
        logger = get_logger()

    status = "COMPLETED" if success else "FAILED"
    logger.info("-" * 60)
    logger.info(f"Experiment {experiment_name} {status}")
    logger.info(f"Duration: {duration:.2f} seconds")
    logger.info("=" * 60)


def log_progress(
    current: int,
    total: int,
    prefix: str = "",
    logger: Optional[logging.Logger] = None
) -> None:
    """
    Log progress update.

    Args:
        current: Current iteration
        total: Total iterations
        prefix: Prefix message
        logger: Logger to use
    """
    if logger is None:
        logger = get_logger()

    percentage = 100 * current / total if total > 0 else 0
    logger.debug(f"{prefix} Progress: {current}/{total} ({percentage:.1f}%)")


class ExperimentLogger:
    """Context manager for experiment logging."""

    def __init__(
        self,
        experiment_name: str,
        config: dict,
        logger: Optional[logging.Logger] = None
    ):
        self.experiment_name = experiment_name
        self.config = config
        self.logger = logger or get_logger()
        self.start_time = None
        self.success = True

    def __enter__(self):
        import time
        self.start_time = time.time()
        log_experiment_start(self.experiment_name, self.config, self.logger)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        import time
        duration = time.time() - self.start_time
        self.success = exc_type is None
        log_experiment_end(self.experiment_name, duration, self.success, self.logger)
        return False  # Don't suppress exceptions

    def info(self, message: str) -> None:
        """Log info message."""
        self.logger.info(message)

    def debug(self, message: str) -> None:
        """Log debug message."""
        self.logger.debug(message)

    def warning(self, message: str) -> None:
        """Log warning message."""
        self.logger.warning(message)

    def error(self, message: str) -> None:
        """Log error message."""
        self.logger.error(message)
        self.success = False
