import numpy as np
import sys
sys.path.append("../")
from mmcv.runner.hooks.hook import Hook


class UpdateCacheHook(Hook):

    def __init__(self, update_epoch_interval=10,start_update=0):
        self.update_epoch_interval = update_epoch_interval
        self.start_update = start_update
        self.update_flag = False

    def before_train_epoch(self, runner):
        if runner.epoch >= self.start_update:
            self.update_flag = True

        if runner.epoch % self.update_epoch_interval == 0:
            runner.model.module.update_cachememory(self.update_flag)
            if self.update_flag:
                np.save(runner.work_dir+"/source_data.npy",runner.model.module.source_data.cpu().numpy())
                np.save(runner.work_dir+"/target_data.npy",runner.model.module.target_data.cpu().numpy())
                print("source_data shape is ",runner.model.module.source_data.cpu().numpy().shape )
                print("target_data shape is ",runner.model.module.target_data.cpu().numpy().shape )
