import os
import csv
import logging
from typing import Dict, Union, List

from omegaconf import OmegaConf
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, RichProgressBar
from pytorch_lightning.utilities.rank_zero import rank_zero_only


def build_callbacks(
    cfg: OmegaConf,
    output_dir: str,
    logging_dir: str,
    phase: str = "train",
    logger: logging.Logger = None,
) -> List[Callback]:
    """
    Build and return a list of callbacks based on the configuration and phase.

    Parameters:
    - cfg: OmegaConf object containing configuration options.
    - output_dir: Path to save checkpoints and output files.
    - logging_dir: Path to store logs.
    - phase: Current phase (e.g., "train" or "test"). Default is "train".

    Returns:
    - List[Callback]: A list of instantiated callback objects.
    """

    # Initialize the list of callbacks
    callbacks = []

    # Add the ProgressBar for displaying progress
    progress_bar_callback = ProgressBar()
    callbacks.append(progress_bar_callback)

    # Add custom ProgressLogger if logging steps are provided in the config
    progress_logger_callback = ProgressLogger(
        log_dir=logging_dir,
        log_every_n_steps=cfg.logging.get("logging_steps", 1),
        metric_monitor=cfg.metric.get("monitor", None),
        root_logger=logger,
    )
    callbacks.append(progress_logger_callback)

    # If in training phase, add checkpointing callbacks
    if phase == "train":
        checkpoint_callbacks = _get_checkpoint_callbacks(cfg, output_dir)
        callbacks.extend(checkpoint_callbacks)

    return callbacks


def _get_checkpoint_callbacks(cfg: OmegaConf, output_dir: str) -> list[Callback]:
    """Prepare checkpoint callbacks based on configuration."""
    callbacks = []
    val_steps = cfg.training.get("eval_steps", 1)

    # General checkpoint parameters
    checkpoint_params = {
        "dirpath": output_dir,
        "monitor": "step",
        "mode": "max",
        "filename": "{epoch}",
        "save_top_k": 5,
        "save_last": False,
        "save_on_train_epoch_end": True,
        "every_n_epochs": cfg.logging.save_steps,
    }

    # Standard model checkpoint based on epoch
    callbacks.append(ModelCheckpoint(**checkpoint_params))

    # Additional checkpoints every validation steps
    checkpoint_params.update(
        {
            "every_n_epochs": int(val_steps),
            "filename": "latest-{epoch}",
            "save_top_k": 1,
            "save_last": False,
            "mode": "max",
        }
    )

    callbacks.append(ModelCheckpoint(**checkpoint_params))

    # Additional checkpoints based on specific metrics
    checkpoint_params.update(
        {
            "save_top_k": 1,
            "auto_insert_metric_name": False,
            "save_last": False,
            "save_weights_only": True,
        }
    )
    for metric, options in cfg.metric.get("save_monitor", {}).items():
        checkpoint_params.update(
            {
                "monitor": metric,
                "mode": options["mode"],
                "filename": f"{options['mode']}-{options['abbr']}" + "-{epoch}",
            }
        )
        checkpoint = ModelCheckpoint(**checkpoint_params)
        callbacks.append(checkpoint)

    # Additional checkpoints every validation steps
    if cfg.training.get("save_steps", 0) > 0:
        checkpoint_params.update(
            {
                "every_n_epochs": None,
                "every_n_train_steps": int(cfg.training.get("save_steps", 0)),
                "monitor": "step",
                "mode": "max",
                "filename": "step-{step}",
                "save_top_k": 5,
                "save_last": False,
            }
        )
        checkpoint = ModelCheckpoint(**checkpoint_params)
        callbacks.append(checkpoint)

    if cfg.training.get("save_steps_all", 0) > 0:
        checkpoint_params.update(
            {
                "every_n_epochs": None,
                "every_n_train_steps": int(cfg.training.get("save_steps_all", 0)),
                "monitor": "step",
                "mode": "max",
                "filename": "step-{step}",
                "save_top_k": -1,
                "save_last": False,
            }
        )
        checkpoint = ModelCheckpoint(**checkpoint_params)
        callbacks.append(checkpoint)

    return callbacks


