import torch

class Checkpoint:
    def __init__(self, model, path="./model.pkl", etc=1):
        self.model = model
        self.path  = path
        self.etc   = etc

    def step(self, local):
        pass

    def epoch(self, local):
        if local["epoch"] % self.etc != 0: return
        torch.save({
            "modelsd" : self.model.state_dict()
        }, self.path)
