# file: prism/callbacks/metrics_callback.py
import sys

from prism.core.base_objects import BaseCallback
from prism.core.registry import CALLBACKS, METRICS
from prism.evaluation.metrics import _TorchmetricsWrapper


@CALLBACKS.register("Metrics")
class MetricsCallback(BaseCallback):
    def __init__(self, config):
        super().__init__(config)
        self.online_metric_instances = {}
        self.training_metric_instances = {}

        if hasattr(self.config.evaluation, 'metrics'):
            for name, cfg in self.config.evaluation.metrics.items():
                modes = cfg.get("mode")
                if not modes:
                    continue

                if isinstance(modes, str):
                    modes = [modes]

                try:
                    metric_cls = METRICS.get(name)
                    metric_instance = metric_cls(self.config)

                    if "online" in modes:
                        self.online_metric_instances[name] = metric_instance
                    if "training" in modes:
                        self.training_metric_instances[name] = metric_instance
                except Exception as e:
                    print(f"  - [ERROR] Failed to instantiate metric '{name}': {e}", file=sys.stderr)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if not self.training_metric_instances:
            return
        if (trainer.global_step + 1) % self.config.evaluation.log_interval != 0:
            return
        if not outputs or 'x_rec' not in outputs or 'data' not in outputs:
            return

        kwargs = {
            "x_rec": outputs['x_rec'],
            "data": outputs['data'],
            "device": pl_module.device,
        }

        for name, metric_instance in self.training_metric_instances.items():
            try:
                result = metric_instance.calculate(**kwargs)
                if trainer.is_global_zero:
                    pl_module.log(f'train_metric/{name}', result, on_step=True, on_epoch=False, logger=True)
            except Exception as e:
                if trainer.is_global_zero:
                    print(f"  - [ERROR] Failed to calculate training metric '{name}': {e}", file=sys.stderr)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        if not outputs or 'x_rec' not in outputs or 'data' not in outputs:
            return

        kwargs = {
            "x_rec": outputs['x_rec'],
            "data": outputs['data'],
            "device": pl_module.device,
        }

        for name, metric_instance in self.online_metric_instances.items():
            if isinstance(metric_instance, _TorchmetricsWrapper):
                try:
                    metric_instance.update(**kwargs)
                except Exception as e:
                    if trainer.is_global_zero:
                        print(f"  - [ERROR] Failed to update online metric '{name}' on batch: {e}", file=sys.stderr)

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        if not outputs or 'x_rec' not in outputs or 'data' not in outputs:
            return

        kwargs = {
            "x_rec": outputs['x_rec'],
            "data": outputs['data'],
            "device": pl_module.device,
        }

        for name, metric_instance in self.online_metric_instances.items():
            if isinstance(metric_instance, _TorchmetricsWrapper):
                try:
                    metric_instance.update(**kwargs)
                except Exception as e:
                    if trainer.is_global_zero:
                        print(f"  - [ERROR] Failed to update test metric '{name}' on batch: {e}", file=sys.stderr)

    def _calculate_and_log_metrics(self, trainer, pl_module, gathered_data, log_prefix):
        if not self.online_metric_instances:
            return

        kwargs = {}
        if gathered_data:
            kwargs = {
                "z_full": gathered_data['z_full'],
                "y_targets": gathered_data['y_targets'],
                "y_style": gathered_data['y_style'],
                "style_feature_map": trainer.datamodule.style_feature_map,
                "device": pl_module.device,
                "config": self.config,
            }

        if trainer.is_global_zero:
            epoch_str = f"Epoch {pl_module.current_epoch + 1}" if log_prefix.startswith('val') else "Test Set"
            print(f"\n--- Running Online Metrics for {epoch_str} ---")

        for name, metric_instance in self.online_metric_instances.items():
            try:
                if isinstance(metric_instance, _TorchmetricsWrapper):
                    result = metric_instance.compute()
                elif gathered_data:
                    result = metric_instance.calculate(**kwargs)
                else:
                    continue

                if trainer.is_global_zero:
                    full_prefix = f'{log_prefix}/{name}'
                    if isinstance(result, dict):
                        pl_module.log_dict({f"{full_prefix}_{k}": v for k, v in result.items()}, on_epoch=True)
                    else:
                        pl_module.log(full_prefix, result, on_epoch=True)
                    print(f"  - {name}: Logged successfully.")

            except Exception as e:
                if trainer.is_global_zero:
                    print(f"  - [ERROR] Failed to calculate metric '{name}' on {log_prefix}: {e}", file=sys.stderr)

    def on_validation_epoch_end(self, trainer, pl_module):
        if not getattr(pl_module, '_is_scheduled_epoch', False):
            return

        gathered_data = getattr(pl_module, 'gathered_validation_outputs', None)
        self._calculate_and_log_metrics(trainer, pl_module, gathered_data, 'val_online_metric')

    def on_test_epoch_end(self, trainer, pl_module):
        gathered_data = getattr(pl_module, 'gathered_test_outputs', None)
        self._calculate_and_log_metrics(trainer, pl_module, gathered_data, 'test_metric')