


__all__ = ['setup_checkpointer']

import os.path as osp

import torch


class Checkpointer(object):

    def __init__(self, cfg, checkpointer_task_root,phase):

        
        self.checkpoint = self._load_checkpoint(checkpointer_task_root)
        if self.checkpoint is not None and phase == 'train':
            cfg.SOLVER.START_EPOCH += self.checkpoint.get('epoch', 0)
        elif self.checkpoint is None and phase != 'train':
            raise RuntimeError('Cannot find checkpoint {}'.format(cfg.MODEL.CHECKPOINT))

        self.output_dir = cfg.OUTPUT_DIR

    def load(self, model, checkpointer_task_root,optimizer=None,):
        if self.checkpoint is not None:
            model.load_state_dict(self.checkpoint['model_state_dict'])
            if optimizer is not None:
                optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
            model.reg_params = self.checkpoint['reg_params']

    def save(self, epoch, model, optimizer,num_tasks):
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'num_tasks':num_tasks,
            'reg_params':model.reg_params
        }, osp.join(self.output_dir, "task_"+str(num_tasks)+'_epoch-{}.pth'.format(epoch)))

    def save_best(self, epoch, model, optimizer,num_tasks):
        torch.save({
            'epoch': 0,
            'model_state_dict': model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'num_tasks':num_tasks,
            'reg_params': model.reg_params
        }, osp.join(self.output_dir, "task_"+str(num_tasks)+'_best.pth'))

    def _load_checkpoint(self, checkpoint):
        if checkpoint is not None and osp.isfile(checkpoint):
            return torch.load(checkpoint, map_location=torch.device('cpu'))
        return None


def setup_checkpointer(cfg, checkpointer_task_root,phase):
    return Checkpointer(cfg, checkpointer_task_root,phase)
