import os
import logging
from omegaconf import OmegaConf
from typing import Dict, Union
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import Callback, RichProgressBar, ModelCheckpoint


def build_callbacks(cfg: OmegaConf, phase: str = "train") -> list[Callback]:
    """Build and return a list of callbacks based on configuration and phase."""
    callbacks = [ProgressBar()]

    if phase == "train":
        callbacks.extend(_get_checkpoint_callbacks(cfg))

    return callbacks


def _get_checkpoint_callbacks(cfg: OmegaConf) -> list[Callback]:
    """Prepare checkpoint callbacks based on configuration."""
    callbacks = [
        ProgressLogger(
            metric_monitor=cfg.LOGGER.METRIC_MONITOR,
            log_every_n_steps=cfg.LOGGER.VAL_EVERY_STEPS,
        )
    ]

    # General checkpoint parameters
    checkpoint_params = {
        "dirpath": os.path.join(cfg.FOLDER_EXP, "checkpoints"),
        "monitor": "step",
        "mode": "max",
        "filename": "{epoch}",
        "save_top_k": -1,
        "save_last": "link",
        "save_on_train_epoch_end": True,
        "every_n_epochs": int(cfg.LOGGER.VAL_EVERY_STEPS * 10),
    }

    # Standard model checkpoint based on epoch
    callbacks.append(ModelCheckpoint(**checkpoint_params))

    # Additional checkpoints every validation steps
    checkpoint_params.update(
        {
            "every_n_epochs": int(cfg.LOGGER.VAL_EVERY_STEPS),
            "filename": "latest-{epoch}",
            "save_top_k": 1,
            "save_last": False,
            "mode": "max",
        }
    )
    callbacks.append(ModelCheckpoint(**checkpoint_params))

    # Additional checkpoints based on specific metrics
    checkpoint_params.update(
        {
            "save_top_k": 1,
            "auto_insert_metric_name": False,
            "save_last": False,
            "save_weights_only": True,
        }
    )
    for metric, options in _get_metric_specific_checkpoint_params(cfg).items():
        checkpoint_params.update(
            {
                "monitor": metric,
                "mode": options["mode"],
                "filename": f"{options['mode']}-{options['abbr']}" + "-{epoch}",
            }
        )
        checkpoint = ModelCheckpoint(**checkpoint_params)
        callbacks.append(checkpoint)

    return callbacks


def _get_metric_specific_checkpoint_params(cfg: OmegaConf) -> dict:
    """Prepare metric-specific checkpoint parameters."""
    metric_specific_checkpoint_params = {}
    for metric in cfg.METRIC.TYPE:
        metric_monitors = cfg.LOGGER.CHEKPOINT_MONITOR.get(metric, {})
        metric_specific_checkpoint_params.update(metric_monitors)
    return metric_specific_checkpoint_params


class ProgressBar(RichProgressBar):
    """Custom progress bar for training progress."""

    def __init__(
        self,
    ):
        super().__init__()

    def get_metrics(
        self, trainer: Trainer, model: LightningModule
    ) -> Dict[str, Union[int, float]]:
        """
        Override this method to customize the metrics displayed in the progress bar.
        """
        items = super().get_metrics(trainer, model)
        items.pop("v_num", None)

        return items


class ProgressLogger(Callback):
    """Custom logger for training progress."""

    def __init__(
        self, metric_monitor: dict, precision: int = 3, log_every_n_steps: int = 1
    ) -> None:
        super().__init__()

        self.metric_monitor = metric_monitor
        self.precision = precision
        self.log_every_n_steps = log_every_n_steps

        self.logger = logging.getLogger(__name__)
        self.logger.propagate = False

        # Inherit parent file handler
        root_logger = logging.getLogger(__name__.split(".")[0])
        for handler in root_logger.handlers:
            if isinstance(handler, logging.FileHandler):
                self.logger.addHandler(handler)

        # Custom console handler
        custom_console_handler = logging.StreamHandler()
        custom_console_formatter = logging.Formatter(
            "\n%(asctime)s - %(message)s", datefmt="%H:%M"
        )
        custom_console_handler.setFormatter(custom_console_formatter)
        self.logger.addHandler(custom_console_handler)

    def on_train_start(self, *args, **kwargs) -> None:
        self.logger.info("Training started.")

    def on_train_end(self, *args, **kwargs) -> None:
        self.logger.info("Training done.")

    def on_train_epoch_end(self, trainer: Trainer, *args, **kwargs) -> None:
        if trainer.current_epoch % self.log_every_n_steps == 0:
            metrics = trainer.callback_metrics
            metrics_log = "   ".join(
                f"{name}: {{:.{self.precision}f}}".format(metrics[path].item())
                for name, path in self.metric_monitor.items()
                if path in metrics
            )
            if metrics_log:
                self.logger.info(f"Epoch {trainer.current_epoch}: {metrics_log}")
