from calendar import c
import os
import sys
import time
import random
import copy
import logging
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import transforms
from collections import OrderedDict, defaultdict
from randaugment import RandAugment
from tqdm import tqdm

from models import get_model
from datasets import get_dataset
from datasets.multiDatasets import multiDatasets
from utils.augment import Cutout
from utils.memory import Memory
from utils.online_sampler import OnlineSampler, OnlineTestSampler, NIIDSampler
from utils.indexed_dataset import IndexedDataset
from utils.train_utils import select_optimizer, select_scheduler

##################################################################
# This is trainer with a DistributedDataParallel                 #
# Based on the following tutorial:                               #
# https://github.com/pytorch/examples/blob/main/imagenet/main.py #
# And Deit by FaceBook                                           #
# https://github.com/facebookresearch/deit                       #
##################################################################


class _Trainer():

    def __init__(self, *args, **kwargs) -> None:

        self.args = kwargs

        self.method = kwargs.get("method")

        self.n = kwargs.get("n")
        self.m = kwargs.get("m")
        self.rnd_NM = kwargs.get("rnd_NM")

        self.n_tasks = kwargs.get("n_tasks")
        self.dataset_name = kwargs.get("dataset")
        self.rnd_seed = kwargs.get("rnd_seed")

        self.memory_size = kwargs.get("memory_size")
        self.log_path = kwargs.get("log_path")
        self.model_name = kwargs.get("model_name")
        self.opt_name = kwargs.get("opt_name")
        self.sched_name = kwargs.get("sched_name")
        self.batchsize = kwargs.get("batchsize")
        self.n_worker = kwargs.get("n_worker")
        self.lr = kwargs.get("lr")
        self.init_model = kwargs.get("init_model")
        self.init_opt = kwargs.get("init_opt")
        self.topk = kwargs.get("topk")
        self.use_amp = kwargs.get("use_amp")
        self.transforms = kwargs.get("transforms")
        self.reg_coef = kwargs.get("reg_coef")
        self.data_dir = kwargs.get("data_dir")
        self.debug = kwargs.get("debug")
        self.note = kwargs.get("note")
        self.selection_size = kwargs.get("selection_size")

        self.blurry_mod = kwargs.get("blurry_mod", 'gaussian')
        self.blurry_always = kwargs.get("blurry_always", 0.3)
        self.blurry_sudden = kwargs.get("blurry_sudden", 0.3)
        self.ba_sigma = kwargs.get("ba_sigma", 0.15)
        self.bs_alpha = kwargs.get("bs_alpha", 1.5)

        self.eval_period = kwargs.get("eval_period")
        self.temp_batchsize = kwargs.get("temp_batchsize")
        self.online_iter = kwargs.get("online_iter")
        self.num_gpus = kwargs.get("num_gpus")
        self.workers_per_gpu = kwargs.get("workers_per_gpu")
        self.imp_update_period = kwargs.get("imp_update_period")

        # for zero shot evaluation
        self.test_only_flag = kwargs.get("test_only", False)
        self.zero_shot_dataset = kwargs.get("zero_shot_dataset").split(',')

        # for distillation
        self.teacher_model = None
        self.kd_lambda = self.args.get('kd_lambda', 0)
        self.kd_ema_decay = self.args.get('kd_ema_decay', 0.999)

        # for distributed training
        self.dist_backend = 'nccl'
        self.dist_url = 'env://'

        self.lr_step = kwargs.get("lr_step")  # for adaptive LR
        self.lr_length = kwargs.get("lr_length")  # for adaptive LR
        self.lr_period = kwargs.get("lr_period")  # for adaptive LR

        self.memory_epoch = kwargs.get("memory_epoch")  # for RM
        self.distilling = kwargs.get("distilling")  # for BiC
        self.agem_batch = kwargs.get("agem_batch")  # for A-GEM
        self.mir_cands = kwargs.get("mir_cands")  # for MIR

        self.start_time = time.time()
        self.num_updates = 0
        self.train_count = 0

        self.ngpus_per_nodes = torch.cuda.device_count()
        self.world_size = 1
        if "WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != '':
            self.world_size = int(
                os.environ["WORLD_SIZE"]) * self.ngpus_per_nodes
        else:
            self.world_size = self.world_size * self.ngpus_per_nodes
        self.distributed = self.world_size > 1

        if self.distributed:
            self.batchsize = self.batchsize // self.world_size
        if self.temp_batchsize is None:
            self.temp_batchsize = self.batchsize // 2
        if self.temp_batchsize > self.batchsize:
            self.temp_batchsize = self.batchsize
        self.memory_batchsize = self.batchsize - self.temp_batchsize

        if 'debug' in self.note:
            self.log_dir = os.path.join(self.log_path, "debug")
        elif self.test_only_flag:
            self.log_dir = kwargs.get("log_path")
        else:
            if self.blurry_mod == 'niid':
                self.log_dir = os.path.join(
                    self.log_path, self.dataset_name,
                    f"TASK{self.n_tasks}BA{int(self.blurry_always*100)}_{self.ba_sigma}BS{int(100*self.blurry_sudden)}_{self.bs_alpha}",
                    f"{self.note}_{datetime.datetime.now().strftime('%y%m%d%H')}"
                )
            else:
                self.log_dir = os.path.join(
                    self.log_path, self.dataset_name,
                    f"TASK{self.n_tasks}N{self.n}M{self.m}",
                    f"{self.note}_{datetime.datetime.now().strftime('%y%m%d%H')}"
                )

    def setup_distributed_model(self):
        logging.info("Building model...")
        self.model = self.model.to(self.device)
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)

        self.model.to(self.device)
        self.model_without_ddp = self.model
        if self.distributed:
            self.model = torch.nn.parallel.DistributedDataParallel(self.model)
            self.model._set_static_graph()
            self.model_without_ddp = self.model.module
        self.criterion = self.model_without_ddp.loss_fn if hasattr(
            self.model_without_ddp, "loss_fn") else nn.CrossEntropyLoss(
                reduction="mean")
        self.optimizer = select_optimizer(self.opt_name, self.lr, self.model)
        self.scheduler = select_scheduler(self.sched_name, self.optimizer)

        n_params = sum(p.numel() for p in self.model_without_ddp.parameters())
        logging.info(f"Total Parameters :\t{n_params}")
        n_params = sum(p.numel() for p in self.model_without_ddp.parameters()
                       if p.requires_grad)
        logging.info(f"Learnable Parameters :\t{n_params}")

    def setup_test_dataloader(self, dataset_names):

        test_transform = transforms.Compose([
            transforms.Resize((self.inp_size, self.inp_size)),
            transforms.ToTensor(),
        ])

        meta, classes_names, test_dataloader = [], [], []
        for _dataset_name in dataset_names:
            _dataset, _ = get_dataset(_dataset_name)
            if isinstance(_dataset, (tuple, list)):
                _dataset = multiDatasets(
                    root=self.data_dir,
                    datasets=_dataset,
                    train=False,
                    download=True,
                    transform=copy.deepcopy(test_transform))
            else:
                _dataset = _dataset(root=self.data_dir,
                                    train=False,
                                    download=True,
                                    transform=copy.deepcopy(test_transform))
            classes_names.append(_dataset.classes_names)
            meta.append((
                _dataset_name,
                len(_dataset.classes_names),
                len(_dataset),
            ))
            _dataloader = DataLoader(
                _dataset,
                batch_size=self.batchsize,
                shuffle=False,
                num_workers=self.n_worker,
                #  pin_memory=True,
                drop_last=False)
            test_dataloader.append(_dataloader)

        return test_dataloader, classes_names, meta

    def setup_dataset(self):
        self.dataset, self.n_classes = get_dataset(self.dataset_name)
        # get dataset
        if isinstance(self.dataset, (tuple, list)):
            self.train_dataset = multiDatasets(root=self.data_dir,
                                               datasets=self.dataset,
                                               train=True,
                                               download=True,
                                               transform=self.train_transform)
            self.test_dataset = multiDatasets(root=self.data_dir,
                                              datasets=self.dataset,
                                              train=False,
                                              download=True,
                                              transform=self.test_transform)
        else:
            self.train_dataset = self.dataset(root=self.data_dir,
                                              train=True,
                                              download=True,
                                              transform=self.train_transform)
            self.test_dataset = self.dataset(root=self.data_dir,
                                             train=False,
                                             download=True,
                                             transform=self.test_transform)
        self.n_classes = len(self.train_dataset.classes)

        self.exposed_classes = []
        self.exposed_classes_names = []
        self.seen = 0

        self._c_grad = torch.zeros(self.n_classes, 3)
        self._c_grad_count = torch.zeros(self.n_classes, 3, dtype=torch.long)
        self._c_grad_timeline = []

    def setup_transforms(self, inp_size=224):
        train_transform = []
        self.cutmix = "cutmix" in self.transforms
        if "autoaug" in self.transforms:
            train_transform.append(transforms.ToTensor())
            train_transform.append(lambda x: (x * 255).type(torch.uint8))
            if 'cifar' in self.dataset_name:
                train_transform.append(
                    transforms.AutoAugment(
                        transforms.AutoAugmentPolicy('cifar10')))
            elif 'imagenet' in self.dataset_name:
                train_transform.append(
                    transforms.AutoAugment(
                        transforms.AutoAugmentPolicy('imagenet')))
            elif 'svhn' in self.dataset_name:
                train_transform.append(
                    transforms.AutoAugment(
                        transforms.AutoAugmentPolicy('svhn')))
            train_transform.append(lambda x: x.type(torch.float32) / 255)

        if "cutout" in self.transforms:
            train_transform.append(Cutout(size=16))
        if "randaug" in self.transforms:
            train_transform.append(RandAugment())

        self.train_transform = transforms.Compose([
            *train_transform,
            transforms.Resize((inp_size, inp_size)),
            transforms.RandomCrop(inp_size, padding=4),
            transforms.RandomHorizontalFlip(),
            # transforms.Normalize(self.mean, self.std),
        ])
        self.test_transform = transforms.Compose([
            transforms.Resize((inp_size, inp_size)),
            transforms.ToTensor(),
            # transforms.Normalize(self.mean, self.std),
        ])

    def run(self):
        # Distributed Launch
        if self.ngpus_per_nodes > 1:
            mp.spawn(self.main_worker, nprocs=self.ngpus_per_nodes, join=True)
        else:
            self.main_worker(0)

    def main_worker(self, gpu) -> None:
        os.makedirs(self.log_dir, exist_ok=True)
        self.gpu = gpu % self.ngpus_per_nodes
        self.device = torch.device(self.gpu)
        if self.distributed:
            self.local_rank = self.gpu
            if 'SLURM_PROCID' in os.environ.keys():
                self.rank = int(os.environ['SLURM_PROCID']
                                ) * self.ngpus_per_nodes + self.gpu
                print(
                    f"| Init Process group {os.environ['SLURM_PROCID']} : {self.local_rank}"
                )
            elif 'WORLD_SIZE' in os.environ.keys():
                self.rank = int(
                    os.environ['RANK']) * self.ngpus_per_nodes + self.gpu
                print(
                    f"| Init Process group {os.environ['RANK']} : {self.local_rank}"
                )
            else:
                self.rank = self.gpu
                print(f"| Init Process group 0 : {self.local_rank}")
            if 'MASTER_ADDR' not in os.environ.keys():
                os.environ['MASTER_ADDR'] = '127.0.0.1'
                os.environ['MASTER_PORT'] = '12701'
            torch.cuda.set_device(self.gpu)
            time.sleep(self.rank * 0.1)  # prevent port collision
            dist.init_process_group(backend=self.dist_backend,
                                    init_method=self.dist_url,
                                    world_size=self.world_size,
                                    rank=self.rank)
            torch.distributed.barrier()
            self.setup_for_distributed(self.is_main_process())
        else:
            self.setup_for_distributed(True)

        logging.info(str(self.args))

        if self.rnd_seed is not None:
            random.seed(self.rnd_seed)
            np.random.seed(self.rnd_seed)
            torch.manual_seed(self.rnd_seed)
            torch.cuda.manual_seed(self.rnd_seed)
            torch.cuda.manual_seed_all(self.rnd_seed)  # if use multi-GPU
            cudnn.deterministic = True
            logging.info('You have chosen to seed training. '
                         'This will turn on the CUDNN deterministic setting, '
                         'which can slow down your training considerably! '
                         'You may see unexpected behavior when restarting '
                         'from checkpoints.')
        cudnn.benchmark = False

        try:
            logging.info(f"Set dataset ({self.method})")
            self.setup_transforms()
            self.setup_dataset()
            self.memory = Memory()
        except Exception as e:
            logging.error(str(e))

        logging.info(f"Select a CIL method ({self.method})")
        logging.info(f"Building model ({self.model_name})")
        self.model, self.inp_size = get_model(
            model_name=self.model_name,
            method=self.method,
            num_classes=self.n_classes,
            device=self.device,
            peft_encoder=self.args['peft_encoder'],
        )

        self.setup_distributed_model()
        self.total_samples = len(self.train_dataset)

        train_dataset = IndexedDataset(self.train_dataset)
        try:
            if self.blurry_mod == 'niid':
                self.train_sampler = NIIDSampler(
                    train_dataset,
                    self.n_tasks,
                    self.rnd_seed,
                    blurry_always=self.blurry_always,
                    blurry_sudden=self.blurry_sudden,
                    sigma=self.ba_sigma,
                    alpha=self.bs_alpha,
                    varing_NM=self.rnd_NM)
            else:
                self.train_sampler = OnlineSampler(train_dataset, self.n_tasks,
                                                   self.m, self.n,
                                                   self.rnd_seed, self.rnd_NM)
            self.train_dataloader = DataLoader(train_dataset,
                                               batch_size=self.batchsize,
                                               sampler=self.train_sampler,
                                               num_workers=self.n_worker,
                                               pin_memory=True)
            self.test_dataloader = DataLoader(self.test_dataset,
                                              batch_size=self.batchsize,
                                              shuffle=False,
                                              num_workers=self.n_worker,
                                              pin_memory=True)
        except Exception as e:
            logging.error(str(e))

        logging.info(f"Incrementally training {self.n_tasks} tasks")
        task_records = defaultdict(list)
        eval_results = defaultdict(list)
        samples_cnt = 0

        num_eval = self.eval_period

        try:
            for task_id in range(self.n_tasks):
                if self.method == "joint" and task_id > 0:
                    return

                logging.info("#" * 50)
                logging.info(f"# Task {task_id} Session")
                logging.info("#" * 50)
                logging.info("[2-1] Prepare a datalist for the current task")

                self.train_sampler.set_task(task_id)
                self.online_before_task(task_id)

                for i, (images, labels,
                        idx) in enumerate(self.train_dataloader):
                    if self.debug and (i + 1) * self.temp_batchsize >= 500:
                        break
                    samples_cnt += images.size(0) * self.world_size
                    loss, acc = self.online_step(images, labels, idx)

                    self.report_training(samples_cnt, loss, acc)

                    if samples_cnt > num_eval:
                        with torch.no_grad():
                            test_sampler = OnlineTestSampler(
                                self.test_dataset, self.exposed_classes)
                            test_dataloader = DataLoader(
                                self.test_dataset,
                                batch_size=self.batchsize * 2,
                                sampler=test_sampler,
                                num_workers=self.n_worker)
                            eval_dict = self.online_evaluate(test_dataloader)
                            if self.distributed:
                                eval_dict = torch.tensor([
                                    eval_dict['avg_loss'],
                                    eval_dict['avg_acc'], *eval_dict['cls_acc']
                                ],
                                                         device=self.device)
                                dist.reduce(eval_dict,
                                            dst=0,
                                            op=dist.ReduceOp.SUM)
                                eval_dict = eval_dict.cpu().numpy()
                                eval_dict = {
                                    'avg_loss': eval_dict[0] / self.world_size,
                                    'avg_acc': eval_dict[1] / self.world_size,
                                    'cls_acc': eval_dict[2:] / self.world_size
                                }
                            if self.is_main_process():
                                eval_results["test_acc"].append(
                                    eval_dict['avg_acc'])
                                eval_results["avg_acc"].append(
                                    eval_dict['cls_acc'])
                                eval_results["data_cnt"].append(num_eval)
                                self.report_test(samples_cnt,
                                                 eval_dict["avg_loss"],
                                                 eval_dict['avg_acc'])
                            num_eval += self.eval_period
                    sys.stdout.flush()
                self.report_test(samples_cnt, eval_dict["avg_loss"],
                                 eval_dict['avg_acc'])
                self.online_after_task(task_id)

                test_sampler = OnlineTestSampler(self.test_dataset,
                                                 self.exposed_classes)
                test_dataloader = DataLoader(self.test_dataset,
                                             batch_size=self.batchsize * 2,
                                             sampler=test_sampler,
                                             num_workers=self.n_worker)
                eval_dict = self.online_evaluate(test_dataloader)
                #! after training done
                # self.report_test(num_eval, eval_dict["avg_loss"], eval_dict['avg_acc'])

                if self.distributed:
                    confusion_matrix = torch.tensor(
                        eval_dict['confusion_matrix'], device=self.device)
                    eval_dict = torch.tensor([
                        eval_dict['avg_loss'], eval_dict['avg_acc'],
                        *eval_dict['cls_acc']
                    ],
                                             device=self.device)
                    dist.reduce(eval_dict, dst=0, op=dist.ReduceOp.SUM)
                    dist.reduce(confusion_matrix, dst=0, op=dist.ReduceOp.SUM)
                    eval_dict = eval_dict.cpu().numpy()
                    confusion_matrix = confusion_matrix.cpu().numpy()
                    eval_dict = {
                        'avg_loss': eval_dict[0] / self.world_size,
                        'avg_acc': eval_dict[1] / self.world_size,
                        'cls_acc': eval_dict[2:] / self.world_size,
                        "confusion_matrix": confusion_matrix
                    }
                task_acc = eval_dict['avg_acc']

                logging.info(
                    "[2-4] Update the information for the current task")
                task_records["task_acc"].append(task_acc)
                task_records["cls_acc"].append(eval_dict["cls_acc"])
                if self.is_main_process() and 'confusion_matrix' in eval_dict:
                    np.save(
                        os.path.join(
                            self.log_dir,
                            f'seed_{self.rnd_seed}_T{task_id}_confusion_matrix.npy'
                        ), eval_dict['confusion_matrix'])

                logging.info("[2-5] Report task result")
        except Exception as e:
            logging.error(str(e))

        try:
            if self.is_main_process():
                self.save_model(os.path.join(self.log_dir, f'checkpoint.pth'))
                np.save(
                    os.path.join(self.log_dir, f'seed_{self.rnd_seed}.npy'),
                    task_records["task_acc"])

                if self.eval_period is not None:
                    np.save(
                        os.path.join(self.log_dir,
                                     f'seed_{self.rnd_seed}_eval.npy'),
                        eval_results['test_acc'])
                    np.save(
                        os.path.join(self.log_dir,
                                     f'seed_{self.rnd_seed}_eval_time.npy'),
                        eval_results['data_cnt'])
                    if 'confusion_matrix' in eval_dict:
                        np.save(
                            os.path.join(
                                self.log_dir,
                                f'seed_{self.rnd_seed}_confusion_matrix.npy'),
                            eval_dict['confusion_matrix'])
                    if len(eval_dict.get('feature', {})) > 0:
                        np.save(os.path.join(self.log_dir, f'feature.npy'),
                                eval_dict['feature'])

                # Accuracy (A)
                A_auc = np.mean(eval_results["test_acc"])
                A_avg = np.mean(task_records["task_acc"])
                A_last = task_records["task_acc"][self.n_tasks - 1]

                # Forgetting (F)
                cls_acc = np.array(task_records["cls_acc"])
                acc_diff = []
                for j in range(self.n_classes):
                    if np.max(cls_acc[:-1, j]) > 0:
                        acc_diff.append(
                            np.max(cls_acc[:-1, j]) - cls_acc[-1, j])
                F_last = np.mean(acc_diff)

                logging.info(f"======== Summary =======")
                logging.info(f"Exposed Classes: {self.exposed_classes}")
                for i in range(len(cls_acc)):
                    logging.info(f"Task {i}\n" + str(cls_acc[i]))
                logging.info(
                    f"A_auc {A_auc:.5f} | A_avg {A_avg:.5f} | A_last {A_last:.5f} | F_last {F_last:.5f}"
                )
                with open(os.path.join(self.log_dir, 'result.txt'), 'w') as f:
                    f.write(
                        f"Dataset:{self.dataset_name} | A_auc {A_auc:.5f} | A_avg {A_avg:.5f} | A_last {A_last:.5f} | F_last {F_last:.5f}\n"
                    )
        except Exception as e:
            logging.error(str(e))

    def add_new_class(self, class_name):
        exposed_classes = []
        for label in class_name:
            if label.item() not in self.exposed_classes:
                self.exposed_classes.append(label.item())
        if self.distributed:
            exposed_classes = torch.cat(
                self.all_gather(
                    torch.tensor(self.exposed_classes,
                                 device=self.device))).cpu().tolist()
            self.exposed_classes = []
            for cls in exposed_classes:
                if cls not in self.exposed_classes:
                    self.exposed_classes.append(cls)
        self.memory.add_new_class(cls_list=self.exposed_classes)

        self.exposed_classes_names = [
            self.train_dataset.classes_names[i] for i in self.exposed_classes
        ]
        if 'reset' in self.sched_name:
            self.update_schedule(reset=True)

    def online_step(self, sample, samples_cnt):
        raise NotImplementedError()

    def online_before_task(self, task_id):
        raise NotImplementedError()

    def online_after_task(self, task_id):
        raise NotImplementedError()

    def online_evaluate(self, test_loader, samples_cnt):
        raise NotImplementedError()

    def is_dist_avail_and_initialized(self):
        if not dist.is_available():
            return False
        if not dist.is_initialized():
            return False
        return True

    def get_world_size(self):
        if not self.is_dist_avail_and_initialized():
            return 1
        return dist.get_world_size()

    def get_rank(self):
        if not self.is_dist_avail_and_initialized():
            return 0
        return dist.get_rank()

    def is_main_process(self):
        return self.get_rank() == 0

    def setup_for_distributed(self, is_master):
        """
        This function disables printing when not in master process
        """
        self.setup_root_logger(is_master=is_master)
        import builtins as __builtin__
        builtin_print = __builtin__.print

        def print(*args, **kwargs):
            force = kwargs.pop('force', False)
            if is_master or force:
                builtin_print(*args, **kwargs)

        __builtin__.print = print

    def setup_root_logger(self, is_master=True, filename="log.txt"):
        if is_master:
            root_logger = logging.getLogger()
            root_logger.setLevel(logging.INFO)
            ch = logging.StreamHandler(stream=sys.stdout)
            ch.setLevel(logging.INFO)
            formatter = logging.Formatter("%(asctime)s | %(message)s")
            ch.setFormatter(formatter)
            root_logger.addHandler(ch)

            fh = logging.FileHandler(os.path.join(self.log_dir, filename),
                                     mode='w')
            fh.setLevel(logging.INFO)
            fh.setFormatter(formatter)
            root_logger.addHandler(fh)
            return root_logger
        else:
            pass

    def report_test(self, sample_num, avg_loss, avg_acc):
        logging.info(
            f"Test | Sample # {sample_num} | test_loss {avg_loss:.4f} | test_acc {avg_acc:.4f} | "
        )

    def report_training(self, sample_num, train_loss, train_acc):
        print(
            f"Train | Sample # {sample_num} | train_loss {train_loss:.4f} | train_acc {train_acc:.4f} | "
            f"lr {self.optimizer.param_groups[0]['lr']:.6f} | "
            f"Num_Classes {len(self.exposed_classes)} | "
            f"running_time {datetime.timedelta(seconds=int(time.time() - self.start_time))} | "
            f"ETA {datetime.timedelta(seconds=int((time.time() - self.start_time) * (self.total_samples-sample_num) / sample_num))}"
        )

    def _interpret_pred(self, y, pred, n_classes=None):
        if n_classes is None:
            n_classes = self.n_classes
        # xlable is batch
        ret_num_data = torch.zeros(n_classes)
        ret_corrects = torch.zeros(n_classes)

        xlabel_cls, xlabel_cnt = y.unique(return_counts=True)
        for cls_idx, cnt in zip(xlabel_cls, xlabel_cnt):
            ret_num_data[cls_idx] = cnt

        correct_xlabel = y.masked_select(y == pred)
        correct_cls, correct_cnt = correct_xlabel.unique(return_counts=True)
        for cls_idx, cnt in zip(correct_cls, correct_cnt):
            ret_corrects[cls_idx] = cnt

        return ret_num_data, ret_corrects

    def reset_opt(self):
        self.optimizer = select_optimizer(self.opt_name, self.lr, self.model)
        self.scheduler = select_scheduler(self.sched_name, self.optimizer)

    def all_gather(self, item):
        local_size = torch.tensor(item.size(0), device=self.device)
        all_sizes = [
            torch.zeros_like(local_size) for _ in range(dist.get_world_size())
        ]
        for i in range(dist.get_world_size()):
            if i == dist.get_rank():
                dist.gather(local_size, all_sizes, dst=i)
            else:
                dist.gather(local_size, dst=i)
        # dist.all_gather(all_sizes, local_size, async_op=False)
        max_size = max(all_sizes)

        size_diff = max_size.item() - local_size.item()
        if size_diff:
            padding = torch.zeros(size_diff,
                                  device=self.device,
                                  dtype=item.dtype)
            item = torch.cat((item, padding))

        all_qs_padded = [
            torch.zeros_like(item) for _ in range(dist.get_world_size())
        ]

        for i in range(dist.get_world_size()):
            if i == dist.get_rank():
                dist.gather(item, all_qs_padded, dst=i)
            else:
                dist.gather(item, dst=i)

        # dist.all_gather(all_qs_padded, item)
        all_qs = []
        for q, size in zip(all_qs_padded, all_sizes):
            all_qs.append(q[:size])
        return all_qs

    def train_data_config(self, n_task, train_dataset, train_sampler):
        for t_i in range(n_task):
            train_sampler.set_task(t_i)
            train_dataloader = DataLoader(train_dataset,
                                          batch_size=self.batchsize,
                                          sampler=train_sampler,
                                          num_workers=4)
            data_info = {}
            for i, data in enumerate(train_dataloader):
                _, label = data
                label = label.to(self.device)
                for b in range(len(label)):
                    if 'Class_' + str(label[b].item()) in data_info.keys():
                        data_info['Class_' + str(label[b].item())] += 1
                    else:
                        data_info['Class_' + str(label[b].item())] = 1
            logging.info(f"[Train] Task{t_i} Data Info")
            logging.info(data_info)

            convert_data_info = self.convert_class_label(data_info)
            np.save(
                os.path.join(self.log_dir,
                             f"seed_{self.rnd_seed}_task{t_i}_train_data.npy"),
                convert_data_info)
            logging.info(convert_data_info)

    def test_data_config(self, test_dataloader, task_id):
        data_info = {}
        for i, data in enumerate(test_dataloader):
            _, label = data
            label = label.to(self.device)

            for b in range(len(label)):
                if 'Class_' + str(label[b].item()) in data_info.keys():
                    data_info['Class_' + str(label[b].item())] += 1
                else:
                    data_info['Class_' + str(label[b].item())] = 1

        logging.info("<<Exposed Class>>")
        logging.info([
            (x, y)
            for x, y in zip(self.exposed_classes, self.exposed_classes_names)
        ])

        logging.info(f"[Test] Task {task_id} Data Info")
        logging.info(data_info)
        logging.info("<<Convert>>")
        convert_data_info = self.convert_class_label(data_info)
        logging.info(convert_data_info)

    def convert_class_label(self, data_info):
        #* self.class_list => original class label
        self.class_list = self.train_dataset.classes
        for key in list(data_info.keys()):
            old_key = int(key[6:])
            data_info[self.class_list[old_key]] = data_info.pop(key)

        return data_info

    def current_task_data(self, train_loader):
        data_info = {}
        for i, data in enumerate(train_loader):
            _, label = data

            for b in range(label.shape[0]):
                if 'Class_' + str(label[b].item()) in data_info.keys():
                    data_info['Class_' + str(label[b].item())] += 1
                else:
                    data_info['Class_' + str(label[b].item())] = 1

        logging.info("Current Task Data Info")
        logging.info(data_info)
        logging.info("<<Convert to str>>")
        convert_data_info = self.convert_class_label(data_info)
        logging.info(convert_data_info)

    def save_model(self, path):
        if self.distributed:
            model = self.model.module.to('cpu')
            state_dict = OrderedDict()
            for k, v in model.state_dict():
                state_dict[k.replace("module.", "")] = v
        else:
            model = self.model.to('cpu')
            state_dict = model.state_dict()
        torch.save(
            {
                "state_dict": state_dict,
                "args": self.args,
            },
            path,
        )

    def update_teacher_model(self):
        if self.kd_lambda > 0:
            if self.teacher_model is None or not self.kd_ema_decay > 0:
                self.teacher_model = copy.deepcopy(self.model)
            else:
                for teacher_param, student_param in zip(
                        self.teacher_model.parameters(),
                        self.model.parameters()):
                    teacher_param.data.mul_(self.kd_ema_decay).add_(
                        student_param.data, alpha=1.0 - self.kd_ema_decay)
            for p in self.teacher_model.parameters():
                p.requires_grad = False
            self.teacher_model.eval()
        return self.teacher_model

    def _KD_loss(self, pred, soft, T=2.0):
        pred = torch.log_softmax(pred / T, dim=1)
        soft = torch.softmax(soft / T, dim=1)
        return -1 * torch.mul(soft, pred).sum() / pred.shape[0]

    def load_state_dict(self, state_dict):
        sd = OrderedDict()
        for k, v in self.model.named_parameters():
            if k in state_dict:
                if v.shape != state_dict[k].shape:
                    print(k, v.shape, state_dict[k].shape)
                else:
                    sd[k] = state_dict[k]
            else:
                print(k)
        return self.model.load_state_dict(sd, strict=False)

    def zero_shot_evaluation(self, state_dict):
        self.device = torch.device('cuda')
        if self.rnd_seed is not None:
            random.seed(self.rnd_seed)
            np.random.seed(self.rnd_seed)
            torch.manual_seed(self.rnd_seed)
            torch.cuda.manual_seed(self.rnd_seed)
            torch.cuda.manual_seed_all(self.rnd_seed)  # if use multi-GPU
            cudnn.deterministic = True
        cudnn.benchmark = False
        assert hasattr(self, 'offline_evaluate')

        print(f"Load a CIL method ({self.method})")
        print(f"Building model ({self.model_name})")
        self.model, self.inp_size = get_model(
            model_name=self.model_name,
            method=self.method,
            device='cpu',
            peft_encoder=self.args['peft_encoder'],
        )
        self.load_state_dict(state_dict)
        self.model = self.model.to(self.device)
        for k, v in self.model.named_parameters():
            v.requires_grad = False

        print("Start lifelong zero shot evaluation")
        test_datasets = [
            x for x in self.zero_shot_dataset if x not in self.dataset_name
        ]
        test_datasets.insert(0, self.dataset_name)
        zs_dataloader, zs_classes, zs_meta = self.setup_test_dataloader(
            test_datasets)
        result = self.offline_evaluate(zs_dataloader, zs_classes, zs_meta)
        return result

    @torch.no_grad()
    def feature_extract_only(self, state_dict, dir_path):
        self.device = torch.device('cuda')
        if self.rnd_seed is not None:
            random.seed(self.rnd_seed)
            np.random.seed(self.rnd_seed)
            torch.manual_seed(self.rnd_seed)
            torch.cuda.manual_seed(self.rnd_seed)
            torch.cuda.manual_seed_all(self.rnd_seed)  # if use multi-GPU
            cudnn.deterministic = True
        cudnn.benchmark = False
        assert hasattr(self, 'offline_evaluate')

        print(f"Load a CIL method ({self.method})")
        print(f"Building model ({self.model_name})")
        self.model, self.inp_size = get_model(
            model_name=self.model_name,
            method=self.method,
            device='cpu',
            peft_encoder=self.args['peft_encoder'],
        )
        self.load_state_dict(state_dict)
        self.model = self.model.to(self.device)
        for k, v in self.model.named_parameters():
            v.requires_grad = False

        print("Start feature extract")
        datasets = [x for x in self.zero_shot_dataset] + [self.dataset_name]
        dataloader, classes, meta = self.setup_test_dataloader(datasets)
        self.model.eval()
        for idx in range(len(dataloader)):
            text_features = self.model.forward_text(
                self.model.labels_tokenize(classes[idx]).to(self.device))

            image_features, labels = [], []
            for data in tqdm(dataloader[idx], desc=meta[idx][0]):
                x, y = data
                x = x.to(self.device)
                image_features.append(self.model.forward_image(x).cpu())
                labels += y.tolist()

            result = {
                'image_features': torch.cat(image_features, 0).cpu().numpy(),
                'text_features': text_features.cpu().numpy(),
                'labels': labels,
                'label_names': classes[idx]
            }
            dst = os.path.join(dir_path, 'feature_{}.npy'.format(meta[idx][0]))
            np.save(dst, result)
