# file: prism/callbacks/data_gatherer.py
import torch
import torch.distributed as dist

from prism.core.base_objects import BaseCallback
from prism.core.registry import CALLBACKS


@CALLBACKS.register("DataGatherer")
class DataGathererCallback(BaseCallback):
    def _gather_and_attach(self, trainer, pl_module, step_outputs_attr, attach_attr):
        step_outputs = getattr(pl_module, step_outputs_attr, [])
        if not step_outputs:
            setattr(pl_module, attach_attr, None)
            return

        if trainer.world_size > 1:
            gathered_outputs = [None for _ in range(trainer.world_size)]
            dist.all_gather_object(gathered_outputs, step_outputs)
            all_outputs = [item for sublist in gathered_outputs for item in sublist]
        else:
            all_outputs = step_outputs

        if not all_outputs:
            setattr(pl_module, attach_attr, None)
            if trainer.world_size > 1:
                dist.barrier()
            return

        gathered_data = {
            "z_full": torch.cat([x['z'] for x in all_outputs], dim=0),
            "y_targets": torch.cat([x['target_labels'] for x in all_outputs], dim=0),
            "y_style": torch.cat([x['style_labels'] for x in all_outputs], dim=0),
            "data": torch.cat([x['data'] for x in all_outputs], dim=0)
        }
        setattr(pl_module, attach_attr, gathered_data)

        if trainer.world_size > 1:
            dist.barrier()

    def on_validation_epoch_start(self, trainer, pl_module):
        if hasattr(pl_module, 'gathered_validation_outputs'):
            delattr(pl_module, 'gathered_validation_outputs')

    def on_validation_epoch_end(self, trainer, pl_module):
        self._gather_and_attach(
            trainer=trainer,
            pl_module=pl_module,
            step_outputs_attr='validation_step_outputs',
            attach_attr='gathered_validation_outputs'
        )

    def on_test_epoch_end(self, trainer, pl_module):
        self._gather_and_attach(
            trainer=trainer,
            pl_module=pl_module,
            step_outputs_attr='test_step_outputs',
            attach_attr='gathered_test_outputs'
        )