import os
import time
from pytorch_lightning.callbacks.base import Callback

from spaghettini import quick_register

BIG = 1e20


@quick_register
class SimpleCheckpointer(Callback):
    def __init__(self, tmp_ckpt_dir, final_ckpt_dir=None, save_every_n_epoch=BIG, save_every_t_min=BIG,
                 sync_every_n_epoch=BIG, sync_every_t_min=BIG):
        super().__init__()
        self.tmp_ckpt_dir = tmp_ckpt_dir if tmp_ckpt_dir is not None else final_ckpt_dir
        self.final_ckpt_dir = final_ckpt_dir

        self.save_every_n_epoch = save_every_n_epoch
        self.save_every_t_min = save_every_t_min

        self.sync_every_min = sync_every_t_min
        self.sync_every_n_epoch = sync_every_n_epoch

        self.last_saved = time.time()
        self.last_synched = time.time()

        print("Warning: Haven't verified that resuming training from checkpoint continues uninterrupted. ")

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        # ____ Save to temporary directory. ____
        mins_since_last_saved = (time.time() - self.last_saved) / 60
        if trainer.current_epoch % self.save_every_n_epoch == 0 or mins_since_last_saved > self.save_every_t_min:
            save_path = os.path.join(self.tmp_ckpt_dir, "checkpoint.ckpt")
            trainer.save_checkpoint(save_path)
            print(f"\nCheckpoint saved at {save_path}\n")
            self.last_saved = time.time()

        # ____ Sync with final directory. ____
        if self.final_ckpt_dir != self.tmp_ckpt_dir:
            mins_since_last_synched = (time.time() - self.last_synched) / 60
            if trainer.current_epoch % self.sync_every_n_epoch == 0 or mins_since_last_synched > self.save_every_t_min:
                sync_path = os.path.join(self.final_ckpt_dir, "checkpoint.ckpt")
                trainer.save_checkpoint(sync_path)
                print(f"\nCheckpoing synced at {sync_path}\n")
                self.last_synched = time.time()

    def on_train_end(self, trainer, pl_module):
        trainer.save_checkpoint(os.path.join(self.final_ckpt_dir, "checkpoint.ckpt"))
        if self.final_ckpt_dir != self.tmp_ckpt_dir:
            trainer.save_checkpoint(os.path.join(self.tmp_ckpt_dir, "checkpoint.ckpt"))

