import os
import signal
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback

class PermissionCallback(Callback):
    def __init__(self, 
                 dirpath, 
                 num_parent_folders=3, 
                 every_n_steps=10000):
        super().__init__()
        self.dirpath = dirpath
        self.num_parent_folders = num_parent_folders
        self.every_n_steps = every_n_steps
        self._original_sigterm_handler = None
        self._original_sigint_handler = None

    def on_fit_start(self, trainer, pl_module):
        # Store original signal handlers
        self._original_sigterm_handler = signal.getsignal(signal.SIGTERM)
        self._original_sigint_handler = signal.getsignal(signal.SIGINT)
        # Register cleanup on kill signals
        signal.signal(signal.SIGTERM, self._handle_kill_signal)
        signal.signal(signal.SIGINT, self._handle_kill_signal)

    def on_fit_end(self, trainer, pl_module):
        # Restore original signal handlers
        signal.signal(signal.SIGTERM, self._original_sigterm_handler)
        signal.signal(signal.SIGINT, self._original_sigint_handler)
        # Change permissions at the end of training
        self._change_permissions(trainer)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        if (batch_idx + 1) % self.every_n_steps == 0:
            self._change_permissions(trainer)

    def _handle_kill_signal(self, signum, frame):
        # Change permissions if training is killed
        trainer = pl.Trainer._singleton_trainer
        if trainer is not None:
            self._change_permissions(trainer)
        # Restore original handlers and re-raise the signal
        signal.signal(signal.SIGTERM, self._original_sigterm_handler)
        signal.signal(signal.SIGINT, self._original_sigint_handler)
        os.kill(os.getpid(), signum)

    def _change_permissions(self, trainer):
        # Change permissions for all checkpoints in dirpath
        if os.path.exists(self.dirpath):
            for root, _, files in os.walk(self.dirpath):
                for file in files:
                    if file.endswith(".ckpt"):
                        filepath = os.path.join(root, file)
                        os.chmod(filepath, 0o777)
                        parent_dir = os.path.dirname(filepath)
                        for _ in range(self.num_parent_folders):
                            if parent_dir and os.path.exists(parent_dir):
                                os.chmod(parent_dir, 0o777)
                                parent_dir = os.path.dirname(parent_dir)
                            else:
                                break
