# file: prism/callbacks/online_visualizations.py
from pathlib import Path
import torch
import torch.distributed as dist
import sys

from prism.core.base_objects import BaseCallback
from prism.core.registry import CALLBACKS, VISUALIZATIONS


@CALLBACKS.register("OnlineVisualizations")
class OnlineVisualizationCallback(BaseCallback):
    def __init__(self, config):
        super().__init__(config)
        self.visualization_instances = {}
        if hasattr(self.config.evaluation, 'visualizations'):
            for name, cfg in self.config.evaluation.visualizations.items():
                if cfg.get("mode") == "online":
                    try:
                        viz_cls = VISUALIZATIONS.get(name)
                        self.visualization_instances[name] = viz_cls(self.config)
                    except Exception as e:
                        print(
                            f"  - [ERROR] Failed to instantiate visualization '{name}': {e}",
                            file=sys.stderr
                        )

    def on_validation_epoch_end(self, trainer, pl_module):
        if not getattr(pl_module, '_is_scheduled_epoch', False):
            return

        if not self.visualization_instances:
            return

        gathered_data = getattr(pl_module, 'gathered_validation_outputs', None)
        if not gathered_data:
            return

        if trainer.is_global_zero:
            epoch = pl_module.current_epoch + 1
            plot_dir = Path(trainer.logger.log_dir) / f"visualizations/epoch_{epoch:03d}"
            plot_dir.mkdir(parents=True, exist_ok=True)

            print(f"\n--- Running Online Visualizations for Epoch {epoch} ---")
            for name, viz_instance in self.visualization_instances.items():
                try:
                    viz_instance.run(
                        trainer=trainer,
                        pl_module=pl_module,
                        plot_dir=plot_dir,
                        epoch=epoch,
                        **gathered_data
                    )
                    print(f"  - {name}: Generated successfully.")
                except Exception as e:
                    print(
                        f"  - [ERROR] Failed to generate online visualization '{name}': {e}",
                        file=sys.stderr
                    )