import copy
import json
import os
import shutil
import time

import torch
from tensorboardX import SummaryWriter
from torch.backends import cudnn
from torch.utils.data import ConcatDataset, DataLoader
import numpy as np

from lib.core import AverageMeter, MCFM_Combiner, accuracy
from lib.core.CIL_combiner import CIL_Combiner
from lib.dataset import ExemplarDataset, SubDataset, TransformedDataset, AVAILABLE_TRANSFORMS, transforms, \
    TransformedDataset_for_exemplars
from lib.loss.loss import CrossEntropy_binary, CrossEntropy, IBLoss, trade_off_IB_Loss, mixup_trade_off_IB_Loss
from lib.model import DDC_Network
from lib.utils.utils import get_cls_num_list, get_model, get_optimizer, get_scheduler
from torchstat import stat


class MCFM_handler:
    """Our approach DDC"""

    def __init__(self, dataset_handler, exemplar_manager, cfg, logger, batch_train_logger, device):
        self.dataset_handler = dataset_handler
        self.exemplar_manager = exemplar_manager
        self.cfg = cfg
        self.logger = logger
        self.batch_train_logger = batch_train_logger
        self.device = device
        self.model = None
        self.pre_tasks_model = None
        self.acc_result = None
        self.start_task_id = None
        self.MCFM_combiner = None

        self.latest_model = None
        self.best_model = None
        self.best_epoch = None
        self.best_acc = 0

    def _first_task_init(self):
        '''Resume to init or init'''
        if self.cfg.RESUME.use_resume:
            self.logger.info(f"use_resume: {self.cfg.RESUME.resumed_model_path}")
            # with open(self.cfg.RESUME.resumed_file, 'r') as fr:
            #     breakpoint_data = json.load(fr)
            breakpoint_data = torch.load(self.cfg.RESUME.resumed_file)
            self.dataset_handler.update_split_selected_data(breakpoint_data["split_selected_data"])
            # print(f"self.dataset_handler._split_selected_data: {self.dataset_handler._split_selected_data}")
            self.dataset_handler.get_dataset()
            self.exemplar_manager.resume_manager(breakpoint_data)
            self.resume_model()
            self.is_resume_legal()
        elif self.cfg.PRETRAINED_MODEL:
            self.dataset_handler.get_dataset()
            self.logger.info(f"use pretrained_model: {self.cfg.PRETRAINED_MODEL}")
            self.model = DDC_Network(self.cfg, mode="train", num_classes=self.dataset_handler.all_classes,
                                     rate=self.cfg.rate)
            self.model.load_model(self.cfg.PRETRAINED_MODEL)
            if self.cfg.CPU_MODE:
                self.model = self.model.to(self.device)
            else:
                if self.cfg.availabel_cudas:
                    os.environ['CUDA_VISIBLE_DEVICES'] = self.cfg.availabel_cudas
                    device_ids = [i for i in range(len(self.cfg.availabel_cudas.strip().split(',')))]
                    self.model = torch.nn.DataParallel(self.model, device_ids=device_ids).cuda()
                else:
                    self.model = torch.nn.DataParallel(self.model).cuda()
        else:
            self.dataset_handler.get_dataset()
            self.model = DDC_Network(self.cfg, mode="train", num_classes=self.dataset_handler.all_classes,
                                     rate=self.cfg.rate)
            # self.model = DDC_Network(self.cfg, mode="train", num_classes=100)
            if self.cfg.CPU_MODE:
                self.model = self.model.to(self.device)
            else:
                self.model = torch.nn.DataParallel(self.model).cuda()
        if self.cfg.re_mix:
            self.MCFM_combiner = MCFM_Combiner(self.cfg, self.device)
        else:
            self.MCFM_combiner = CIL_Combiner(self.cfg, self.device)

    def resume_model(self):
        self.model = DDC_Network(self.cfg, mode="train", num_classes=self.dataset_handler.all_classes,
                                 rate=self.cfg.rate)
        self.model.load_model(self.cfg.RESUME.resumed_model_path)
        if self.cfg.CPU_MODE:
            self.model = self.model.to(self.device)
        else:
            if self.cfg.availabel_cudas:
                os.environ['CUDA_VISIBLE_DEVICES'] = self.cfg.availabel_cudas
                device_ids = [i for i in range(len(self.cfg.availabel_cudas.strip().split(',')))]
                self.model = torch.nn.DataParallel(self.model, device_ids=device_ids).cuda()
            else:
                self.model = torch.nn.DataParallel(self.model).cuda()
        checkpoint = torch.load(self.cfg.RESUME.resumed_model_path)
        self.acc_result = checkpoint['acc_result']
        self.start_task_id = checkpoint['task_id']
        self.logger.info(f"start from task {self.start_task_id}")

    def is_resume_legal(self):
        learned_classes_num = len(self.exemplar_manager.exemplar_sets)
        assert learned_classes_num % self.dataset_handler.classes_per_task == 0
        assert learned_classes_num / self.dataset_handler.classes_per_task == self.start_task_id
        print(f"Resume acc_result of resumed model: {self.acc_result}")
        self.logger.info(f"Resume acc_result of resumed model: {self.acc_result}")
        acc, centroid_acc = self.validate_with_exemplars(self.start_task_id)
        FC_acc = self.validate_with_FC(self.start_task_id)
        print(f"validate resumed model: {acc.mean()} || {centroid_acc.mean()} || {FC_acc.mean()}")
        self.logger.info(f"validate resumed model: {acc.mean()} || {FC_acc.mean()}")
        pass

    def build_optimize(self, base_lr=None, typical_cls_train=True):
        # todo Done
        # if typical_cls_train:
        #     optimizer = get_optimizer(self.cfg, self.model)
        #     return optimizer
        optimizer = get_optimizer(self.cfg, self.model, BASE_LR=base_lr)

        return optimizer

    def build_scheduler(self, optimizer, lr_step=None):
        # todo
        scheduler = get_scheduler(self.cfg, optimizer, lr_step=lr_step)
        return scheduler

    def cil_train_main(self):
        '''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode].

        [train_datasets]    <list> with for each task the training <DataSet>
        [scenario]          <str>, choice from "task", "domain" and "class"
        [classes_per_task]  <int>, # of classes per task'''

        gpus = torch.cuda.device_count()
        self.logger.info(f"use {gpus} gpus")
        cudnn.benchmark = True
        cudnn.enabled = True
        # 初始化 Network
        self._first_task_init()
        # print(self.model)
        # self.model = self.model.cuda()
        # stat(self.model, (3,32,32))
        # return
        ddc_result = {}
        train_dataset_transform = transforms.Compose([
            *AVAILABLE_TRANSFORMS[self.dataset_handler.dataset_name]['train_transform'],
        ])
        if not self.cfg.RESUME.use_resume:
            self.start_task_id = 1  # self.start_task_id 从 1 开始
        else:
            self.start_task_id += 1

        model_dir = os.path.join(self.cfg.OUTPUT_DIR, self.cfg.NAME, "models")
        code_dir = os.path.join(self.cfg.OUTPUT_DIR, self.cfg.NAME, "codes")

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        else:
            self.logger.info(
                "This directory has already existed, Please remember to modify your cfg.NAME"
            )

            shutil.rmtree(code_dir)
        self.logger.info("=> output model will be saved in {}".format(model_dir))
        this_dir = os.path.dirname(__file__)
        ignore = shutil.ignore_patterns(
            "*.pyc", "*.so", "*.out", "*pycache*", "*.pth", "*build*", "*output*", "*datasets*"
        )
        shutil.copytree(os.path.join(this_dir, ".."), code_dir, ignore=ignore)
        train_dataset = None
        train_dataset_for_EM = None
        for task, original_imgs_train_dataset in enumerate(self.dataset_handler.original_imgs_train_datasets,
                                                           1):
            self.logger.info(f'New task {task} begin.')
            self.batch_train_logger.info(f'New task {task} begin.')
            if self.cfg.RESUME.use_resume and task < self.start_task_id:
                self.logger.info(f"Use resume. continue.")
                continue

            if self.cfg.use_base_half and task < int(self.dataset_handler.all_tasks / 2):
                train_dataset_temp = TransformedDataset(original_imgs_train_dataset, transform=train_dataset_transform)
                if train_dataset is None:
                    train_dataset = train_dataset_temp
                else:
                    train_dataset = ConcatDataset([train_dataset, train_dataset_temp])

                if self.cfg.exemplar_manager.store_original_imgs:
                    train_dataset_for_EM_temp = TransformedDataset_for_exemplars(original_imgs_train_dataset,
                                                                                 transform=
                                                                                 self.dataset_handler.val_test_dataset_transform)
                else:
                    train_dataset_for_EM_temp = TransformedDataset_for_exemplars(original_imgs_train_dataset,
                                                                                 transform=train_dataset_transform)

                if train_dataset_for_EM is None:
                    train_dataset_for_EM = train_dataset_for_EM_temp
                else:
                    train_dataset_for_EM = ConcatDataset([train_dataset_for_EM, train_dataset_for_EM_temp])
                self.logger.info(f'task continue.')
                continue
            else:
                if self.cfg.use_base_half:
                    if task == int(self.dataset_handler.all_tasks / 2):
                        train_dataset_temp = TransformedDataset(original_imgs_train_dataset,
                                                                transform=train_dataset_transform)
                        train_dataset = ConcatDataset([train_dataset, train_dataset_temp])
                        self.logger.info(f'base_half dataset construct end.')
                        # self.batch_train_logger.info(f'base_half dataset construct end.')
                        self.logger.info(f'train_dataset length: {len(train_dataset)}.')
                    elif task > int(self.dataset_handler.all_tasks / 2):
                        train_dataset = TransformedDataset(original_imgs_train_dataset,
                                                           transform=train_dataset_transform)
                    else:
                        train_dataset = None
                else:
                    train_dataset = TransformedDataset(original_imgs_train_dataset, transform=train_dataset_transform)

            self.pre_tasks_model = copy.deepcopy(self.model)

            exemplar_dataset = None
            if self.exemplar_manager.memory_budget > 0:
                if self.cfg.use_base_half:
                    if task > int(self.dataset_handler.all_tasks / 2):
                        exemplar_dataset = self.exemplar_manager.get_ExemplarDataset(for_train=True)
                elif task > 1:
                    exemplar_dataset = self.exemplar_manager.get_ExemplarDataset(for_train=True)

            if task > 1:
                if exemplar_dataset:
                    self.logger.info(f"exemplar_dataset length: {len(exemplar_dataset)} ")
                else:
                    self.logger.info(f"exemplar_dataset length: None ")

            # Find [active_classes]
            active_classes_num = self.dataset_handler.classes_per_task * task

            if self.cfg.use_base_half and task == int(self.dataset_handler.all_tasks / 2) or \
                    (not self.cfg.use_base_half and task == 1):
                self.first_task_train_main(train_dataset, active_classes_num, task, writer=None)
            else:
                self.logger.info(f'Task {task} begin:'
                                 f'Use exemplar_dataset to train model.')
                assert exemplar_dataset is not None
                label_weight = self.build_label_weight(active_classes_num,
                                                       int(len(
                                                           original_imgs_train_dataset) / self.dataset_handler.classes_per_task))
                label_weight = label_weight.to(self.device)
                self.logger.info(f"use distance_loss, label_weight: {label_weight}")
                if self.cfg.re_mix:
                    self.train_main(train_dataset, exemplar_dataset, active_classes_num, task, writer=None,
                                    label_weight=label_weight)

                    if self.cfg.use_IB:
                        self.logger.info(f"IB train start.")
                        self.train_main(train_dataset, exemplar_dataset, active_classes_num, task, writer=None,
                                        label_weight=label_weight, reach_IB=True)
                        self.logger.info(f"IB train Ending.")
                else:
                    train_dataset = ConcatDataset([train_dataset, exemplar_dataset])
                    self.normal_train_main(train_dataset, active_classes_num, task, writer=None,
                                           label_weight=label_weight)
                    # if self.cfg.use_IB and task > 2:
                    if self.cfg.use_IB:
                        self.logger.info(f"IB train start.")
                        self.normal_train_main(train_dataset, active_classes_num, task, writer=None,
                                               label_weight=label_weight, reach_IB=True)
                        self.logger.info(f"IB train Ending.")

            self.batch_train_logger.info(f'##########MCFM train task {task} End.#########')
            self.logger.info(f'#############MCFM train task {task} End.##############')
            self.logger.info(f'#############Example handler task {task} start.##############')
            # print("DDC train task-%d End" % task)
            # print("Example handler task-%d start." % task)
            # EXEMPLARS: update exemplar sets
            if self.cfg.exemplar_manager.fixed_exemplar_num > 0:
                exemplars_per_class = self.cfg.exemplar_manager.fixed_exemplar_num
            else:
                exemplars_per_class = int(np.floor(self.exemplar_manager.memory_budget /
                                                   (self.dataset_handler.classes_per_task * task)))
            # reduce examplar-sets
            if self.cfg.exemplar_manager.fixed_exemplar_num < 0:
                if self.cfg.use_base_half and task > int(self.dataset_handler.all_tasks / 2) or \
                        (not self.cfg.use_base_half and task > 1):
                    self.exemplar_manager.reduce_exemplar_sets(exemplars_per_class)

            if self.cfg.exemplar_manager.store_original_imgs:
                train_dataset_for_EM_temp = TransformedDataset_for_exemplars(original_imgs_train_dataset,
                                                                             transform=
                                                                             self.dataset_handler.val_test_dataset_transform)
            else:
                train_dataset_for_EM_temp = TransformedDataset_for_exemplars(original_imgs_train_dataset,
                                                                             transform=train_dataset_transform)
            # for each new class trained on, construct examplar-set
            if self.cfg.use_base_half and task == int(self.dataset_handler.all_tasks / 2):
                new_classes = list(range(0, self.dataset_handler.classes_per_task * task))
                train_dataset_for_EM = ConcatDataset([train_dataset_for_EM, train_dataset_for_EM_temp])
            else:
                new_classes = list(range(self.dataset_handler.classes_per_task * (task - 1),
                                         self.dataset_handler.classes_per_task * task))
                train_dataset_for_EM = train_dataset_for_EM_temp

            self.latest_model = self.model
            if self.cfg.use_best_model:
                self.model = self.best_model
                self.logger.info(f"Use best model. ")

            for class_id in new_classes:
                # create new dataset containing only all examples of this class
                self.logger.info(f"construct_exemplar_set class_id: {class_id}")
                class_dataset = SubDataset(original_dataset=train_dataset_for_EM,
                                           sub_labels=[class_id])
                # based on this dataset, construct new exemplar-set for this class
                self.exemplar_manager.construct_exemplar_set(class_dataset, self.model,
                                                             exemplars_per_class, self.cfg.TRAIN.BATCH_SIZE,
                                                             self.cfg.TRAIN.NUM_WORKERS)
                self.logger.info(
                    f"self.exemplar_manager exemplar_set length: {len(self.exemplar_manager.exemplar_sets)}")
            self.exemplar_manager.compute_means = True
            self.exemplar_manager.recompute_centroid_feature = True
            val_acc_with_exemplars_ncm, val_acc_with_exemplars_centroid = self.validate_with_exemplars(task)
            val_acc = self.validate_with_FC(task)
            test_acc = None
            if self.dataset_handler.val_datasets:
                test_acc_with_exemplars_ncm, test_acc_with_exemplars_centroid = self.validate_with_exemplars(task,
                                                                                                             is_test=True)
                test_acc = self.validate_with_FC(task, is_test=True)

            if test_acc:
                self.save_best_latest_model_data(model_dir, task, test_acc.mean(), self.cfg.TRAIN.MAX_EPOCH)
            else:
                self.save_best_latest_model_data(model_dir, task, val_acc.mean(), self.cfg.TRAIN.MAX_EPOCH)
            # self.compute_means = True
            self.logger.info(f'#############task: {task:0>3d} is finished Test begin. ##############')
            if self.dataset_handler.val_datasets:
                val_acc_ncm_str = f'task: {task} classififer:{"ncm"} val_acc: {val_acc_with_exemplars_ncm}, ' \
                                  f'avg: {val_acc_with_exemplars_ncm.mean()}, classififer:{"centroid"} ' \
                                  f'val_acc: {val_acc_with_exemplars_centroid}, ' \
                                  f'avg: {val_acc_with_exemplars_centroid.mean()}'
                val_acc_FC_str = f'task: {task} classififer:{"FC"} val_acc: {val_acc}, avg: {val_acc.mean()} '
                test_acc_ncm_str = f'task: {task} classififer:{"ncm"} test_acc: {test_acc_with_exemplars_ncm}, ' \
                                   f'avg: {test_acc_with_exemplars_ncm.mean()}, classififer:{"centroid"} ' \
                                   f'test_acc: {test_acc_with_exemplars_centroid}, ' \
                                   f'avg: {test_acc_with_exemplars_centroid.mean()}'
                test_acc_FC_str = f'task: {task} classififer:{"FC"} || test_acc: {test_acc}, avg: {test_acc.mean()} '
                self.logger.info(val_acc_ncm_str)
                self.logger.info(val_acc_FC_str)
                self.logger.info(test_acc_ncm_str)
                self.logger.info(test_acc_FC_str)

            else:
                test_acc_ncm_str = f'task: {task} classififer:{"ncm"} test_acc: {val_acc_with_exemplars_ncm}, ' \
                                   f'avg: {val_acc_with_exemplars_ncm.mean()}, classififer:{"centroid"} ' \
                                   f'test_acc: {val_acc_with_exemplars_centroid}, ' \
                                   f'avg: {val_acc_with_exemplars_centroid.mean()}'
                test_acc_FC_str = f'task: {task} classififer:{"FC"} || test_acc: {val_acc}, avg: {val_acc.mean()} '
                self.logger.info(test_acc_ncm_str)
                self.logger.info(test_acc_FC_str)


    def train_main(self, train_dataset, exemplar_dataset, active_classes_num, task, writer=None, label_weight=None,
                   reach_IB=False):
        if self.cfg.use_IB and reach_IB:
            optimizer = self.build_optimize(base_lr=self.cfg.TRAIN.OPTIMIZER.IB_BASE_LR)
            scheduler = self.build_scheduler(optimizer, lr_step=self.cfg.TRAIN.LR_SCHEDULER.IB_LR_STEP)
            MAX_EPOCH = self.cfg.TRAIN.IB_EPOCH
        else:
            optimizer = self.build_optimize()
            scheduler = self.build_scheduler(optimizer)
            MAX_EPOCH = self.cfg.TRAIN.MAX_EPOCH
        start_time = time.time()

        if "binary" in self.cfg.LOSS.LOSS_TYPE:
            criterion = CrossEntropy_binary()
        else:
            criterion = CrossEntropy()
        # weight = None, alpha = 10000., beta = 0.000001, active_classes_num = 100
        if self.cfg.re_mix and self.cfg.Mixup.all:
            criterion_ib = mixup_trade_off_IB_Loss(weight=label_weight, alpha=self.cfg.IB_alpha, beta=self.cfg.beta,
                                                   active_classes_num=active_classes_num).cuda()
        else:
            criterion_ib = IBLoss(weight=label_weight, alpha=self.cfg.IB_alpha,
                                  active_classes_num=active_classes_num).cuda()
        best_acc = 0
        mini_batch_size = int(self.cfg.TRAIN.BATCH_SIZE * len(self.exemplar_manager.exemplar_sets[0]) *
                              self.dataset_handler.classes_per_task * (task - 1) / (len(train_dataset)))
        balance_multiple = int((len(train_dataset) / self.dataset_handler.classes_per_task)
                               / len(self.exemplar_manager.exemplar_sets[0]))
        if mini_batch_size < 32:
            mini_batch_size = 32
            balance_multiple = int(self.cfg.TRAIN.BATCH_SIZE / 32) * (task - 1)
        exemplar_dataset_loader = iter(DataLoader(dataset=exemplar_dataset, batch_size=mini_batch_size,
                                                  num_workers=self.cfg.TRAIN.NUM_WORKERS, shuffle=True,
                                                  drop_last=True))
        exemplar_dataset_iter_num = len(exemplar_dataset_loader)
        exemplar_datasets_iter_index = 0
        self.logger.info(f"mini_batch_size: {mini_batch_size} || balance_multiple: {balance_multiple}")

        for epoch in range(1, MAX_EPOCH + 1):
            all_loss = [AverageMeter(), AverageMeter(), AverageMeter()]
            acc = AverageMeter()
            if float(torch.__version__[:3]) < 1.3:
                scheduler.step()
            is_first_ite = True
            iters_left = 1
            iter_index = 0
            iter_num = 0
            while iters_left > 0:
                # Update # iters left on current data-loader(s) and, if needed, create new one(s)
                iters_left -= 1
                if is_first_ite:
                    is_first_ite = False
                    data_loader = iter(
                        DataLoader(dataset=train_dataset, batch_size=self.cfg.TRAIN.BATCH_SIZE,
                                   num_workers=self.cfg.TRAIN.NUM_WORKERS, shuffle=True, drop_last=True))
                    # NOTE:  [train_dataset]  is training-set of current task
                    #      [training_dataset] is training-set of current task with stored exemplars added (if requested)
                    iter_num = iters_left = len(data_loader)
                    continue

                #####-----CURRENT BATCH-----#####
                try:
                    x, y = next(data_loader)  # --> sample training data of current task
                except StopIteration:
                    raise ValueError("next(data_loader) error while read data. ")
                x, y = x.to(self.device), y.to(self.device)  # --> transfer them to correct device
                if exemplar_datasets_iter_index == exemplar_dataset_iter_num:
                    exemplar_dataset_loader = iter(
                        DataLoader(dataset=exemplar_dataset, batch_size=mini_batch_size,
                                   num_workers=self.cfg.TRAIN.NUM_WORKERS, shuffle=True,
                                   drop_last=True))
                    exemplar_datasets_iter_index = 0
                examplar_imgs, examplar_labels = next(exemplar_dataset_loader)
                exemplar_datasets_iter_index += 1
                # ---> Train MAIN MODEL
                cnt = y.shape[0]
                loss, now_acc, now_cnt = self.train_a_batch(optimizer, criterion, x, y, active_classes_num,
                                                            examplar_imgs, examplar_labels, balance_multiple, task,
                                                            label_weight, criterion_ib=criterion_ib, reach_IB=reach_IB)
                all_loss[0].update(loss[0].data.item(), cnt)
                all_loss[1].update(loss[1].data.item(), cnt)
                if self.cfg.CLASSIFIER.NECK.distance_loss:
                    all_loss[-1].update(loss[-1].data.item(), cnt)
                acc.update(now_acc[0], cnt)
                if iter_index % self.cfg.SHOW_STEP == 0:
                    pbar_str = "Epoch: {} || Batch:{:>3d}/{} || lr : {} || Batch_cls_Loss:{:>5.3f} || " \
                               "Batch_distill_Loss:{:>5.3f} || Batch_Accuracy:{:>5.2f}".format(epoch, iter_index,
                                                                                               iter_num,
                                                                                               optimizer.param_groups[
                                                                                                   0]['lr'],
                                                                                               all_loss[0].val,
                                                                                               all_loss[1].val,
                                                                                               acc.val * 100
                                                                                               )
                    if reach_IB:
                        pbar_str = "IB_train," + pbar_str
                    if self.cfg.CLASSIFIER.NECK.distance_loss:
                        pbar_str += "Batch_distance_Loss:{:>5.3f}".format(all_loss[2].val)
                    self.batch_train_logger.info(pbar_str)
                    self.logger.info(pbar_str)
                iter_index += 1

            # if epoch % self.cfg.epoch_show_step == 0:
            # train_acc, train_loss = acc.avg, all_loss.avg
            # loss_dict, acc_dict = {"train_loss": train_loss}, {"train_acc": train_acc}
            if self.cfg.VALID_STEP != -1 and epoch % self.cfg.VALID_STEP == 0:
                pbar_str = "Validate Epoch: {} || lr: {} || epoch_cls_Loss:{:>5.3f}  || epoch_distill_Loss:{:>5.3f}" \
                           "epoch_Accuracy:{:>5.2f}".format(epoch, optimizer.param_groups[0]['lr'],
                                                            all_loss[0].avg, all_loss[1].avg,
                                                            acc.val * 100
                                                            )
                if reach_IB:
                    pbar_str = "IB_train," + pbar_str
                # self.batch_train_logger.info(pbar_str)
                self.logger.info(pbar_str)

                val_acc = self.validate_with_FC(task)  # task_id 从1开始

                if val_acc.mean() > best_acc:
                    best_acc, best_epoch = val_acc.mean(), epoch
                    self.best_model = copy.deepcopy(self.model)
                    self.best_epoch = best_epoch
                    self.best_acc = best_acc
                    self.logger.info(
                        "--------------Best_Epoch:{:>3d}    Best_Acc:{:>5.2f}%--------------".format(
                            best_epoch, best_acc * 100
                        )
                    )
            # if writer:
            #     writer.add_scalars("scalar/acc", acc_dict, epoch)
            #     writer.add_scalars("scalar/loss", loss_dict, epoch)

            if float(torch.__version__[:3]) >= 1.3:
                scheduler.step()
        end_time = time.time()
        if writer:
            writer.close()

    def normal_train_main(self, train_dataset, active_classes_num, task, writer=None, label_weight=None,
                          reach_IB=False):
        if self.cfg.use_IB and reach_IB:
            optimizer = self.build_optimize(base_lr=self.cfg.TRAIN.OPTIMIZER.IB_BASE_LR)
            scheduler = self.build_scheduler(optimizer, lr_step=self.cfg.TRAIN.LR_SCHEDULER.IB_LR_STEP)
            MAX_EPOCH = self.cfg.TRAIN.IB_EPOCH
        else:
            optimizer = self.build_optimize()
            scheduler = self.build_scheduler(optimizer)
            MAX_EPOCH = self.cfg.TRAIN.MAX_EPOCH
        start_time = time.time()

        if "binary" in self.cfg.LOSS.LOSS_TYPE:
            criterion = CrossEntropy_binary()
        else:
            criterion = CrossEntropy()
        # weight = None, alpha = 10000., beta = 0.000001, active_classes_num = 100
        if self.cfg.Mixup.all:
            criterion_ib = mixup_trade_off_IB_Loss(weight=label_weight, alpha=self.cfg.IB_alpha, beta=self.cfg.beta,
                                                   active_classes_num=active_classes_num).cuda()
        else:
            criterion_ib = IBLoss(weight=label_weight, alpha=self.cfg.IB_alpha,
                                  active_classes_num=active_classes_num).cuda()
        '''criterion_ib = trade_off_IB_Loss(weight=label_weight, alpha=self.cfg.IB_alpha, beta = self.cfg.beta,
                              active_classes_num=active_classes_num).cuda()'''
        best_acc = 0

        for epoch in range(1, MAX_EPOCH + 1):
            all_loss = [AverageMeter(), AverageMeter(), AverageMeter()]
            acc = AverageMeter()
            if float(torch.__version__[:3]) < 1.3:
                scheduler.step()
            is_first_ite = True
            iters_left = 1
            iter_index = 0
            iter_num = 0
            while iters_left > 0:
                # Update # iters left on current data-loader(s) and, if needed, create new one(s)
                iters_left -= 1
                if is_first_ite:
                    is_first_ite = False
                    data_loader = iter(
                        DataLoader(dataset=train_dataset, batch_size=self.cfg.TRAIN.BATCH_SIZE,
                                   num_workers=self.cfg.TRAIN.NUM_WORKERS, shuffle=True, drop_last=True))
                    # NOTE:  [train_dataset]  is training-set of current task
                    #      [training_dataset] is training-set of current task with stored exemplars added (if requested)
                    iter_num = iters_left = len(data_loader)
                    continue

                #####-----CURRENT BATCH-----#####
                try:
                    x, y = next(data_loader)  # --> sample training data of current task
                except StopIteration:
                    raise ValueError("next(data_loader) error while read data. ")
                x, y = x.to(self.device), y.to(self.device)  # --> transfer them to correct device
                # ---> Train MAIN MODEL
                cnt = y.shape[0]
                loss, now_acc, now_cnt = self.normal_train_a_batch(optimizer, criterion, x, y, active_classes_num,
                                                                   task, label_weight, criterion_ib=criterion_ib,
                                                                   reach_IB=reach_IB)
                all_loss[0].update(loss[0].data.item(), cnt)
                all_loss[1].update(loss[1].data.item(), cnt)
                if self.cfg.CLASSIFIER.NECK.distance_loss:
                    all_loss[-1].update(loss[-1].data.item(), cnt)
                acc.update(now_acc[0], cnt)
                if iter_index % self.cfg.SHOW_STEP == 0:
                    pbar_str = "Epoch: {} || Batch:{:>3d}/{} || lr : {} || Batch_cls_Loss:{:>5.3f} || " \
                               "Batch_distill_Loss:{:>5.3f} || Batch_Accuracy:{:>5.2f}".format(epoch, iter_index,
                                                                                               iter_num,
                                                                                               optimizer.param_groups[
                                                                                                   0]['lr'],
                                                                                               all_loss[0].val,
                                                                                               all_loss[1].val,
                                                                                               acc.val * 100
                                                                                               )
                    if reach_IB:
                        pbar_str = "IB_train," + pbar_str
                    if self.cfg.CLASSIFIER.NECK.distance_loss:
                        pbar_str += "Batch_distance_Loss:{:>5.3f}".format(all_loss[2].val)
                    self.batch_train_logger.info(pbar_str)
                    self.logger.info(pbar_str)
                iter_index += 1

            # if epoch % self.cfg.epoch_show_step == 0:
            # train_acc, train_loss = acc.avg, all_loss.avg
            # loss_dict, acc_dict = {"train_loss": train_loss}, {"train_acc": train_acc}
            if self.cfg.VALID_STEP != -1 and epoch % self.cfg.VALID_STEP == 0:
                pbar_str = "Validate Epoch: {} || lr: {} || epoch_cls_Loss:{:>5.3f}  || epoch_distill_Loss:{:>5.3f}" \
                           "epoch_Accuracy:{:>5.2f}".format(epoch, optimizer.param_groups[0]['lr'],
                                                            all_loss[0].avg, all_loss[1].avg,
                                                            acc.val * 100
                                                            )
                if reach_IB:
                    pbar_str = "IB_train," + pbar_str
                # self.batch_train_logger.info(pbar_str)
                self.logger.info(pbar_str)

                val_acc = self.validate_with_FC(task)  # task_id 从1开始

                if val_acc.mean() > best_acc:
                    best_acc, best_epoch = val_acc.mean(), epoch
                    self.best_model = copy.deepcopy(self.model)
                    self.best_epoch = best_epoch
                    self.best_acc = best_acc
                    self.logger.info(
                        "--------------Best_Epoch:{:>3d}    Best_Acc:{:>5.2f}%--------------".format(
                            best_epoch, best_acc * 100
                        )
                    )
            # if writer:
            #     writer.add_scalars("scalar/acc", acc_dict, epoch)
            #     writer.add_scalars("scalar/loss", loss_dict, epoch)

            if float(torch.__version__[:3]) >= 1.3:
                scheduler.step()
        end_time = time.time()
        if writer:
            writer.close()

    def first_task_train_main(self, train_dataset, active_classes_num, task_id, writer=None):
        optimizer = self.build_optimize(typical_cls_train=True)
        scheduler = self.build_scheduler(optimizer)
        start_time = time.time()
        if "binary" in self.cfg.LOSS.LOSS_TYPE:
            criterion = CrossEntropy_binary()
        else:
            criterion = CrossEntropy()
        best_acc = 0
        for epoch in range(1, self.cfg.TRAIN.MAX_EPOCH + 1):
            # if self.cfg.DISTILL.ENABLE:
            #     all_loss = [AverageMeter(), AverageMeter()]
            #     acc = [AverageMeterList(4), AverageMeterList(4)]
            all_loss = AverageMeter()
            distance_loss = AverageMeter()
            acc = AverageMeter()
            if float(torch.__version__[:3]) < 1.3:
                scheduler.step()
            is_first_ite = True
            iters_left = 1
            iter_index = 0
            iter_num = 0
            while iters_left > 0:
                # Update # iters left on current data-loader(s) and, if needed, create new one(s)
                iters_left -= 1
                if is_first_ite:
                    is_first_ite = False
                    data_loader = iter(
                        DataLoader(dataset=train_dataset, batch_size=self.cfg.TRAIN.BATCH_SIZE,
                                   num_workers=self.cfg.TRAIN.NUM_WORKERS, shuffle=True, drop_last=True))
                    # NOTE:  [train_dataset]  is training-set of current task
                    #      [training_dataset] is training-set of current task with stored exemplars added (if requested)
                    iter_num = iters_left = len(data_loader)
                    continue

                #####-----CURRENT BATCH-----#####
                try:
                    x, y = next(data_loader)  # --> sample training data of current task
                except StopIteration:
                    raise ValueError("next(data_loader) error while read data. ")
                x, y = x.to(self.device), y.to(self.device)  # --> transfer them to correct device
                # ---> Train MAIN MODEL
                # Train the main model with this batch
                # image, label, meta, active_classes_num, classes_per_task, criterion, optimizer,
                # previous_task_model, all_loss, acc, epoch, batch_index, number_batch, ** kwargs
                cnt = y.shape[0]
                # optimizer, criterion, current_image, current_label, active_classes_num,
                # pre_tasks_imgs, pre_tasks_labels, balance_multiple, task
                loss, now_acc, now_cnt = self.train_a_batch(optimizer=optimizer, criterion=criterion, current_image=x,
                                                            current_label=y, active_classes_num=active_classes_num,
                                                            task=task_id, pre_tasks_imgs=None, pre_tasks_labels=None,
                                                            balance_multiple=None)
                all_loss.update(loss[0].data.item(), cnt)
                if self.cfg.CLASSIFIER.NECK.distance_loss:
                    distance_loss.update(loss[1].data.item(), cnt)
                acc.update(now_acc[0], cnt)
                if iter_index % self.cfg.SHOW_STEP == 0:
                    pbar_str = "Epoch: {} || Batch:{:>3d}/{}|| lr: {} || Batch_cls_Loss:{:>5.3f} ||  " \
                               "Batch_Accuracy:{:>5.2f}".format(epoch, iter_index, iter_num,
                                                                optimizer.param_groups[0]['lr'],
                                                                all_loss.val, acc.val * 100
                                                                )
                    if self.cfg.CLASSIFIER.NECK.distance_loss:
                        pbar_str += "Batch_distance_Loss:{:>5.3f}".format(distance_loss.avg)
                    self.batch_train_logger.info(pbar_str)
                    self.logger.info(pbar_str)
                iter_index += 1

            # if epoch % self.cfg.epoch_show_step == 0:
            train_acc, train_loss = acc.avg, all_loss.avg
            loss_dict, acc_dict = {"train_loss": train_loss}, {"train_acc": train_acc}
            if self.cfg.VALID_STEP != -1 and epoch % self.cfg.VALID_STEP == 0:
                pbar_str = "Validate Epoch: {} || lr: {} || epoch_Loss:{:>5.3f}  || " \
                           "epoch_Accuracy:{:>5.2f}".format(epoch, optimizer.param_groups[0]['lr'],
                                                            all_loss.val,
                                                            acc.val * 100
                                                            )
                self.batch_train_logger.info(pbar_str)
                self.logger.info(pbar_str)

                val_acc = self.validate_with_FC(task_id)  # task_id 从1开始

                if val_acc.mean() > best_acc:
                    best_acc, best_epoch = val_acc.mean(), epoch
                    self.best_model = copy.deepcopy(self.model)
                    self.best_epoch = best_epoch
                    self.best_acc = best_acc
                    self.logger.info(
                        "--------------Best_Epoch:{:>3d}    Best_Acc:{:>5.2f}%--------------".format(
                            best_epoch, best_acc * 100
                        )
                    )
            if writer:
                writer.add_scalars("scalar/acc", acc_dict, epoch)
                writer.add_scalars("scalar/loss", loss_dict, epoch)

            if float(torch.__version__[:3]) >= 1.3:
                scheduler.step()
        end_time = time.time()
        if writer:
            writer.close()

    def train_a_batch(self, optimizer, criterion, current_image, current_label, active_classes_num,
                      pre_tasks_imgs, pre_tasks_labels, balance_multiple, task, label_weight=None, criterion_ib=None,
                      reach_IB=False):

        if task == 1 or (self.cfg.use_base_half and task == int(self.dataset_handler.all_tasks / 2)):
            self.model.train()
            if self.cfg.re_mix:
                res = self.MCFM_combiner.forward(self.model, criterion, current_image, current_label,
                                                 active_classes_num,
                                                 self.dataset_handler.classes_per_task, previous_task_model=None,
                                                 pre_tasks_imgs=None, pre_tasks_labels=None, balance_multiple=None,
                                                 task=task)
            else:
                res = self.MCFM_combiner.forward(self.model, criterion, current_image, current_label,
                                                 active_classes_num,
                                                 self.dataset_handler.classes_per_task, previous_task_model=None,
                                                 task=task)

            loss, now_acc, now_cnt = res
        else:
            self.pre_tasks_model.eval()
            self.model.train()
            res = self.MCFM_combiner.forward(self.model, criterion, current_image, current_label, active_classes_num,
                                             self.dataset_handler.classes_per_task, self.pre_tasks_model,
                                             pre_tasks_imgs, pre_tasks_labels, balance_multiple, task, label_weight,
                                             criterion_ib=criterion_ib, reach_IB=reach_IB)

            loss, now_acc, now_cnt = res

        optimizer.zero_grad()
        sum(loss).backward()
        optimizer.step()
        return res
        pass

    def normal_train_a_batch(self, optimizer, criterion, current_image, current_label, active_classes_num,
                             task, label_weight=None, criterion_ib=None, reach_IB=False):
        # todo
        if task == 1 or (self.cfg.use_base_half and task == int(self.dataset_handler.all_tasks / 2)):
            self.model.train()
            res = self.MCFM_combiner.forward(self.model, criterion, current_image, current_label, active_classes_num,
                                             self.dataset_handler.classes_per_task, previous_task_model=None,
                                             task=task)

            loss, now_acc, now_cnt = res
        else:
            self.pre_tasks_model.eval()
            self.model.train()
            res = self.MCFM_combiner.forward(self.model, criterion, current_image, current_label, active_classes_num,
                                             self.dataset_handler.classes_per_task, self.pre_tasks_model,
                                             task, criterion_ib=criterion_ib, reach_IB=reach_IB)

            loss, now_acc, now_cnt = res

        optimizer.zero_grad()
        sum(loss).backward()
        optimizer.step()
        return res
        pass

    def build_label_weight(self, active_classes_num, current_task_classes_imgs_num):
        pre_task_classes_num = active_classes_num - self.dataset_handler.classes_per_task
        label_weight = np.array([0, ] * active_classes_num)
        pre_tasks_classes_imgs_num = len(self.exemplar_manager.exemplar_sets[0])
        label_weight[0:pre_task_classes_num] = pre_tasks_classes_imgs_num
        label_weight[pre_task_classes_num:active_classes_num] = current_task_classes_imgs_num
        # label_weight = 1 / (active_classes_num * label_weight)
        per_cls_weights = 1.0 / label_weight
        per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * active_classes_num

        # label_weight = (label_weight.sum() / label_weight)
        # label_weight = label_weight / label_weight.sum()
        return torch.from_numpy(per_cls_weights).float()

    def IB_train_main(self, train_dataset, active_classes_num, task, label_weight):
        optimizer = self.build_optimize(0.01)
        scheduler = self.build_scheduler(optimizer)
        start_time = time.time()
        best_acc = 0
        criterion_ib = IBLoss(weight=label_weight, alpha=1000, active_classes_num=active_classes_num).cuda()
        for epoch in range(1, self.cfg.TRAIN.MAX_EPOCH + 1):
            all_loss = AverageMeter()
            acc = AverageMeter()
            if float(torch.__version__[:3]) < 1.3:
                scheduler.step()
            is_first_ite = True
            iters_left = 1
            iter_index = 0
            iter_num = 0
            while iters_left > 0:
                # Update # iters left on current data-loader(s) and, if needed, create new one(s)
                iters_left -= 1
                if is_first_ite:
                    is_first_ite = False
                    data_loader = iter(
                        DataLoader(dataset=train_dataset, batch_size=self.cfg.TRAIN.BATCH_SIZE,
                                   num_workers=self.cfg.TRAIN.NUM_WORKERS, shuffle=True, drop_last=True))
                    # NOTE:  [train_dataset]  is training-set of current task
                    #      [training_dataset] is training-set of current task with stored exemplars added (if requested)
                    iter_num = iters_left = len(data_loader)
                    continue

                #####-----CURRENT BATCH-----#####
                try:
                    x, y = next(data_loader)  # --> sample training data of current task
                except StopIteration:
                    raise ValueError("next(data_loader) error while read data. ")
                x, y = x.to(self.device), y.to(self.device)  # --> transfer them to correct device

                # ---> Train MAIN MODEL
                cnt = y.shape[0]
                loss, now_acc, now_cnt = self.IB_train_a_batch(optimizer, criterion_ib, x, y, active_classes_num)
                all_loss.update(loss[0].data.item(), cnt)
                acc.update(now_acc[0], cnt)
                if iter_index % self.cfg.SHOW_STEP == 0:
                    pbar_str = "IB train Epoch: {} || Batch:{:>3d}/{} || lr : {} || Batch_cls_Loss:{:>5.3f} || " \
                               "|| Batch_Accuracy:{:>5.2f}".format(epoch, iter_index,
                                                                   iter_num,
                                                                   optimizer.param_groups[
                                                                       0]['lr'],
                                                                   all_loss.val,
                                                                   acc.val * 100
                                                                   )

                    self.logger.info(pbar_str)
                iter_index += 1

            # if epoch % self.cfg.epoch_show_step == 0:
            # train_acc, train_loss = acc.avg, all_loss.avg
            # loss_dict, acc_dict = {"train_loss": train_loss}, {"train_acc": train_acc}
            if self.cfg.VALID_STEP != -1 and epoch % self.cfg.VALID_STEP == 0:
                pbar_str = "IB train Validate Epoch: {} || lr: {} || epoch_cls_Loss:{:>5.3f} || " \
                           "epoch_Accuracy:{:>5.2f}".format(epoch, optimizer.param_groups[0]['lr'],
                                                            all_loss.avg,
                                                            acc.val * 100
                                                            )
                self.logger.info(pbar_str)

                val_acc = self.validate_with_FC(task)  # task_id 从1开始

                if val_acc.mean() > best_acc:
                    best_acc, best_epoch = val_acc.mean(), epoch
                    self.best_model = copy.deepcopy(self.model)
                    self.best_epoch = best_epoch
                    self.best_acc = best_acc
                    self.logger.info(
                        "--------------IB train Best_Epoch:{:>3d}    Best_Acc:{:>5.2f}%--------------".format(
                            best_epoch, best_acc * 100
                        )
                    )
            # if writer:
            #     writer.add_scalars("scalar/acc", acc_dict, epoch)
            #     writer.add_scalars("scalar/loss", loss_dict, epoch)

            if float(torch.__version__[:3]) >= 1.3:
                scheduler.step()
        end_time = time.time()
        # todo
        pass

    def IB_train_a_batch(self, optimizer, criterion_ib, x, y, active_classes_num):
        # measure data loading time
        self.model.train()
        # compute output
        output, features = self.model(x=x, feature_flag=True, classifier_flag=True)
        output = output[:, 0:active_classes_num]
        loss = criterion_ib(output, y, features)
        _, now_result = torch.max(output, 1)
        now_acc, now_cnt = accuracy(now_result.cpu().numpy(), y.cpu().numpy())
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return [loss], [now_acc], [now_cnt]

    def validate_with_exemplars(self, task, is_test=False):
        # todo
        ncm_acc = []
        centroid_acc = []
        mode = self.model.training
        self.model.eval()
        for task_id in range(task):  # 这里的task 从0 开始
            if self.dataset_handler.val_datasets and (not is_test):
                predict_result = self.validate_with_exemplars_per_task(self.dataset_handler.val_datasets[task_id])
            else:
                predict_result = self.validate_with_exemplars_per_task(self.dataset_handler.test_datasets[task_id])
            ncm_acc.append(predict_result[0])
            centroid_acc.append(predict_result[1])
            self.logger.info(
                f"task : {task} || per task {task_id}, ncm acc:{predict_result[0]} || centroid acc: {predict_result[1]}"
            )
        self.model.train(mode=mode)
        return np.array(ncm_acc), np.array(centroid_acc)
        pass

    def validate_with_exemplars_per_task(self, val_dataset):
        # todo
        val_loader = DataLoader(dataset=val_dataset, batch_size=self.cfg.TRAIN.BATCH_SIZE,
                                num_workers=self.cfg.TRAIN.NUM_WORKERS, shuffle=False, drop_last=False)
        batch_time = AverageMeter()
        data_time = AverageMeter()
        NCM_top1 = AverageMeter()
        centroid_top1 = AverageMeter()
        end = time.time()

        for inputs, labels in val_loader:
            correct_temp = 0
            centroid_correct_temp = 0
            data_time.update(time.time() - end)
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            y_hat = self.exemplar_manager.classify_with_exemplars(inputs,
                                                                  self.model)  # x, model, classifying_approach="NCM", allowed_classes
            correct_temp += y_hat.eq(labels.data).cpu().sum()
            NCM_top1.update((correct_temp / inputs.size(0)).item(), inputs.size(0))
            if self.cfg.multi_centroid_classify:
                centroid_y_hat = self.exemplar_manager.classify_with_exemplars(inputs, self.model,
                                                                               classifying_approach="centroid")  # x, model, classifying_approach="NCM", allowed_classes
                centroid_correct_temp += centroid_y_hat.eq(labels.data).cpu().sum()
                centroid_top1.update((centroid_correct_temp / inputs.size(0)).item(), inputs.size(0))
            batch_time.update(time.time() - end)
            end = time.time()
        throughput = 1.0 / (batch_time.avg / self.cfg.TRAIN.BATCH_SIZE)
        if self.cfg.multi_centroid_classify:
            return NCM_top1.avg, centroid_top1.avg
        else:
            return NCM_top1.avg, 0
        pass

    def validate_with_FC(self, task, is_test=False):
        acc = []
        mode = self.model.training
        self.model.eval()
        for task_id in range(task):  # 这里的task 从0 开始
            if self.dataset_handler.val_datasets and (not is_test):
                predict_result = self.validate_with_FC_per_task(self.dataset_handler.val_datasets[task_id], task)
            else:
                predict_result = self.validate_with_FC_per_task(self.dataset_handler.test_datasets[task_id],
                                                                task)
            acc.append(predict_result)
            self.logger.info(
                f"task: {task} || per task {task_id}, validate_with_FC acc:{predict_result}"
            )
        acc = np.array(acc)
        self.model.train(mode=mode)
        # print(
        #     f"task {task} validate_with_exemplars, acc_avg:{acc.mean()}")
        # self.logger.info(
        #     f"per task {task}, validate_with_exemplars, avg acc:{acc.mean()}"
        #     f"-------------------------------------------------------------"
        # )
        return acc
        pass

    def validate_with_FC_per_task(self, val_dataset, task):
        # todo
        val_loader = DataLoader(dataset=val_dataset, batch_size=self.cfg.TRAIN.BATCH_SIZE,
                                num_workers=self.cfg.TRAIN.NUM_WORKERS, shuffle=False, drop_last=False)
        top1 = AverageMeter()
        correct = 0
        active_classes_num = self.dataset_handler.classes_per_task * task
        for inputs, labels in val_loader:
            correct_temp = 0
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            out = self.model(x=inputs, is_nograd=True, classifier_flag=True)
            _, balance_fc_y_hat = torch.max(out[:, 0:active_classes_num], 1)
            correct_temp += balance_fc_y_hat.eq(labels.data).cpu().sum()
            correct += correct_temp
            top1.update((correct_temp / inputs.size(0)).item(), inputs.size(0))
        return top1.avg
        pass

    def save_best_latest_model_data(self, model_dir, task_id, acc, epoch):
        if self.best_model is None:
            self.best_model = self.model
        if self.latest_model is None:
            self.latest_model = self.model
        if task_id == 1 or self.cfg.use_base_half and task_id == int(self.dataset_handler.all_tasks / 2):
            torch.save({
                'state_dict': self.best_model.state_dict(),
                'acc_result': self.best_acc,
                'best_epoch': self.best_epoch,
                'task_id': task_id
            }, os.path.join(model_dir, "base_best_model.pth")
            )
            torch.save({
                'state_dict': self.latest_model.state_dict(),
                'acc_result': acc,
                'latest_epoch': epoch,
                'task_id': task_id
            }, os.path.join(model_dir, "base_latest_model.pth")
            )
            split_selected_data = self.dataset_handler.get_split_selected_data()
            torch.save({
                'exemplar_sets': self.exemplar_manager.exemplar_sets,
                'store_original_imgs': self.exemplar_manager.store_original_imgs,
                'split_selected_data': split_selected_data
            }, os.path.join(model_dir, "base_exp_data_info.pkl")
            )
        else:
            torch.save({
                'state_dict': self.best_model.state_dict(),
                'acc_result': self.best_acc,
                'best_epoch': self.best_epoch,
                'task_id': task_id
            }, os.path.join(model_dir, "best_model.pth")
            )
            torch.save({
                'state_dict': self.latest_model.state_dict(),
                'acc_result': acc,
                'latest_epoch': epoch,
                'task_id': task_id
            }, os.path.join(model_dir, "latest_model.pth")
            )
            split_selected_data = self.dataset_handler.get_split_selected_data()
            torch.save({
                'exemplar_sets': self.exemplar_manager.exemplar_sets,
                'store_original_imgs': self.exemplar_manager.store_original_imgs,
                'split_selected_data': split_selected_data
            }, os.path.join(model_dir, "exp_data_info.pkl")
            )

        pass
