"""Unified logging module for the fine-tuning pipeline.

Provides real-time logging to both file and console with timestamps.
Also supports capturing all stdout/stderr to log file.
"""

import logging
import sys
from pathlib import Path
from typing import TextIO


class TeeWriter:
    """Write to multiple streams simultaneously (stdout/stderr + file).

    This captures ALL terminal output including third-party library output
    (like OpenHands SDK rich logging) and writes it to both terminal and file.
    """

    def __init__(self, original: TextIO, log_file: TextIO):
        self.original = original
        self.log_file = log_file

    def write(self, text: str) -> int:
        # Write to original stream (terminal)
        self.original.write(text)
        self.original.flush()
        # Write to log file
        self.log_file.write(text)
        self.log_file.flush()
        return len(text)

    def flush(self) -> None:
        self.original.flush()
        self.log_file.flush()

    def fileno(self) -> int:
        return self.original.fileno()

    def isatty(self) -> bool:
        return self.original.isatty()


class OutputCapture:
    """Capture all stdout/stderr to a log file.

    Can be used as context manager or with start()/stop() methods.

    Usage (context manager):
        with OutputCapture(log_file_path):
            print("This goes to both terminal and file")

    Usage (manual control):
        capture = OutputCapture(log_file_path)
        capture.start()
        print("This goes to both terminal and file")
        capture.stop()
    """

    def __init__(self, log_file: Path | str):
        self.log_file_path = Path(log_file)
        self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
        self._file = None
        self._original_stdout = None
        self._original_stderr = None
        self._active = False

    def start(self) -> "OutputCapture":
        """Start capturing stdout/stderr to file."""
        if self._active:
            return self
        self._file = open(self.log_file_path, "a", encoding="utf-8")
        self._original_stdout = sys.stdout
        self._original_stderr = sys.stderr
        sys.stdout = TeeWriter(self._original_stdout, self._file)
        sys.stderr = TeeWriter(self._original_stderr, self._file)
        self._active = True
        return self

    def stop(self) -> None:
        """Stop capturing and restore original streams."""
        if not self._active:
            return
        sys.stdout = self._original_stdout
        sys.stderr = self._original_stderr
        if self._file:
            self._file.close()
            self._file = None
        self._active = False

    def __enter__(self):
        return self.start()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()
        return False


def setup_logger(
    output_dir: Path | str,
    name: str = "pipeline",
    level: int = logging.INFO,
) -> logging.Logger:
    """Setup unified logging to file and console.

    Args:
        output_dir: Directory to save log file
        name: Logger name
        level: Logging level (default: INFO)

    Returns:
        Configured logger instance
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    logger = logging.getLogger(name)
    logger.setLevel(level)

    # Clear existing handlers to avoid duplicates
    logger.handlers.clear()

    # File handler - writes to pipeline.log in real-time
    log_file = output_dir / "pipeline.log"
    file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
    file_handler.setLevel(level)

    # Console handler - writes to stdout
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(level)

    # Format with timestamp
    formatter = logging.Formatter(
        "%(asctime)s [%(levelname)s] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    # Prevent propagation to root logger
    logger.propagate = False

    return logger


def get_logger(name: str = "pipeline") -> logging.Logger:
    """Get existing logger by name.

    Args:
        name: Logger name

    Returns:
        Logger instance (may be unconfigured if setup_logger not called)
    """
    return logging.getLogger(name)


class LoggerAdapter:
    """Adapter to redirect print() calls to logger.

    Usage:
        logger = setup_logger(output_dir)
        adapter = LoggerAdapter(logger)
        adapter.info("This will be logged")
        adapter.print("This will also be logged as INFO")
    """

    def __init__(self, logger: logging.Logger):
        self.logger = logger

    def debug(self, msg: str) -> None:
        self.logger.debug(msg)

    def info(self, msg: str) -> None:
        self.logger.info(msg)

    def warning(self, msg: str) -> None:
        self.logger.warning(msg)

    def error(self, msg: str) -> None:
        self.logger.error(msg)

    def critical(self, msg: str) -> None:
        self.logger.critical(msg)

    def print(self, msg: str) -> None:
        """Alias for info() to replace print() calls."""
        self.logger.info(msg)
