import logging
from typing import Dict, Any
import mlflow
from pathlib import Path
from hydra.core.hydra_config import HydraConfig


class MLflowHandler(logging.Handler):
    """Custom handler to send logs to MLflow and write to file."""

    def __init__(self, formatter_config: Dict[str, Any]):
        super().__init__()
        print("Initializing MLflowHandler")

        # Get handler config - base directory is already resolved by Hydra
        self.log_base_dir = Path(formatter_config.get("log_base_dir", "logs"))
        self.buffer_size = formatter_config.get("buffer_size", 100)

        # Get allowed loggers from config
        self.allowed_loggers = formatter_config.get("loggers", [])
        if isinstance(self.allowed_loggers, str):
            self.allowed_loggers = [self.allowed_loggers]
        print(f"MLflowHandler will handle logs from: {self.allowed_loggers}")
        print(f"Using log base directory: {self.log_base_dir}")

        # Create log directory if it doesn't exist
        self.log_base_dir.mkdir(parents=True, exist_ok=True)

        # Track logs per MLflow run
        self._run_logs = {}  # {run_id: [messages]}
        self._run_log_files = {}  # {run_id: Path}

        # Set up formatter from config
        formatter = logging.Formatter(
            fmt=formatter_config.get("format", "%(message)s"),
            datefmt=formatter_config.get("datefmt"),
        )
        self.setFormatter(formatter)

    def _should_handle_log(self, logger_name: str) -> bool:
        """Check if this logger should be handled based on config."""
        return any(logger_name.startswith(allowed) for allowed in self.allowed_loggers)

    def _create_log_path(self, run_id: str) -> Path:
        """Create a log file path for a specific MLflow run."""
        return self.log_base_dir / f"haipr_{run_id}.log"

    def emit(self, record):
        """Emit a log record."""
        if not self._should_handle_log(record.name):
            return

        try:
            msg = self.format(record)

            # Handle MLflow logging if we're in an active run
            active_run = mlflow.active_run()
            if active_run:
                run_id = active_run.info.run_id
                
                # Initialize run tracking if this is the first log for this run
                if run_id not in self._run_logs:
                    self._run_logs[run_id] = []
                    self._run_log_files[run_id] = self._create_log_path(run_id)

                # Write to run-specific file immediately
                try:
                    with open(self._run_log_files[run_id], "a") as f:
                        f.write(msg + "\n")
                except Exception as e:
                    print(f"Error writing to {self._run_log_files[run_id]}: {e}")

                # Add to current run's messages
                self._run_logs[run_id].append(msg)

                # Flush if buffer is getting full for this run
                if len(self._run_logs[run_id]) >= self.buffer_size:
                    self.flush(run_id)

        except Exception as e:
            print(f"Error in emit: {e}")
            self.handleError(record)

    def flush(self, run_id: str = None):
        """Flush the log buffers to MLflow for a specific run."""
        if run_id is None:
            # If no run_id specified, flush all runs
            for rid in list(self._run_logs.keys()):
                self.flush(rid)
            return

        if run_id not in self._run_logs or not self._run_logs[run_id]:
            return

        try:
            # Read the current file content for this run
            log_file = self._run_log_files[run_id]
            if log_file.exists():
                with open(log_file, "r") as f:
                    content = f.read()

                # Try to use existing run or reactivate if needed
                active_run = mlflow.active_run()
                if active_run and active_run.info.run_id == run_id:
                    mlflow.log_text(content, "haipr.log")

                # Clear buffer after successful upload
                self._run_logs[run_id] = []
        except Exception as e:
            print(f"Error in flush for run {run_id}: {e}")
            if run_id in self._run_logs:
                self._run_logs[run_id] = []

    def close(self):
        """Flush and close the handler."""
        try:
            # Flush all runs before closing
            self.flush()
        finally:
            super().close()


class JobContextFilter(logging.Filter):
    """Filter that adds job number to all log records."""

    def filter(self, record):
        if not hasattr(record, "job_num"):
            try:
                if HydraConfig.initialized():
                    record.job_num = HydraConfig.get().job.num
                else:
                    record.job_num = 0
            except Exception:  # Catch any exception but handle it explicitly
                record.job_num = 0
        return True
