# file: user_extensions/baselines/callbacks.py
import sys
from pathlib import Path

from prism.core.base_objects import BaseCallback
from prism.core.registry import CALLBACKS, METRICS, VISUALIZATIONS
from prism.evaluation.metrics import _TorchmetricsWrapper


@CALLBACKS.register("BaselineEvaluation")
class BaselineEvaluationCallback(BaseCallback):
    def __init__(self, config):
        super().__init__(config)
        self.metric_instances = {}
        self.visualization_instances = {}

        if hasattr(self.config.evaluation, 'metrics'):
            for name, cfg in self.config.evaluation.metrics.items():
                if "online" in cfg.get("mode", []):
                    try:
                        metric_cls = METRICS.get(name)
                        self.metric_instances[name] = metric_cls(self.config)
                    except Exception as e:
                        print(f"  - [ERROR] Failed to instantiate metric '{name}': {e}", file=sys.stderr)

        if hasattr(self.config.evaluation, 'visualizations'):
            for name, cfg in self.config.evaluation.visualizations.items():
                if "online" in cfg.get("mode", []):
                    try:
                        viz_cls = VISUALIZATIONS.get(name)
                        self.visualization_instances[name] = viz_cls(self.config)
                    except Exception as e:
                        print(f"  - [ERROR] Failed to instantiate viz '{name}': {e}", file=sys.stderr)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        if not outputs or 'x_rec' not in outputs or 'data' not in outputs:
            return

        kwargs = {"x_rec": outputs['x_rec'], "data": outputs['data'], "device": pl_module.device}
        for name, metric in self.metric_instances.items():
            if isinstance(metric, _TorchmetricsWrapper):
                try:
                    metric.update(**kwargs)
                except Exception as e:
                    if trainer.is_global_zero:
                        print(f"  - [ERROR] Failed to update metric '{name}': {e}", file=sys.stderr)

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        if not outputs or 'x_rec' not in outputs or 'data' not in outputs:
            return

        kwargs = {"x_rec": outputs['x_rec'], "data": outputs['data'], "device": pl_module.device}
        for name, metric in self.metric_instances.items():
            if isinstance(metric, _TorchmetricsWrapper):
                try:
                    metric.update(**kwargs)
                except Exception as e:
                    if trainer.is_global_zero:
                        print(f"  - [ERROR] Failed to update test metric '{name}': {e}", file=sys.stderr)

    def on_validation_epoch_end(self, trainer, pl_module):
        if not getattr(pl_module, '_is_scheduled_epoch', False):
            return

        gathered_data = getattr(pl_module, 'gathered_validation_outputs', None)
        if not gathered_data:
            return

        if trainer.is_global_zero:
            print(f"\n--- Running Baseline Online Metrics for Epoch {pl_module.current_epoch + 1} ---")

        kwargs = {
            "z_full": gathered_data['z_full'],
            "y_targets": gathered_data['y_targets'],
            "y_style": gathered_data['y_style'],
            "style_feature_map": trainer.datamodule.style_feature_map,
            "device": pl_module.device,
            "config": self.config,
        }

        for name, metric in self.metric_instances.items():
            try:
                if isinstance(metric, _TorchmetricsWrapper):
                    result = metric.compute()
                else:
                    result = metric.calculate(**kwargs)

                if trainer.is_global_zero:
                    log_prefix = f'val_metric/{name}'
                    if isinstance(result, dict):
                        pl_module.log_dict({f"{log_prefix}_{k}": v for k, v in result.items()}, on_epoch=True)
                    else:
                        pl_module.log(log_prefix, result, on_epoch=True)
                    print(f"  - {name}: Logged successfully.")
            except Exception as e:
                if trainer.is_global_zero:
                    print(f"  - [ERROR] Failed to calculate metric '{name}': {e}", file=sys.stderr)

        if trainer.is_global_zero and self.visualization_instances:
            epoch = pl_module.current_epoch + 1
            plot_dir = Path(trainer.logger.log_dir) / f"visualizations/epoch_{epoch:03d}"
            plot_dir.mkdir(parents=True, exist_ok=True)

            print(f"\n--- Running Baseline Online Visualizations for Epoch {epoch} ---")
            for name, viz in self.visualization_instances.items():
                try:
                    viz.run(trainer=trainer, pl_module=pl_module, plot_dir=plot_dir, epoch=epoch, **gathered_data)
                    print(f"  - {name}: Generated successfully.")
                except Exception as e:
                    print(f"  - [ERROR] Failed to generate viz '{name}': {e}", file=sys.stderr)

    def on_test_epoch_end(self, trainer, pl_module):
        gathered_data = getattr(pl_module, 'gathered_test_outputs', None)
        if not gathered_data:
            print("  - [Warning] No test data was gathered. Skipping test metric calculation.", file=sys.stderr)
            return

        if trainer.is_global_zero:
            print("\n--- Running Final Test Set Metrics ---")

        kwargs = {
            "z_full": gathered_data['z_full'],
            "y_targets": gathered_data['y_targets'],
            "y_style": gathered_data['y_style'],
            "style_feature_map": trainer.datamodule.style_feature_map,
            "device": "cpu",
            "config": self.config,
        }

        for name, metric in self.metric_instances.items():
            try:
                if isinstance(metric, _TorchmetricsWrapper):
                    result = metric.compute()
                else:
                    result = metric.calculate(**kwargs)

                if trainer.is_global_zero:
                    log_prefix = f'test_metric/{name}'
                    if isinstance(result, dict):
                        pl_module.log_dict({f"{log_prefix}_{k}": v for k, v in result.items()}, on_epoch=True)
                    else:
                        pl_module.log(log_prefix, result, on_epoch=True)
                    print(f"  - {name}: Logged successfully.")
            except Exception as e:
                if trainer.is_global_zero:
                    print(f"  - [ERROR] Failed to calculate test metric '{name}': {e}", file=sys.stderr)