# file: prism/callbacks/artifact_saver.py
from pathlib import Path
import torch
from pytorch_lightning.utilities.rank_zero import rank_zero_only

from prism.core.base_objects import BaseCallback
from prism.core.registry import CALLBACKS


@CALLBACKS.register("ArtifactSaver")
class ArtifactSaver(BaseCallback):
    def __init__(self, config):
        super().__init__(config)
        self.eval_cfg = self.config.evaluation

    def _save_artifacts(self, trainer, pl_module, gathered_data, artifact_subdir):
        if not gathered_data:
            return

        if trainer.is_global_zero:
            print(f"\n--- Saving Artifacts to '{artifact_subdir}' ---")
            save_kwargs = gathered_data.copy()
            save_kwargs.pop('data', None)
            self.save_artifacts_from_data(
                trainer=trainer,
                pl_module=pl_module,
                artifact_dir_name=artifact_subdir,
                **save_kwargs
            )
            print(f"  - Artifacts saved successfully to '{artifact_subdir}'.")

    @rank_zero_only
    def save_artifacts_from_data(self, trainer, pl_module, artifact_dir_name, z_full, y_targets, y_style):
        artifact_dir = Path(trainer.logger.log_dir) / "artifacts" / artifact_dir_name
        artifact_dir.mkdir(parents=True, exist_ok=True)

        torch.save(pl_module.encoder.state_dict(), artifact_dir / "encoder.pth")
        torch.save(pl_module.generator.state_dict(), artifact_dir / "generator.pth")
        torch.save(pl_module.classifier.state_dict(), artifact_dir / "classifier.pth")
        torch.save(z_full, artifact_dir / "Z_full.pt")
        torch.save(y_targets, artifact_dir / "Y_targets.pt")
        if y_style is not None:
            torch.save(y_style, artifact_dir / "Y_style.pt")

    def on_validation_epoch_end(self, trainer, pl_module):
        if not getattr(pl_module, '_is_scheduled_epoch', False):
            return

        gathered_data = getattr(pl_module, 'gathered_validation_outputs', None)
        if not gathered_data:
            return

        artifact_subdir = f"epoch_{pl_module.current_epoch + 1:03d}"
        self._save_artifacts(trainer, pl_module, gathered_data, artifact_subdir)

    def on_test_epoch_end(self, trainer, pl_module):
        gathered_data = getattr(pl_module, 'gathered_test_outputs', None)
        artifact_subdir = "test_set_results"
        self._save_artifacts(trainer, pl_module, gathered_data, artifact_subdir)