# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import contextlib
import numpy as np
from inspect import signature
from collections import OrderedDict
from sklearn.metrics import accuracy_score, balanced_accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, top_k_accuracy_score
import torchvision
import heapq
import math
import clip
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import accuracy_score

from semilearn.datasets.augmentation import RandAugment, RandomResizedCropAndInterpolation
from torchvision import transforms
from semilearn.core.hooks import Hook, get_priority, CheckpointHook, TimerHook, LoggingHook, DistSamplerSeedHook, ParamUpdateHook, EvaluationHook, EMAHook, WANDBHook, AimHook
from semilearn.core.utils import get_dataset, get_data_loader, get_optimizer, get_cosine_schedule_with_warmup, Bn_Controller
from semilearn.core.criterions import CELoss, ConsistencyLoss


class CLIP_adapter(nn.Module):
    def __init__(self, clip_weights, gpu, cache_keys=None):
        super(CLIP_adapter, self).__init__()
        self.cate_num, self.feat_dim= clip_weights.shape
        self.ent_scale = torch.nn.Parameter(torch.tensor(0.2).cuda(gpu))
        self.res = nn.Parameter(torch.zeros([self.cate_num, self.feat_dim]).cuda(gpu), requires_grad=True).half()
        self.adapter = nn.Sequential(
            nn.Linear(self.feat_dim, self.feat_dim // 2, bias=False),
            nn.ReLU(inplace=False),
            nn.Linear(self.feat_dim // 2, self.feat_dim, bias=False),
        ).half()
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=False)
        self.fc = nn.Linear(self.feat_dim, self.cate_num).half()
    def forward(self, clip_feat, clip_weights, init=False):
        x = self.adapter(clip_feat)
        if init:
            clip_feat_new = clip_feat
        else:    
            clip_feat_new = clip_feat + x
        clip_feat = self.relu(clip_feat_new)
        clip_logits = self.fc(clip_feat)
        
        res_text = self.res
        new_clip_weights = clip_weights.clone()
        new_clip_weights = clip_weights + res_text
        
        return clip_feat_new, clip_logits, new_clip_weights
    
class AlgorithmBase:
    """
        Base class for algorithms
        init algorithm specific parameters and common parameters
        
        Args:
            - args (`argparse`):
                algorithm arguments
            - net_builder (`callable`):
                network loading function
            - tb_log (`TBLog`):
                tensorboard logger
            - logger (`logging.Logger`):
                logger to use
    """
    def __init__(
        self,
        args,
        net_builder,
        tb_log=None,
        logger=None,
        **kwargs):
        
        # common arguments
        self.args = args
        self.num_classes = args.num_classes
        self.ema_m = args.ema_m
        self.epochs = args.epoch
        self.num_train_iter = args.num_train_iter
        self.num_eval_iter = args.num_eval_iter
        self.num_log_iter = args.num_log_iter
        self.num_iter_per_epoch = int(self.num_train_iter // self.epochs)
        self.lambda_u = args.ulb_loss_ratio 
        self.use_cat = args.use_cat
        self.use_amp = args.amp
        self.clip_grad = args.clip_grad
        self.save_name = args.save_name
        self.save_dir = args.save_dir
        self.resume = args.resume
        self.algorithm = args.algorithm
        # commaon utils arguments
        self.tb_log = tb_log
        self.print_fn = print if logger is None else logger.info
        self.ngpus_per_node = torch.cuda.device_count()
        self.loss_scaler = GradScaler()
        self.amp_cm = autocast if self.use_amp else contextlib.nullcontext
        self.gpu = args.gpu
        self.rank = args.rank
        self.distributed = args.distributed
        self.world_size = args.world_size

        # clip
        self.clip_model, self.preprocess = clip.load(args.clip_backbone)
        self.clip_weights = self.load_text_feature(self.args.dataset)
        self.clip_adapter = CLIP_adapter(self.clip_weights, self.gpu).cuda(self.gpu)

        self.c_feature_dim = self.clip_weights.shape[1]
        
        self.c_optimizer = torch.optim.AdamW(self.clip_adapter.parameters(), lr=self.args.c_lr, eps=0.0001, weight_decay=self.args.weight_decay)
        self.c_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.c_optimizer, self.num_train_iter)

        
        # common model related parameters
        self.it = 0
        self.start_epoch = 0
        self.best_eval_acc, self.best_it = 0.0, 0
        self.bn_controller = Bn_Controller()
        self.net_builder = net_builder
        self.ema = None
        
        # build dataset
        self.dataset_dict = self.set_dataset()

        # build data loader
        self.loader_dict = self.set_data_loader()

        # cv, nlp, speech builder different arguments
        self.model = self.set_model()
        self.ema_model = self.set_ema_model()

        # build optimizer and scheduler
        self.optimizer, self.scheduler = self.set_optimizer()

        # build supervised loss and unsupervised loss
        self.ce_loss = CELoss()
        self.consistency_loss = ConsistencyLoss()

        # other arguments specific to the algorithm
        # self.init(**kwargs)

        # set common hooks during training
        self._hooks = []  # record underlying hooks 
        self.hooks_dict = OrderedDict() # actual object to be used to call hooks
        self.set_hooks()

    def init(self, **kwargs):
        """
        algorithm specific init function, to add parameters into class
        """
        raise NotImplementedError
    
    def load_text_feature(self, dataset):

        cache_dir = os.path.join('./caches', dataset)
        os.makedirs(cache_dir, exist_ok=True)
        
        save_path = cache_dir + "/text_weights_cupl_t.pt"
        clip_weights = torch.load(save_path)
        # print(clip_weights)
        return clip_weights
  
    def set_dataset(self):
        """
        set dataset_dict
        """
        if self.rank != 0 and self.distributed:
            torch.distributed.barrier()
        dataset_dict, tzsl_dict = get_dataset(self.args, self.algorithm, self.args.dataset, self.args.num_labels, self.args.num_classes, self.args.data_dir, self.args.include_lb_to_ulb)
        if self.args.tzsl:
            for key, value in tzsl_dict.items():
                setattr(self, key, value)
                
        if dataset_dict is None:
            return dataset_dict

        self.args.ulb_dest_len = len(dataset_dict['train_ulb']) if dataset_dict['train_ulb'] is not None else 0
        self.args.lb_dest_len = len(dataset_dict['train_lb'])
        self.print_fn("unlabeled data number: {}, labeled data number {}".format(self.args.ulb_dest_len, self.args.lb_dest_len))
        if self.rank == 0 and self.distributed:
            torch.distributed.barrier()
        return dataset_dict

    def set_data_loader(self):
        """
        set loader_dict
        """
        if self.dataset_dict is None:
            return
            
        self.print_fn("Create train and test data loaders")
        loader_dict = {}
        loader_dict['train_lb'] = get_data_loader(self.args,
                                                  self.dataset_dict['train_lb'],
                                                  self.args.batch_size,
                                                  data_sampler=self.args.train_sampler,
                                                  num_iters=self.num_train_iter,
                                                  num_epochs=self.epochs,
                                                  num_workers=self.args.num_workers,
                                                  distributed=self.distributed)

        loader_dict['train_ulb'] = get_data_loader(self.args,
                                                   self.dataset_dict['train_ulb'],
                                                   self.args.batch_size * self.args.uratio,
                                                   data_sampler=self.args.train_sampler,
                                                   num_iters=self.num_train_iter,
                                                   num_epochs=self.epochs,
                                                   num_workers=2 * self.args.num_workers,
                                                   distributed=self.distributed)

        loader_dict['eval'] = get_data_loader(self.args,
                                              self.dataset_dict['eval'],
                                              self.args.eval_batch_size,
                                              # make sure data_sampler is None for evaluation
                                              data_sampler=None,
                                              num_workers=self.args.num_workers,
                                              drop_last=False)
        
        if self.args.tzsl:
            loader_dict['tzsl'] = get_data_loader(self.args,
                                                  self.dataset_dict['tzsl'],
                                                  self.args.eval_batch_size,
                                                  # make sure data_sampler is None for evaluation
                                                  data_sampler=None,
                                                  num_workers=self.args.num_workers,
                                                  drop_last=False)
        
        if self.dataset_dict['test'] is not None:
            loader_dict['test'] =  get_data_loader(self.args,
                                                   self.dataset_dict['test'],
                                                   self.args.eval_batch_size,
                                                   # make sure data_sampler is None for evaluation
                                                   data_sampler=None,
                                                   num_workers=self.args.num_workers,
                                                   drop_last=False)
        self.print_fn(f'[!] data loader keys: {loader_dict.keys()}')
        return loader_dict

    def set_optimizer(self):
        """
        set optimizer for algorithm
        """
        self.print_fn("Create optimizer and scheduler")
        optimizer = get_optimizer(self.model, self.args.optim, self.args.lr, self.args.momentum, self.args.weight_decay, self.args.layer_decay)
        scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                    self.num_train_iter,
                                                    num_warmup_steps=self.args.num_warmup_iter)
        return optimizer, scheduler

    def set_model(self):
        """
        initialize model
        """
        model = self.net_builder(num_classes=self.num_classes, pretrained=self.args.use_pretrain, pretrained_path=self.args.pretrain_path)
        return model

    def set_ema_model(self):
        """
        initialize ema model from model
        """
        ema_model = self.net_builder(num_classes=self.num_classes)
        ema_model.load_state_dict(self.model.state_dict())
        return ema_model

    def set_hooks(self):
        """
        register necessary training hooks
        """
        # parameter update hook is called inside each train_step
        self.register_hook(ParamUpdateHook(), None, "HIGHEST")
        self.register_hook(EMAHook(), None, "HIGH")
        self.register_hook(EvaluationHook(), None, "HIGH")
        self.register_hook(CheckpointHook(), None, "HIGH")
        self.register_hook(DistSamplerSeedHook(), None, "NORMAL")
        self.register_hook(TimerHook(), None, "LOW")
        self.register_hook(LoggingHook(), None, "LOWEST")
        if self.args.use_wandb:
            self.register_hook(WANDBHook(), None, "LOWEST")
        if self.args.use_aim:
            self.register_hook(AimHook(), None, "LOWEST")

    def process_batch(self, input_args=None, **kwargs):
        """
        process batch data, send data to cuda
        NOTE **kwargs should have the same arguments to train_step function as keys to work properly
        """
        if input_args is None:

            input_args = signature(self.train_step).parameters
            input_args = list(input_args.keys())

        input_dict = {}

        for arg, var in kwargs.items():
            if not arg in input_args:
                continue
            
            if var is None:
                continue
            
            # send var to cuda
            if isinstance(var, dict):
                var = {k: v.cuda(self.gpu) for k, v in var.items()}
            else:
                var = var.cuda(self.gpu)
            input_dict[arg] = var
        return input_dict
    

    def process_out_dict(self, out_dict=None, **kwargs):
        """
        process the out_dict as return of train_step
        """
        if out_dict is None:
            out_dict = {}

        for arg, var in kwargs.items():
            out_dict[arg] = var
        
        # process res_dict, add output from res_dict to out_dict if necessary
        return out_dict


    def process_log_dict(self, log_dict=None, prefix='train', **kwargs):
        """
        process the tb_dict as return of train_step
        """
        if log_dict is None:
            log_dict = {}

        for arg, var in kwargs.items():
            log_dict[f'{prefix}/' + arg] = var
        return log_dict

    def compute_prob(self, logits):
        return torch.softmax(logits, dim=-1)

    def cls_acc(output, target, topk=1):
        pred = output.topk(topk, 1, True, True)[1].t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
        acc = 100 * acc / target.shape[0]
        return acc

    def train_step(self, idx_lb, x_lb, y_lb, idx_ulb, x_ulb_w, x_ulb_s):
        """
        train_step specific to each algorithm
        """
        # implement train step for each algorithm
        # compute loss
        # update model 
        # record log_dict
        # return log_dict
        raise NotImplementedError


    def train(self):
        
        """
        train function
        """
        self.clip_model.eval()
        self.model.train()
        self.clip_adapter.train()
        self.call_hook("before_run")

        self.print_fn("validating...")
        eval_dict = self.evaluate('eval')
        # self.log_dict.update(eval_dict)
        
        for epoch in range(self.start_epoch, self.epochs):
            self.epoch = epoch
            ## T-zero-shot learning
            if self.args.tzsl:
                new_indices, new_labels = self.get_tzsl_lbind()
                new_indices = [int(idx) for idx in new_indices]
                self.dataset_dict['train_lb'] = self.set_tzsl_lb(new_indices, new_labels)
                self.loader_dict['train_lb'] = get_data_loader(
                    self.args,
                    self.dataset_dict['train_lb'],
                    self.args.batch_size,
                    data_sampler=self.args.train_sampler,
                    num_iters=self.num_train_iter,
                    num_epochs=self.epochs,
                    num_workers=self.args.num_workers,
                    distributed=self.distributed
                )
                print("lb_loader_dict changed")
            
            # prevent the training iterations exceed args.num_train_iter
            if self.it >= self.num_train_iter:
                break
            
            self.call_hook("before_train_epoch")
            for data_lb, data_ulb in zip(self.loader_dict['train_lb'],
                                        self.loader_dict['train_ulb']):
                if self.it >= self.num_train_iter:
                    break

                self.call_hook("before_train_step")
                self.out_dict, self.log_dict = self.train_step(**self.process_batch(**data_lb, **data_ulb))
                self.call_hook("after_train_step")
                self.it += 1
            
            self.call_hook("after_train_epoch")

        self.call_hook("after_run")


    def c_mm_transform(self, c_feats, dul_mask=None, alpha=0.4):

        selected_c_feats = c_feats[dul_mask.bool()]
        
        if selected_c_feats.size(0) == 0:
            return c_feats, None, None
        
        lam = np.random.beta(alpha, alpha)
        lam = max(lam, 1 - lam)
        
        batch_size = selected_c_feats.size(0)
        
        index = torch.randperm(batch_size)
        if torch.cuda.is_available():
            index = index.cuda(selected_c_feats.device)
        
        mixed_x = lam * selected_c_feats + (1 - lam) * selected_c_feats[index, :]
        
        c_feats[dul_mask.bool()] = mixed_x
        
        return c_feats, index, lam

    
    
    def get_tzsl_lbind(self, zl_dest='tzsl'):
        """
        Efficient evaluation function with dynamic top-k sample index extraction for each label.
        """
        self.clip_model.eval()
        self.clip_adapter.eval()
        self.ema.apply_shadow()
        p_target = torch.ones(self.num_classes) / self.num_classes
        p_target = p_target.cuda(self.gpu)
        all_probs = []
        top_k = 3 * round(math.sqrt(self.num_classes))
        print("num_labels per class:")
        print(top_k)
        
        init = False
        if self.it == 0:
            init = True
            print("init")
        zl_loader = self.loader_dict[zl_dest]
        label_to_top_k_heap = {}  # Dictionary to maintain a min-heap for each label

        def add_to_heap(label, index, prob, y_true):
            """Add an entry to the heap for a specific label."""
            if label not in label_to_top_k_heap:
                label_to_top_k_heap[label] = []
            heapq.heappush(label_to_top_k_heap[label], (prob, index, y_true))
            if len(label_to_top_k_heap[label]) > top_k:
                heapq.heappop(label_to_top_k_heap[label])  # Maintain top-k size

        with torch.no_grad():
            for data in zl_loader:
                x_c = data['x_c_lb']
                y = data['y_lb']
                index = data['idx_lb']
                index = index.cuda(self.gpu)
                x_c = x_c.cuda(self.gpu)

                c_feats = self.clip_model.encode_image(x_c)
                c_feat_new, c_logits, clip_weights_new = self.clip_adapter(c_feats, self.clip_weights, init)
                c_feat_new /= c_feat_new.norm(dim=-1, keepdim=True)
                cc_logits = 100. * c_feat_new @ clip_weights_new.T

                c_logits = c_logits + cc_logits

                probs = F.softmax(c_logits, dim=-1)
                max_prob, pred_label = torch.max(probs, dim=-1)

                # Update heaps dynamically
                for i in range(pred_label.size(0)):
                    add_to_heap(pred_label[i].item(), index[i].item(), max_prob[i].item(), y[i].item())
        label_to_top_k_indices = {
            label: [entry[1] for entry in sorted(heap, key=lambda x: -x[0])]
            for label, heap in label_to_top_k_heap.items()
        }
        label_to_top = {
            label: [entry[2] for entry in sorted(heap, key=lambda x: -x[0])]
            for label, heap in label_to_top_k_heap.items()
        }
        new_indices = []
        new_labels = []
        y_tt = []
        # print(label_to_top_k_indices)
        # print(label_to_top)
        for label, indices in label_to_top_k_indices.items():
            new_indices.extend(indices)
            new_labels.extend([label] * len(indices))
        for label, y_t in label_to_top.items():
            y_tt.extend(y_t)

        accuracy = accuracy_score(y_tt, new_labels)
        # print(accuracy)
        
        len_y_tt = len(y_tt)

        top_k_num_classes = top_k * self.num_classes
        ratio = len_y_tt / top_k_num_classes
        # print(ratio)
        self.ema.restore()
        self.clip_adapter.train()

        return new_indices, new_labels
    

    def set_tzsl_lb(self, new_indices, new_labels):
        """
        
        """
        from semilearn.datasets.cv_datasets.datasetbase import BasicDataset
        from semilearn.datasets.cv_datasets.eurosat import EuroSat
        from semilearn.datasets.cv_datasets.imagenet import ImagenetDataset

        if self.args.dataset == "eurosat":
            tzsl_lb = EuroSat(self.algorithm, self.data_dir, split="trainval", idx_list=self.tulb_idx[new_indices], transform=self.tfm_wk, transform_strong=self.tfm_st, clip_transform=self.clip_tfm, cg_tgts=True, tgts=new_labels)
        elif self.args.dataset in ["cifar10", "cifar100", "dtd"]:
            tzsl_lb = BasicDataset(self.algorithm, self.raw_data[new_indices], new_labels, self.num_classes, self.tfm_wk, False, self.tfm_st, self.tfm_st, self.clip_tfm, False)
        elif self.args.dataset in ["imagenet", "imagenet127"]:
            tzsl_lb = ImagenetDataset(root=self.data_dir, transform=self.tfm_wk, clip_transform=self.clip_tfm, ulb=False, alg=self.algorithm, sample_ind=new_indices, tgts=new_labels, shff=False)
        else:
            return None
        
        return tzsl_lb

   
    
    def evaluate(self, eval_dest='eval', out_key='logits', return_logits=False):
        """
        evaluation function
        """
        self.model.eval()
        self.clip_model.eval()
        self.clip_adapter.eval()
        self.ema.apply_shadow()

        eval_loader = self.loader_dict[eval_dest]
        total_loss = 0.0
        total_num = 0.0
        y_true = []
        y_pred = []
        y_probs = []
        y_logits = []
        y_c_pred = []
        y_c_probs = []
        with torch.no_grad():
            for data in eval_loader:
                x = data['x_lb']
                x_c = data['x_c_lb']
                y = data['y_lb']
                if isinstance(x, dict):
                    x = {k: v.cuda(self.gpu) for k, v in x.items()}
                else:
                    x = x.cuda(self.gpu)
                y = y.cuda(self.gpu)
                x_c = x_c.cuda(self.gpu)

                num_batch = y.shape[0]
                total_num += num_batch

                c_feats = self.clip_model.encode_image(x_c)
                c_feat_new, c_logits, clip_weights_new = self.clip_adapter(c_feats, self.clip_weights)
                c_feat_new /= c_feat_new.norm(dim=-1, keepdim=True)
                cc_logits = 100. * c_feat_new @ clip_weights_new.T
                
                c_logits = c_logits + cc_logits
                
                logits = self.model(x)[out_key]
                
                loss = F.cross_entropy(logits, y, reduction='mean', ignore_index=-1)
                y_true.extend(y.cpu().tolist())
                y_pred.extend(torch.max(logits, dim=-1)[1].cpu().tolist())
                y_c_pred.extend(torch.max(c_logits, dim=-1)[1].cpu().tolist())
                y_logits.append(logits.cpu().numpy())
                y_probs.extend(torch.softmax(logits, dim=-1).cpu().tolist())
                y_c_probs.extend(torch.softmax(c_logits, dim=-1).cpu().tolist())
                total_loss += loss.item() * num_batch
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        y_c_pred = np.array(y_c_pred)
        y_logits = np.concatenate(y_logits)
        top1 = accuracy_score(y_true, y_pred)
        top1_c = accuracy_score(y_true, y_c_pred)
        top5 = top_k_accuracy_score(y_true, y_probs, k=5)
        top5_c = top_k_accuracy_score(y_true, y_c_probs, k=5)
        balanced_top1 = balanced_accuracy_score(y_true, y_pred)
        precision_c = precision_score(y_true, y_c_pred, average='macro')
        precision = precision_score(y_true, y_pred, average='macro')
        recall_c = recall_score(y_true, y_c_pred, average='macro')
        recall = recall_score(y_true, y_pred, average='macro')
        F1_c = f1_score(y_true, y_c_pred, average='macro')
        F1 = f1_score(y_true, y_pred, average='macro')

        cf_mat = confusion_matrix(y_true, y_pred, normalize='true')
        cf_mat_c = confusion_matrix(y_true, y_c_pred, normalize='true')
        # self.print_fn(top1_c)
        self.print_fn('confusion matrix:\n' + np.array_str(cf_mat))
        self.print_fn('confusion matrix_c:\n' + np.array_str(cf_mat_c))
        self.ema.restore()
        self.model.train()
        self.clip_adapter.train()

        eval_dict = {eval_dest+'/loss': total_loss / total_num, eval_dest+'/top-1-acc': top1, eval_dest+'/top-1c-acc': top1_c, eval_dest+'/top-5-acc': top5, eval_dest+'/top-5c-acc': top5_c, 
                     eval_dest+'/balanced_acc': balanced_top1, eval_dest+'/precision': precision, eval_dest+'/precisionc': precision_c, eval_dest+'/recall': recall, eval_dest+'/recall_c': recall_c, eval_dest+'/F1': F1, eval_dest+'/F1c': F1_c}
        if return_logits:
            eval_dict[eval_dest+'/logits'] = y_logits
        return eval_dict


    def get_save_dict(self):
        """
        make easier for saving model when need save additional arguments
        """
        # base arguments for all models
        save_dict = {
            'model': self.model.state_dict(),
            'clip_adapter': self.clip_adapter.state_dict(),
            'c_scheduler': self.c_scheduler.state_dict(),
            'c_optimizer': self.c_optimizer.state_dict(),
            'ema_model': self.ema_model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'loss_scaler': self.loss_scaler.state_dict(),
            'it': self.it + 1,
            'epoch': self.epoch + 1,
            'best_it': self.best_it,
            'best_eval_acc': self.best_eval_acc,
        }
        if self.scheduler is not None:
            save_dict['scheduler'] = self.scheduler.state_dict()
        return save_dict
    

    def save_model(self, save_name, save_path):
        """
        save model and specified parameters for resume
        """
        if not os.path.exists(save_path):
            os.makedirs(save_path, exist_ok=True)
        save_filename = os.path.join(save_path, save_name)
        save_dict = self.get_save_dict()
        torch.save(save_dict, save_filename)
        self.print_fn(f"model saved: {save_filename}")


    def load_model(self, load_path):
        """
        load model and specified parameters for resume
        """
        checkpoint = torch.load(load_path, map_location='cpu')
        self.model.load_state_dict(checkpoint['model'])
        self.clip_adapter.load_state_dict(checkpoint['clip_adapter'])
        self.c_scheduler.load_state_dict(checkpoint['c_scheduler'])
        self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])
        self.ema_model.load_state_dict(checkpoint['ema_model'])
        self.loss_scaler.load_state_dict(checkpoint['loss_scaler'])
        self.it = checkpoint['it']
        self.start_epoch = checkpoint['epoch']
        self.epoch = self.start_epoch
        self.best_it = checkpoint['best_it']
        self.best_eval_acc = checkpoint['best_eval_acc']
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        if self.scheduler is not None and 'scheduler' in checkpoint:
            self.scheduler.load_state_dict(checkpoint['scheduler'])
        self.print_fn('Model loaded')
        return checkpoint

    def check_prefix_state_dict(self, state_dict):
        """
        remove prefix state dict in ema model
        """
        new_state_dict = dict()
        for key, item in state_dict.items():
            if key.startswith('module'):
                new_key = '.'.join(key.split('.')[1:])
            else:
                new_key = key
            new_state_dict[new_key] = item
        return new_state_dict

    def register_hook(self, hook, name=None, priority='NORMAL'):
        """
        Ref: https://github.com/open-mmlab/mmcv/blob/a08517790d26f8761910cac47ce8098faac7b627/mmcv/runner/base_runner.py#L263
        Register a hook into the hook list.
        The hook will be inserted into a priority queue, with the specified
        priority (See :class:`Priority` for details of priorities).
        For hooks with the same priority, they will be triggered in the same
        order as they are registered.
        Args:
            hook (:obj:`Hook`): The hook to be registered.
            hook_name (:str, default to None): Name of the hook to be registered. Default is the hook class name.
            priority (int or str or :obj:`Priority`): Hook priority.
                Lower value means higher priority.
        """
        assert isinstance(hook, Hook)
        if hasattr(hook, 'priority'):
            raise ValueError('"priority" is a reserved attribute for hooks')
        priority = get_priority(priority)
        hook.priority = priority  # type: ignore
        hook.name = name if name is not None else type(hook).__name__

        # insert the hook to a sorted list
        inserted = False
        for i in range(len(self._hooks) - 1, -1, -1):
            if priority >= self._hooks[i].priority:  # type: ignore
                self._hooks.insert(i + 1, hook)
                inserted = True
                break
        
        if not inserted:
            self._hooks.insert(0, hook)

        # call set hooks
        self.hooks_dict = OrderedDict()
        for hook in self._hooks:
            self.hooks_dict[hook.name] = hook
        


    def call_hook(self, fn_name, hook_name=None, *args, **kwargs):
        """Call all hooks.
        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
            hook_name (str): The specific hook name to be called, such as
                "param_update" or "dist_align", uesed to call single hook in train_step.
        """
        
        if hook_name is not None:
            return getattr(self.hooks_dict[hook_name], fn_name)(self, *args, **kwargs)
        
        for hook in self.hooks_dict.values():
            if hasattr(hook, fn_name):
                getattr(hook, fn_name)(self, *args, **kwargs)

    def registered_hook(self, hook_name):
        """
        Check if a hook is registered
        """
        return hook_name in self.hooks_dict


    @staticmethod
    def get_argument():
        """
        Get specificed arguments into argparse for each algorithm
        """
        return {}



class ImbAlgorithmBase(AlgorithmBase):
    def __init__(self, args, net_builder, tb_log=None, logger=None, **kwargs):
        super().__init__(args, net_builder, tb_log, logger, **kwargs)
        
        # imbalanced arguments
        self.lb_imb_ratio = self.args.lb_imb_ratio
        self.ulb_imb_ratio = self.args.ulb_imb_ratio
        self.imb_algorithm = self.args.imb_algorithm
    
    def imb_init(self, *args, **kwargs):
        """
        intiialize imbalanced algorithm parameters
        """
        pass 

    def set_optimizer(self):
        if 'vit' in self.args.net and self.args.dataset in ['cifar100', 'food101', 'semi_aves', 'semi_aves_out']:
            return super().set_optimizer() 
        elif self.args.dataset in ['imagenet', 'imagenet127']:
            return super().set_optimizer() 
        else:
            self.print_fn("Create optimizer and scheduler")
            optimizer = get_optimizer(self.model, self.args.optim, self.args.lr, self.args.momentum, self.args.weight_decay, self.args.layer_decay, bn_wd_skip=False)
            scheduler = None
            return optimizer, scheduler
