
  
import tensorflow as tf
from datetime import datetime as dt

"""=========================================================="""

class CheckPointTracker(object):
    def __init__(self, base_path, save_freq):
        self.base_path = base_path
        self.save_freq = save_freq
        self.ckpt_dict = {}
        
    def set_path(self, path):
        self.base_path = path
        
    def set_save_freq(self, save_freq):
        self.save_freq = save_freq
        
    def register(self, trackable, name):
        ckpt = tf.train.Checkpoint(step=tf.Variable(0), trackable=trackable)
        manager = tf.train.CheckpointManager(ckpt, self.base_path.rstrip("/") + "/" + name.lstrip("/"), max_to_keep=2)
        self.ckpt_dict[name] = (ckpt, manager)
        
    def update(self):
        for name, (ckpt, manager) in self.ckpt_dict.items():
            ckpt.step.assign_add(1)
            if int(ckpt.step) % self.save_freq == 0:
                manager.save()
                
    def restore(self):
        epochs_completed = 0
        for name, (ckpt, manager) in self.ckpt_dict.items():
            ckpt.restore(manager.latest_checkpoint)
            epochs_completed = int(ckpt.step.numpy())
        return epochs_completed

"""=========================================================="""

class ExperimentManager(object):
    def __init__(self):
        pass
    def on_experiment_start(self):
        pass
    def on_epoch_start(self):
        pass
    def on_epoch_end(self):
        pass
    def on_experiment_end(self):
        pass

"""=========================================================="""
    
class TimedExperimentManager(ExperimentManager):
    def __init__(self, max_time):
        self.max_time = max_time
        self.start_time = None
        
    def on_experiment_start(self):
        self.start_time = dt.now()
        
    def on_epoch_end(self):
        # Returns True if experiment should end
        print((dt.now() - self.start_time).total_seconds())
        if (dt.now() - self.start_time).total_seconds() <= self.max_time:
            return False
        else:
            print("Ending at {} seconds elapsed".format((dt.now() - self.start_time).total_seconds()))
            return True
        
