import numpy as np
import wandb
import torch
import os
import glob
# from .runner import Runner
import torch.distributed as dist 
# from .runner import Runner
from utils.parallel import get_dist_info
from utils.display import display_metrics_dict
from utils.evaluation import AverageMeter, accuracy, ConfusionMatrix

TASK_TYPES = ['source_ssl', 'source_ssl_lv', 'target_ssl', 'distill', 'target_lv', 'target_ft', 'clip', 'clipm', 'target_seglv', 'target_segft']
class CMDRunner:
    """The runner of Kernel ROBOT for .
    Args:
            
    """
    def __init__(self, 
                 models,
                 optims,
                 schedulers,
                 losses,
                 logger,
                 work_dir,
                 device,
                 backbone_types=('resnet18', 'resnet18'),
                 pretrained_path_dict=None,
                 sample_size=500,
                 max_epochs=(100, 100),
                 train_tasks=('source', 'distill', 'target'),
                 load_pretrained_ssl=False,
                 wandb=True,
                 print_every=100,
                 val_every=20000,
                 save_every=20000,
                 resume_from='',
                 resume_latest=False,
                 meta=None):
        self.models = models
        self.optims = optims
        self.schedulers = schedulers
        self.losses = losses
        self.logger = logger
        self.work_dir = work_dir
        self.device = device
        self._backbone_types = backbone_types
        self._pretrained_path_dict = pretrained_path_dict
        self.sample_size = sample_size
        self._model_name = 'KernelROBOTRunner'
        self._rank, self._world_size = get_dist_info()
        self._epoch = 1
        self._iter = 1
        self._max_epochs = max_epochs
        self._train_tasks = train_tasks
        self._max_iters = max_epochs
        self._load_pretrained_ssl = load_pretrained_ssl
        self._wandb = wandb
        self._print_every = print_every
        self._val_every = val_every
        self._save_every = save_every
        # self._task_idxs_dict = {'source':0, 'distill':1, 'target':2, 'target_only':2, 'target_ssl':2}
        self.meta = meta if meta is not None else {}
        resume_succ = False
        if self._load_pretrained_ssl:
            self.load_pretrained(self._backbone_types, self._pretrained_path_dict, self.device)

        self.logger.info(f'resume_from:{resume_from}')
        if resume_from != '':
            self.logger.info(f'Try to resume from {resume_from}. \n Existence:{os.path.exists(resume_from)}')

        if os.path.exists(resume_from):
            self.logger.info(f'resume from {resume_from}')
            map_location = self.device
            try:
                self.resume(resume_from, map_location=map_location)
                resume_succ = True
            except:
                self.logger.info('Resume Error. Start from random init weights.')
            if resume_succ:
                self.logger.info(f'Resume Success from {resume_from}.')
        elif resume_latest:
            ckpt_dir = os.path.join(self.work_dir, 'chckpoints')
            for task_type in ['target', 'distill', 'source']:
                ckpt_paths = glob.glob(f'{ckpt_dir}/{task_type}/iter_*.pth')
                if len(ckpt_paths) > 0:
                    max_length = max([len(ckpt) for ckpt in ckpt_paths])
                    ckpt_paths = [ckpt for ckpt in ckpt_paths if len(ckpt) == max_length]
                    for latest_ckpt_path in reversed(sorted(ckpt_paths)):
                        print(f'resume from {latest_ckpt_path}')
                        map_location = next(self.models['psi_0'].parameters()).device
                        try:
                            self.resume(latest_ckpt_path, map_location=map_location)
                            resume_succ = True
                            break
                        except:
                            self.logger.info('Resume Error. Try Previous Iteraiton.')
                if resume_succ:
                    break

    def initialize_metrics_dict(self, task_type):
        
        if task_type in ['distill', 'clip', 'clipm']:
            metrics_names = ['loss', 
                             'source_acc1', 'source_acc5', 
                             'target_acc1', 'target_acc5']
        elif task_type in ['target_seglv', 'target_segft']:
            metrics_names = ['loss', 'acc_global', 'mIoU']
        else:
            metrics_names = ['loss', 'acc1', 'acc5']
        
        metrics_dict = {metric:AverageMeter() for metric in metrics_names}
        if task_type in ['target_seglv', 'target_segft']:
            metrics_dict['confmat'] = ConfusionMatrix(40)
        return metrics_names, metrics_dict

    def display_info(self, tag, 
                     epoch_i, iter_i, max_epoch, each_iters,
                     metrics_names, metrics_dict):
        metrics_message = display_metrics_dict(metrics_names, metrics_dict)
        epoch_message = f'Epoch[{epoch_i+1}/{max_epoch}], Iter[{iter_i+1}/{each_iters}]'
        cur_lrs = self.current_lr()
        cur_lr = list(cur_lrs.values())[0][0]
        optim_message = f'cur_lr:{cur_lr:.7f}'
        whole_message = tag + ' ' + epoch_message + ' (' + optim_message + ') :' + metrics_message
        self.logger.info(whole_message)

    def metric_info(self, metrics_names, metrics_dict, 
                    loss_val, preds=None, gt_labels=None):
        batch_size = preds[list(preds.keys())[0]].size(0) if isinstance(preds, dict) else preds.size(0) 
        with torch.no_grad():
            # update metrics
            metrics_dict['loss'].update(loss_val.detach().cpu().item(), batch_size)
            if preds is not None and gt_labels is not None:
                if isinstance(preds, dict):
                    for key in preds:
                        if  f'{key}_acc1' in metrics_dict:
                            accs = accuracy(preds[key], gt_labels[key], topk=(1,5))
                            metrics_dict[f'{key}_acc1'].update(accs[1], batch_size)
                            metrics_dict[f'{key}_acc5'].update(accs[5], batch_size)
                else:
                    if  'acc1' in metrics_dict:
                        accs = accuracy(preds, gt_labels, topk=(1,5))
                        metrics_dict['acc1'].update(accs[1], batch_size)
                        metrics_dict['acc5'].update(accs[5], batch_size)
                    if 'mIoU' in metrics_dict:
                        metrics_dict['confmat'].update(gt_labels.flatten(), preds.argmax(1).flatten())
                        acc_global, acc, iu = metrics_dict['confmat'].compute()
                        # metrics_dict['acc_global'].update(acc_global.item(), batch_size)
                        # metrics_dict['mIoU'].update(iu.mean().item(), batch_size)
                        metrics_dict['acc_global'].avg = acc_global.item()
                        metrics_dict['acc_global'].val = acc_global.item()
                        metrics_dict['mIoU'].avg = iu.mean().item()
                        metrics_dict['mIoU'].val = iu.mean().item()

            metrics_dict_val = {mn:metrics_dict[mn].val for mn in metrics_names}

        return metrics_dict_val
    
    def update_info(self, epoch_i, iter_i, whole_iter_i, 
                    max_epoch, each_iters,
                    metrics_names, metrics_dict, 
                    loss_val, preds, gt_labels, 
                    task_type, val_loader):
        """
        Update the info to wandb and save if it is certain iteration
        """
        
        # Display and Save
        if self._rank == 0:
            if whole_iter_i % self._print_every == 0:  
                metrics_dict_val = self.metric_info(metrics_names, metrics_dict, 
                                                    loss_val, preds, gt_labels)
                self.display_info(f'{task_type} Train', 
                                  epoch_i, iter_i, max_epoch, each_iters,
                                  metrics_names, metrics_dict_val)
                if self._wandb:
                    self.wandb_info(task_type, 'train', metrics_names, metrics_dict, 
                                    step=int((epoch_i + iter_i / each_iters) * 100))

            if whole_iter_i % self._val_every == 0:
                self.val(task_type, val_loader, epoch_i, iter_i, max_epoch, each_iters)
                
            if whole_iter_i % self._save_every == 0:
                # save checkpoints
                self.save_checkpoint(self.work_dir, task_type, f'iter_{whole_iter_i}.pth') 
    

    def wandb_info(self, task_type, tag, metrics_names, metrics_dict, step):
        assert tag in ['train', 'val']
        metrics_dict_val, metrics_dict_avg = {}, {}

        if tag == 'train':
            metrics_dict_val = {'/' + task_type + '/' + tag + '/' + mn :metrics_dict[mn].val for mn in metrics_names}
        if tag == 'val':
            metrics_dict_avg = {'/' + task_type + '/' + tag + '/' + mn :metrics_dict[mn].avg for mn in metrics_names}

        wandb.log({**metrics_dict_val, **metrics_dict_avg, 'epoch':step})
    
    def train_source_batch(self, psi_1, phi_1, g_1, data_batch, optim_keys):
        info_nce_loss = self.losses['task_{}'.format(self._task_idx)]
        # print(data_batch)
        # in_datas, gt_labels = [dterm.to(self.device) for dterm in data_batch]
        in_datas, gt_labels = data_batch
        # print(in_datas.size())
        in_datas = torch.cat(in_datas, dim=0).to(self.device)
        # feedforward
        for key in optim_keys:
            self.optims[key].zero_grad()
        embeddings = psi_1(in_datas)
        feats = phi_1(embeddings)
        ssl_feats = g_1(feats)
        # optimization
        # print(ssl_feats.shape)
        logits, labels, loss_val = info_nce_loss(ssl_feats)
        loss_val.backward()
        for key in optim_keys:
            self.optims[key].step()
        return logits, labels, loss_val

    def train_clip_batch(self, psi_0, phi_0, g_0, psi_1, phi_1, g_1,
                             data_batch, optim_keys):
        clip_loss = self.losses['task_{}'.format(self._task_idx)]
        source_datas, target_datas = data_batch
        source_datas = torch.cat(source_datas, dim=0).to(self.device)
        target_datas = torch.cat(target_datas, dim=0).to(self.device)
        # feedforward
        for key in optim_keys:
            self.optims[key].zero_grad()
        embeddings = psi_1(source_datas)
        source_feats = phi_1(embeddings)
        source_ssl_feats = g_1(source_feats)

        embeddings = psi_0(target_datas)
        target_feats = phi_0(embeddings)
        target_ssl_feats = g_0(target_feats)

        # optimization
        logits_per_source, logits_per_target, labels, loss_val = clip_loss(target_ssl_feats, source_ssl_feats) # pred, gt
        loss_val.backward()
        for key in optim_keys:
            self.optims[key].step()
        # print({'source':source_logits, 'target':target_logits}, {'source':source_labels, 'target':target_labels})
        return {'source':logits_per_source, 'target':logits_per_target}, \
                {'source':labels, 'target':labels}, loss_val
    
    def train_distill_batch(self, psi_0, phi_0, g_0, psi_1, phi_1, g_1,
                             data_batch, optim_keys):
        co_info_nce_loss = self.losses['task_{}'.format(self._task_idx)]
        source_datas, target_datas = data_batch
        source_datas = torch.cat(source_datas, dim=0).to(self.device)
        target_datas = torch.cat(target_datas, dim=0).to(self.device)
        # feedforward
        for key in optim_keys:
            self.optims[key].zero_grad()
        embeddings = psi_1(source_datas)
        source_feats = phi_1(embeddings)
        source_ssl_feats = g_1(source_feats)

        embeddings = psi_0(target_datas)
        target_feats = phi_0(embeddings)
        target_ssl_feats = g_0(target_feats)

        # optimization
        source_logits, source_labels, target_logits, target_labels, loss_val = co_info_nce_loss(target_ssl_feats, source_ssl_feats) # pred, gt
        loss_val.backward()
        for key in optim_keys:
            self.optims[key].step()
        # print({'source':source_logits, 'target':target_logits}, {'source':source_labels, 'target':target_labels})
        return {'source':source_logits, 'target':target_logits}, \
                {'source':source_labels, 'target':target_labels}, loss_val
    
    def train_target_batch(self, psi_0, phi_0, f_0,
                            data_batch, optim_keys):
        ce_loss = self.losses['task_{}'.format(self._task_idx)]
        in_datas, gt_labels = [dterm.to(self.device) for dterm in data_batch]
        gt_labels = gt_labels.to(torch.long)
        # print('train target torch', in_datas[0].sum(), gt_labels[0])
        # feedforward
        for key in optim_keys:
            self.optims[key].zero_grad()
        embeddings = psi_0(in_datas)
        feats = phi_0(embeddings)
        preds = f_0(feats)
        # optimization
        # print('preds', preds.size(), gt_labels.size())
        loss_val = ce_loss(preds, gt_labels)
        loss_val.backward()
        for key in optim_keys:
            self.optims[key].step()
        return preds, gt_labels, loss_val

    def get_loaders(self, data_names, val_data_names, train_loaders, val_loaders):
        # data_names 0: source, 1:distill, 2:target
        # print( data_names, train_loaders, self._task_idxs_dict)
        train_loader = train_loaders[data_names[self._task_idx]][0]  # final 0 just for taking the item from list
        val_loader = val_loaders[val_data_names[self._task_idx]][0]
        return train_loader, val_loader

    def train_base(self, task_type, max_epoch, data_names, val_data_names, train_loaders, val_loaders):
        train_loader, val_loader = self.get_loaders(data_names, val_data_names, train_loaders, val_loaders)
        each_iters = len(train_loader)
        max_iters = max_epoch * each_iters
        # Initialization
        # represent the source models
        psi_1, phi_1, f_1, g_1 = self.models['psi_1'], self.models['phi_1'], self.models['f_1'], self.models['g_1']
        psi_0, phi_0, f_0, g_0 = self.models['psi_0'], self.models['phi_0'], self.models['f_0'], self.models['g_0']

        # metrics record
        metrics_names, metrics_dict = self.initialize_metrics_dict(task_type)
        whole_iter_i = 0 
        assert task_type in TASK_TYPES, f'Training: Unsupported task type {task_type}.'
        
        self.logger.info('Start Training...')
        for epoch_i in range(max_epoch):
            for iter_i, data_batch in enumerate(train_loader):
                preds, gt_labels = None, None
                if 'seg' in task_type:
                    input_shape = data_batch[0].size()[2:]
                    f_0.reset_input_shape(input_shape)
                    f_1.reset_input_shape(input_shape)
                if task_type == 'source_ssl':
                    optim_keys = ['psi_1', 'phi_1', 'g_1']
                    preds, gt_labels, loss_val = self.train_source_batch(psi_1, phi_1, g_1, data_batch, optim_keys)
                elif task_type == 'source_ssl_lv':
                    optim_keys = ['g_1']
                    preds, gt_labels, loss_val = self.train_source_batch(psi_1, phi_1, g_1, data_batch, optim_keys)
                elif task_type == 'target_ssl':
                    optim_keys = ['psi_0', 'phi_0', 'g_0']
                    preds, gt_labels, loss_val = self.train_source_batch(psi_0, phi_0, g_0, data_batch, optim_keys)
                elif task_type == 'distill':
                    optim_keys = ['psi_0', 'phi_0', 'g_0']
                    preds, gt_labels, loss_val = self.train_distill_batch(psi_0, phi_0, g_0, psi_1, phi_1, g_1, data_batch, optim_keys)
                elif task_type == 'clip':
                    optim_keys = ['psi_0', 'phi_0', 'g_0']
                    preds, gt_labels, loss_val = self.train_clip_batch(psi_0, phi_0, g_0, psi_1, phi_1, g_1, data_batch, optim_keys)
                elif task_type == 'clipm':
                    optim_keys = ['psi_0', 'phi_0', 'g_0', 'g_1']
                    preds, gt_labels, loss_val = self.train_clip_batch(psi_0, phi_0, g_0, psi_1, phi_1, g_1, data_batch, optim_keys)
                elif task_type in ['target_lv', 'target_seglv']:
                    # linear eval
                    optim_keys = ['f_0']
                    preds, gt_labels, loss_val = self.train_target_batch(psi_0, phi_0, f_0, data_batch, optim_keys)
                elif task_type in ['target_ft', 'target_segft']:
                    optim_keys = ['psi_0', 'phi_0', 'f_0']
                    preds, gt_labels, loss_val = self.train_target_batch(psi_0, phi_0, f_0, data_batch, optim_keys)
                
                self.update_info(epoch_i+1, iter_i+1, whole_iter_i+1, 
                                 max_epoch, each_iters,
                                 metrics_names, metrics_dict, 
                                 loss_val, preds, gt_labels,
                                 task_type, val_loader)
                whole_iter_i += 1
            for key in optim_keys:
                self.schedulers[key].step()
        if max_epoch > 0:
            self.save_checkpoint(self.work_dir, task_type, f'final.pth')

    def train(self, data_names, val_data_names, train_loaders, val_loaders):
        assert len(self._train_tasks) == len(self._max_epochs), f'Max Epochs should have the same length as the train tasks. Now {len(self._train_tasks)} tasks, {len(self._max_epochs)} max epochs.'
        self.logger.info('Before in Training...')
        for task_idx, (task_type, max_epoch) in enumerate(zip(self._train_tasks, self._max_epochs)):
            self._task_idx = task_idx
            self.train_base(task_type, max_epoch, data_names, val_data_names, train_loaders, val_loaders)

    def val_source_batch(self, psi_1, phi_1, g_1, data_batch):
        info_nce_loss = self.losses['task_{}'.format(self._task_idx)]
        in_datas, gt_labels = data_batch
        in_datas = torch.cat(in_datas, dim=0).to(self.device)
        # feedforward
        embeddings = psi_1(in_datas)
        feats = phi_1(embeddings)
        ssl_feats = g_1(feats)
        # optimization
        logits, labels, loss_val = info_nce_loss(ssl_feats)
        return logits, labels, loss_val
        
    def val_distill_batch(self, psi_0, phi_0, g_0, psi_1, phi_1, g_1,
                             data_batch):
        co_info_nce_loss = self.losses['task_{}'.format(self._task_idx)]
        source_datas, target_datas = data_batch
        source_datas = torch.cat(source_datas, dim=0).to(self.device)
        target_datas = torch.cat(target_datas, dim=0).to(self.device)
        # feedforward
        embeddings = psi_1(source_datas)
        source_feats = phi_1(embeddings)
        source_ssl_feats = g_1(source_feats)

        embeddings = psi_0(target_datas)
        target_feats = phi_0(embeddings)
        target_ssl_feats = g_0(target_feats)

        # optimization
        source_logits, source_labels, target_logits, target_labels, loss_val = co_info_nce_loss(target_ssl_feats, source_ssl_feats) # pred, gt
        return {'source':source_logits, 'target':target_logits}, \
                {'source':source_labels, 'target':target_labels}, loss_val
    
    def val_clip_batch(self, psi_0, phi_0, g_0, psi_1, phi_1, g_1,
                             data_batch):
        clip_loss = self.losses['task_{}'.format(self._task_idx)]
        source_datas, target_datas = data_batch
        source_datas = torch.cat(source_datas, dim=0).to(self.device)
        target_datas = torch.cat(target_datas, dim=0).to(self.device)
        # feedforward
        embeddings = psi_1(source_datas)
        source_feats = phi_1(embeddings)
        source_ssl_feats = g_1(source_feats)

        embeddings = psi_0(target_datas)
        target_feats = phi_0(embeddings)
        target_ssl_feats = g_0(target_feats)

        # optimization
        logits_per_source, logits_per_target, labels, loss_val = clip_loss(target_ssl_feats, source_ssl_feats) # pred, gt
        
        return {'source':logits_per_source, 'target':logits_per_target}, \
                {'source':labels, 'target':labels}, loss_val
    

    def val_target_batch(self, psi_0, phi_0, f_0,
                            data_batch):
        ce_loss = self.losses['task_{}'.format(self._task_idx)]
        in_datas, gt_labels = [dterm.to(self.device) for dterm in data_batch]
        gt_labels = gt_labels.to(torch.long)
        # feedforward
        embeddings = psi_0(in_datas)
        feats = phi_0(embeddings)
        preds = f_0(feats)
        # optimization
        loss_val = ce_loss(preds, gt_labels)
        return preds, gt_labels, loss_val

    def val(self, task_type, val_loader, epoch_i=-1, iter_i=-1,  max_epoch=1, each_iters=1):
        val_keys = ['psi_0', 'phi_0', 'f_0',
                    'psi_1', 'phi_1', 'f_1']
        for key in val_keys:
            self.models[key].eval()
        psi_1, phi_1, f_1, g_1 = self.models['psi_1'], self.models['phi_1'], self.models['f_1'], self.models['g_1']
        psi_0, phi_0, f_0, g_0 = self.models['psi_0'], self.models['phi_0'], self.models['f_0'], self.models['g_0']
        metrics_names, metrics_dict = self.initialize_metrics_dict(task_type)
        assert task_type in TASK_TYPES, f'Validation: Unsupported task type {task_type}.'
        with torch.no_grad():
            for iter_i, data_batch in enumerate(val_loader):
                if 'seg' in task_type:
                    input_shape = data_batch[0].size()[2:]
                    f_0.reset_input_shape(input_shape)
                    f_1.reset_input_shape(input_shape)
                # forward loss
                preds, gt_labels = None, None
                if task_type in ['source_ssl', 'source_ssl_lv']:
                    preds, gt_labels, loss_val = self.val_source_batch(psi_1, phi_1, g_1, data_batch)
                # elif task_type == 'source_ssl_lv':
                #     preds, gt_labels, loss_val = self.val_source_batch(psi_1, phi_1, g_1, data_batch)
                elif task_type == 'target_ssl':
                    preds, gt_labels, loss_val = self.val_source_batch(psi_0, phi_0, g_0, data_batch)
                elif task_type == 'distill':
                    preds, gt_labels, loss_val = self.val_distill_batch(psi_0, phi_0, g_0, psi_1, phi_1, g_1, data_batch)
                elif task_type in ['clip', 'clipm']:
                    preds, gt_labels, loss_val = self.val_clip_batch(psi_0, phi_0, g_0, psi_1, phi_1, g_1, data_batch)
                # elif task_type == 'clipm':
                #     preds, gt_labels, loss_val = self.val_clip_batch(psi_0, phi_0, g_0, psi_1, phi_1, g_1, data_batch)
                elif task_type in ['target_lv', 'target_seglv', 'target_ft', 'target_segft']:
                    preds, gt_labels, loss_val = self.val_target_batch(psi_0, phi_0, f_0, data_batch)
                # elif task_type in ['target_ft', 'target_segft']:
                #     preds, gt_labels, loss_val = self.val_target_batch(psi_0, phi_0, f_0, data_batch)
                self.metric_info(metrics_names, metrics_dict, 
                                 loss_val, preds, gt_labels)
                
            metrics_dict_avg = {m:metrics_dict[m].avg for m in metrics_names}
            if self._wandb:
                self.wandb_info(task_type, 'val', metrics_names, metrics_dict, 
                                step=int((epoch_i + iter_i / each_iters) * 100))
            self.logger.info(f'The length of val dataloader: {len(val_loader)}')
            self.display_info(f'{task_type} Val', epoch_i, iter_i, max_epoch, each_iters,
                               metrics_names, metrics_dict_avg)
        for key in val_keys:
            self.models[key].train()

    def run(self, data_names, val_data_names,
            train_loaders, val_loaders,
            workflow, **kwargs):
        """workflow is not implemented
        """
        self.train(data_names, val_data_names, train_loaders, val_loaders)

    def current_lr(self):
        """Get current learning rates.

        Returns:
            list[float] | dict[str, list[float]]: Current learning rates of all
            param groups. If the runner has a dict of optimizers, this method
            will return a dict.
        """
        # lr: Union[List[float], Dict[str, List[float]]]
        if isinstance(self.optims, torch.optim.Optimizer):
            lr = [group['lr'] for group in self.optimizer.param_groups]
        elif isinstance(self.optims, dict):
            lr = dict()
            for name, optim in self.optims.items():
                lr[name] = [group['lr'] for group in optim.param_groups]
        else:
            raise RuntimeError(
                'lr is not applicable because optimizer does not exist.')
        return lr

    def current_momentum(self):
        """Get current momentums.

        Returns:
            list[float] | dict[str, list[float]]: Current momentums of all
            param groups. If the runner has a dict of optimizers, this method
            will return a dict.
        """

        def _get_momentum(optimizer):
            momentums = []
            for group in optimizer.param_groups:
                if 'momentum' in group.keys():
                    momentums.append(group['momentum'])
                elif 'betas' in group.keys():
                    momentums.append(group['betas'][0])
                else:
                    momentums.append(0)
            return momentums

        if self.optims is None:
            raise RuntimeError(
                'momentum is not applicable because optims does not exist.')
        elif isinstance(self.optims, torch.optim.optims):
            momentums = _get_momentum(self.optims)
        elif isinstance(self.optims, dict):
            momentums = dict()
            for name, optim in self.optims.items():
                momentums[name] = _get_momentum(optim)
        return momentums

    
    def save_checkpoint(self,
                        out_dir,
                        task_type,
                        filename_tmpl,
                        save_optimizer = False,
                        meta = None,) -> None:
        
        os.makedirs(os.path.join(out_dir, 'chckpoints', task_type), exist_ok=True)
        save_filename = os.path.join(out_dir, 'chckpoints', task_type, filename_tmpl)
        
        save_dict = {'models':{}, 'optims':{}, 'meta':{}}
        # models
        for model in self.models:
            if isinstance(self.models[model], torch.Tensor):
                save_dict['models'][model] = self.models[model].cpu()
            else:
                save_dict['models'][model] = self.models[model].state_dict()

        # optimizer
        if save_optimizer:
            for optim in self.optims:
                save_dict['optims'][optim] = self.optims[optim].state_dict()

        # meta
        save_dict['meta'] = meta if meta is not None else self.meta
        save_dict['meta'].update(
            {
                'epoch':self._epoch,
                'iter':self._iter
            })
        torch.save(save_dict, save_filename)

    def load_checkpoint(
        self,
        filename,
        map_location = 'cpu',
        strict = False,
        revise_keys = [(r'^module.', '')],
    ):
        checkpoint = torch.load(filename, map_location=map_location)
        for model in self.models:
            if isinstance(self.models[model], torch.Tensor):
                self.models[model] = checkpoint['models'][model].to(self.models[model].device)
            else:
                self.logger.info(f'load state dict {model}')
                self.models[model].load_state_dict(checkpoint['models'][model])

        return checkpoint

    def resume(self,
               checkpoint,
               resume_optimizer = False,
               map_location = 'default'):
        if map_location == 'default':
            if torch.cuda.is_available():
                device_id = torch.cuda.current_device()
                checkpoint = self.load_checkpoint(
                    checkpoint,
                    map_location=lambda storage, loc: storage.cuda(device_id))
            else:
                checkpoint = self.load_checkpoint(checkpoint)
        else:
            checkpoint = self.load_checkpoint(
                checkpoint, map_location=map_location)

        self._epoch = checkpoint['meta']['epoch']
        self._iter = checkpoint['meta']['iter']

        # resume meta information meta
        self.meta = checkpoint['meta']

        # optimizer
        if 'optims' in checkpoint and resume_optimizer:
            for optim in self.optims:
                self.optims[optim].load_state_dict(checkpoint['optims'][optim])

        self.logger.info('resumed epoch %d, iter %d', self._epoch, self._iter)

    def load_pretrained(self,
                        model_types,
                        pretrained_path_dict,
                        map_location = 'cpu',):
        """Only work for loading pretrained simclr backbone, phi_1
        """
        from models.backbones.resnet_wide import resnet50x1, resnet50x2, resnet50x4
        
        RESNET_DICT = {
            'resnet50x1':resnet50x1,
            'resnet50x2':resnet50x2,
            'resnet50x4':resnet50x4,
        }
        
        # assert model_type in RESNET_DICT, f'Unsupported resnet model type {model_type}'
        for model_i, model_type in enumerate(model_types):
            if model_type in RESNET_DICT:
                filename = pretrained_path_dict[model_type]
                filename = os.path.join(os.getcwd(), filename)
                resnet_model = RESNET_DICT[model_type]().to(map_location)
                checkpoint = torch.load(filename, map_location=map_location)
                resnet_model.load_state_dict(checkpoint['state_dict'])

                self.models[f'psi_{model_i}'].load_from_resnet(resnet_model)
                self.models[f'phi_{model_i}'].load_from_resnet(resnet_model)
                
                self.logger.info(f'Model {model_i}: Supported backbone {model_type}. Successfully load pretrained simclr from {filename}')
            else:
                self.logger.info(f'Model {model_i}: Unsupported backbone {model_type}, so just maintain the initial weights.')
            