class ProgressBar(RichProgressBar):
    """Custom progress bar for training progress."""

    def __init__(
        self,
    ):
        super().__init__()
        self.theme.metrics_format = ".3g"

    def get_metrics(
        self, trainer: Trainer, model: LightningModule
    ) -> Dict[str, Union[int, float]]:
        """
        Override this method to customize the metrics displayed in the progress bar.
        """
        items = super().get_metrics(trainer, model)
        items.pop("v_num", None)

        return items


class ProgressLogger(Callback):
    """Custom logger for training progress and saving metrics to a CSV."""

    def __init__(
        self,
        log_dir="logs",
        filename="metrics.csv",
        metric_monitor: dict = None,
        metrics_format=".3g",
        log_every_n_steps: int = 1,
        root_logger: logging.Logger = None,
    ):
        super().__init__()
        self.metric_monitor = metric_monitor
        self.metrics_format = metrics_format
        self.log_every_n_steps = log_every_n_steps

        if rank_zero_only.rank == 0:
            # CSV file setup
            self.log_dir = log_dir
            self.filename = filename
            os.makedirs(self.log_dir, exist_ok=True)
            self.filepath = os.path.join(self.log_dir, self.filename)

            # Create CSV file with header
            with open(self.filepath, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["epoch"] + ["metric_name"] + ["value"])

            # Logger setup
            self.logger = logging.getLogger(__name__)
            self.logger.propagate = False

            # Inherit parent file handler
            root_logger = logging.getLogger() if root_logger is None else root_logger
            for handler in root_logger.handlers:
                if isinstance(handler, logging.FileHandler):
                    self.logger.addHandler(handler)

            # Custom console handler
            custom_console_handler = logging.StreamHandler()
            custom_console_formatter = logging.Formatter(
                "\n%(asctime)s - %(message)s", datefmt="%H:%M"
            )
            custom_console_handler.setFormatter(custom_console_formatter)
            self.logger.addHandler(custom_console_handler)
            self.logger.setLevel(logging.INFO)

    @rank_zero_only
    def log_metrics_to_csv(self, metrics: Dict[str, Union[int, float]], epoch: int):
        """Helper method to log metrics to a CSV file."""
        with open(self.filepath, "a", newline="") as f:
            writer = csv.writer(f)
            for metric_name, value in metrics.items():
                writer.writerow(
                    [
                        epoch,
                        metric_name,
                        value.item() if hasattr(value, "item") else value,
                    ]
                )

    @rank_zero_only
    def log_metrics_to_console(self, metrics: Dict[str, Union[int, float]], epoch: int):
        """Helper method to log metrics to console."""
        metric_items = metrics.items()
        if self.metric_monitor:
            metric_items = [
                (name, metrics.get(path, None))
                for name, path in self.metric_monitor.items()
            ]

        # Format each metric
        metrics_log = "   ".join(
            f"{name}: {{:{self.metrics_format}}}".format(
                value.item() if hasattr(value, "item") else value
            )
            for name, value in metric_items
            if value is not None
        )

        if metrics_log:
            self.logger.info(f"Epoch {epoch}: {metrics_log}")

    @rank_zero_only
    def on_train_start(self, *args, **kwargs):
        self.logger.info("Training started.")

    @rank_zero_only
    def on_train_end(self, *args, **kwargs):
        self.logger.info("Training done.")

    def on_train_epoch_end(self, trainer: Trainer, *args, **kwargs):
        if trainer.current_epoch % self.log_every_n_steps == 0:
            metrics = trainer.callback_metrics
            self.log_metrics_to_console(metrics, trainer.current_epoch)

    def on_validation_epoch_end(self, trainer, pl_module):
        """Logs validation metrics to both console and CSV."""
        metrics = trainer.logged_metrics
        epoch = trainer.current_epoch

        # Log metrics to console and CSV
        self.log_metrics_to_console(metrics, epoch)
        self.log_metrics_to_csv(metrics, epoch)

    def on_test_epoch_end(self, trainer, pl_module):
        """Logs test metrics to both console and CSV."""
        metrics = trainer.logged_metrics
        epoch = trainer.current_epoch

        # Log metrics to console and CSV
        self.log_metrics_to_console(metrics, epoch)
        self.log_metrics_to_csv(metrics, epoch)
