import logging
from typing import Dict, Any
import mlflow
from pathlib import Path
from hydra.core.hydra_config import HydraConfig


class MLflowHandler(logging.Handler):
    """Custom handler to send logs to MLflow and write to file."""

    def __init__(self, formatter_config: Dict[str, Any]):
        super().__init__()
        print("Initializing MLflowHandler")

        # Get handler config - path is already resolved by Hydra

        self.log_path = Path(formatter_config.get("log_path", "logs/haipr.log"))
        self.buffer_size = formatter_config.get("buffer_size", 100)

        # Get allowed loggers from config
        self.allowed_loggers = formatter_config.get("loggers", [])
        if isinstance(self.allowed_loggers, str):
            self.allowed_loggers = [self.allowed_loggers]
        print(f"MLflowHandler will handle logs from: {self.allowed_loggers}")
        print(f"Using log file: {self.log_path}")

        # Create log directory if it doesn't exist
        self.log_path.parent.mkdir(parents=True, exist_ok=True)

        # Keep track of all messages for the current run
        self._current_messages = []
        self._current_run_id = None

        # Set up formatter from config
        formatter = logging.Formatter(
            fmt=formatter_config.get("format", "%(message)s"),
            datefmt=formatter_config.get("datefmt"),
        )
        self.setFormatter(formatter)

    def _should_handle_log(self, logger_name: str) -> bool:
        """Check if this logger should be handled based on config."""
        return any(logger_name.startswith(allowed) for allowed in self.allowed_loggers)

    def emit(self, record):
        """Emit a log record."""
        if not self._should_handle_log(record.name):
            return

        try:
            msg = self.format(record)

            # Write to file
            try:
                with open(self.log_path, "a") as f:
                    f.write(msg + "\n")
            except Exception as e:
                print(f"Error writing to {self.log_path}: {e}")

            # Handle MLflow logging if we're in an active run
            active_run = mlflow.active_run()
            if active_run:
                # Store run ID when we first see it
                if not self._current_run_id and active_run.info.run_id:
                    self._current_run_id = active_run.info.run_id

                # Add to current messages
                self._current_messages.append(msg)

                # Flush if buffer is getting full
                if len(self._current_messages) >= self.buffer_size:
                    self.flush()

        except Exception as e:
            print(f"Error in emit: {e}")
            self.handleError(record)

    def flush(self):
        """Flush the log buffers to MLflow."""
        if not self._current_messages:
            return

        try:
            # Read the current file content
            with open(self.log_path, "r") as f:
                content = f.read()

            # Try to use existing run or reactivate if needed
            active_run = mlflow.active_run()
            if active_run:
                mlflow.log_text(content, "haipr.log")
            elif self._current_run_id:
                with mlflow.start_run(run_id=self._current_run_id):
                    mlflow.log_text(content, "haipr.log")

            # Clear buffer after successful upload
            self._current_messages = []
        except Exception as e:
            print(f"Error in flush: {e}")
            self._current_messages = []

    def close(self):
        """Flush and close the handler."""
        try:
            if self._current_messages:
                self.flush()
        finally:
            super().close()


class JobContextFilter(logging.Filter):
    """Filter that adds job number to all log records."""

    def filter(self, record):
        if not hasattr(record, "job_num"):
            try:
                if HydraConfig.initialized():
                    record.job_num = HydraConfig.get().job.num
                else:
                    record.job_num = 0
            except Exception:  # Catch any exception but handle it explicitly
                record.job_num = 0
        return True
