import logging
from copy import deepcopy

import torch
import torch.nn as nn
from base import Trainer
from dataloader.data_utils import get_dataloader
from utils import *
from augmentations import fantasy
from .continual_learning import compute_fisher_matrix, store_optimal_params, adjust_gradients

from .helper import *
from .Network import MYNET
from ..resnet18_encoder import mode_context


class FSCILTrainer(Trainer):
    def __init__(self, args):
        super().__init__(args)
        self.base_model_dict = None
        self.scheduler_novel = None
        self.optimizer_novel = None
        self.best_model_dict = None
        self.best_model_dict_meta = None
        self.prototypes = None
        self.args = args
        self.fisher_information = None
        self.optimal_params = None
        self.criterion = nn.CrossEntropyLoss()
        if args.fantasy is not None:
            self.transform, self.num_trans = fantasy.__dict__[args.fantasy]()
        else:
            self.transform = None
            self.num_trans = 0
        self.set_up_model()

    def set_up_model(self):
        self.model = MYNET(self.args, mode=self.args.base_mode)
        self.model = nn.DataParallel(self.model, list(range(self.args.num_gpu)))
        self.model = self.model.cuda()

        if self.args.model_dir is not None:
            logging.info('Loading init parameters from: %s' % self.args.model_dir)
            self.best_model_dict = torch.load(self.args.model_dir,
                                              map_location={'cuda:3': 'cuda:0'})['params']
        elif self.args.meta_model_dir is not None:
            logging.info('Loading emta parameters from: %s' % self.args.meta_model_dir)
            self.best_model_dict_meta = torch.load(self.args.meta_model_dir,
                                                   map_location={'cuda:3': 'cuda:0'})['params']
        else:
            logging.info('random init params')
            if self.args.start_session > 0:
                logging.info('WARING: Random init weights for new sessions!')
            self.best_model_dict = deepcopy(self.model.state_dict())

    def fine_tune_and_evaluate(self, trainloader, train_set, testloader, args, session, result_list, epochs, mask,
                               is_meta=False):
        logging.info('>>> Load Model &&& Finish Meta base train...')

        self.model.module.train_backbone()
        # self.model.module.train_backbone_all()
        # Set the mode
        mode_context.set('normal')
        if args.meta_train_is_done:
            assert args.meta_model_dir is not None

        self.model.load_state_dict(self.best_model_dict)  # 确保模型已经有best_model_dict_meta属性
        # self.model.module.encoder.train_fast()  # 确保模型有train_fast方法

        # logging.info(f'New classes for this session: {np.unique(trainloader.dataset.targets)}')
        # logging.info(
        #     f'New classes for this session: {np.unique([trainset.targets[i] for i in trainloader.dataset.indices])}')
        optimizer, scheduler = get_optimizer(args, self.model)

        for epoch in range(epochs):
            start_time = time.time()

            tl, ta = base_train(self.model, trainloader, train_set, optimizer, scheduler, epoch, args, session)
            # tl, ta = base_pretrain(self.model, trainloader, train_set, optimizer, scheduler, epoch, args, session)

            # self.model = replace_base_fc(train_set, testloader.dataset.transform, self.model, args)
            tsl, tsa = test(self.model, testloader, epoch, args, session, result_list)

            # Save better model
            if (tsa * 100) >= self.trlog['max_acc'][session]:
                self.trlog['max_acc'][session] = float('%.3f' % (tsa * 100))
                self.trlog['max_acc_epoch'] = epoch
                save_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
                torch.save(dict(params=self.model.state_dict()), save_model_dir)
                torch.save(optimizer.state_dict(), os.path.join(args.save_path, 'optimizer_best.pth'))
                self.best_model_dict = deepcopy(self.model.state_dict())  # 确保模型有best_model_dict属性
                logging.info('********A better model is found!!**********')
                logging.info('Saving model to: %s' % save_model_dir)

            logging.info(
                f'Best epoch {self.trlog["max_acc_epoch"]}, best test acc={self.trlog["max_acc"][session]:.3f}')

            self.trlog['train_loss'].append(tl)
            self.trlog['train_acc'].append(ta)
            self.trlog['test_loss'].append(tsl)
            self.trlog['test_acc'].append(tsa)
            lrc = scheduler.get_last_lr()[0]

            logging.info(
                f'Epoch: {epoch}, LR: {lrc:.4f}, Training Loss: {tl:.5f}, Training Acc: {ta:.5f}, Test Loss: {tsl:.5f}, Test Acc: {tsa:.5f}')
            print(
                f'This epoch takes {time.time() - start_time:.0f} seconds, \nstill need around {(time.time() - start_time) * (args.epochs_base - epoch) / 60:.2f} mins to finish this session')
            scheduler.step()

    def pretrain_before_meta(self, trainloader, train_set, testloader, args, session, result_list, epochs, mask,
                             is_meta=False):
        logging.info('>>> Load Model &&& Finish Meta base train...')

        self.model.module.train_backbone()
        # self.model.module.train_backbone_all()
        # Set the mode
        mode_context.set('normal')
        if args.meta_train_is_done:
            assert args.meta_model_dir is not None

        self.model.load_state_dict(self.best_model_dict)  # 确保模型已经有best_model_dict_meta属性
        # self.model.module.encoder.train_fast()  # 确保模型有train_fast方法

        # logging.info(f'New classes for this session: {np.unique(trainloader.dataset.targets)}')
        # logging.info(
        #     f'New classes for this session: {np.unique([trainset.targets[i] for i in trainloader.dataset.indices])}')
        optimizer, scheduler = get_optimizer(args, self.model)

        for epoch in range(epochs):
            start_time = time.time()

            tl, ta = base_pretrain(self.model, trainloader, optimizer, scheduler, epoch, args, session)
            # self.model = replace_base_fc(train_set, testloader.dataset.transform, self.model, args)
            tsl, tsa = test(self.model, testloader, epoch, args, session, result_list)

            # Save better model
            if (tsa * 100) >= self.trlog['max_acc'][session]:
                self.trlog['max_acc'][session] = float('%.3f' % (tsa * 100))
                self.trlog['max_acc_epoch'] = epoch
                save_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
                torch.save(dict(params=self.model.state_dict()), save_model_dir)
                torch.save(optimizer.state_dict(), os.path.join(args.save_path, 'optimizer_best.pth'))
                self.best_model_dict = deepcopy(self.model.state_dict())  # 确保模型有best_model_dict属性
                logging.info('********A better model is found!!**********')
                logging.info('Saving model to: %s' % save_model_dir)

            logging.info(
                f'Best epoch {self.trlog["max_acc_epoch"]}, best test acc={self.trlog["max_acc"][session]:.3f}')

            self.trlog['train_loss'].append(tl)
            self.trlog['train_acc'].append(ta)
            self.trlog['test_loss'].append(tsl)
            self.trlog['test_acc'].append(tsa)
            lrc = scheduler.get_last_lr()[0]

            logging.info(
                f'Epoch: {epoch}, LR: {lrc:.4f}, Training Loss: {tl:.5f}, Training Acc: {ta:.5f}, Test Loss: {tsl:.5f}, Test Acc: {tsa:.5f}')
            print(
                f'This epoch takes {time.time() - start_time:.0f} seconds, \nstill need around {(time.time() - start_time) * (args.epochs_base - epoch) / 60:.2f} mins to finish this session')
            scheduler.step()

    def sup_con_finetune(self, trainloader, train_set, testloader, args, session, result_list, epochs, mask,
                         is_meta=False):
        logging.info('>>> Load Model &&& Finish Meta base train...')

        self.model.module.train_backbone()
        # self.model.module.train_backbone_all()
        # Set the mode
        mode_context.set('normal')
        if args.meta_train_is_done:
            assert args.meta_model_dir is not None

        self.model.load_state_dict(self.best_model_dict)  # 确保模型已经有best_model_dict_meta属性
        # self.model.module.encoder.train_fast()  # 确保模型有train_fast方法

        # logging.info(f'New classes for this session: {np.unique(trainloader.dataset.targets)}')
        # logging.info(
        #     f'New classes for this session: {np.unique([trainset.targets[i] for i in trainloader.dataset.indices])}')
        optimizer, scheduler = get_optimizer(args, self.model)

        for epoch in range(epochs):
            start_time = time.time()

            # tl, ta = base_train(self.model, trainloader, optimizer, scheduler, epoch, args, mask)
            tl, ta = supcon_train(self.model, trainloader, optimizer, scheduler, epoch, args, mask)
            self.model = replace_base_fc(train_set, testloader.dataset.transform, self.model, args)
            tsl, tsa = test(self.model, testloader, epoch, args, session, result_list)

            # Save better model
            if (tsa * 100) >= self.trlog['max_acc'][session]:
                self.trlog['max_acc'][session] = float('%.3f' % (tsa * 100))
                self.trlog['max_acc_epoch'] = epoch
                save_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
                torch.save(dict(params=self.model.state_dict()), save_model_dir)
                torch.save(optimizer.state_dict(), os.path.join(args.save_path, 'optimizer_best.pth'))
                self.best_model_dict = deepcopy(self.model.state_dict())  # 确保模型有best_model_dict属性
                logging.info('********A better model is found!!**********')
                logging.info('Saving model to: %s' % save_model_dir)

            logging.info(
                f'Best epoch {self.trlog["max_acc_epoch"]}, best test acc={self.trlog["max_acc"][session]:.3f}')

            self.trlog['train_loss'].append(tl)
            self.trlog['train_acc'].append(ta)
            self.trlog['test_loss'].append(tsl)
            self.trlog['test_acc'].append(tsa)
            lrc = scheduler.get_last_lr()[0]

            logging.info(
                f'Epoch: {epoch}, LR: {lrc:.4f}, Training Loss: {tl:.5f}, Training Acc: {ta:.5f}, Test Loss: {tsl:.5f}, Test Acc: {tsa:.5f}')
            print(
                f'This epoch takes {time.time() - start_time:.0f} seconds, \nstill need around {(time.time() - start_time) * (args.epochs_base - epoch) / 60:.2f} mins to finish this session')
            scheduler.step()

    def fine_tune_post(self, trainloader, testloader, args, session, result_list, epochs, is_meta=False):
        logging.info('>>> Load Model &&& Finish Meta base train...')

        self.model.module.train_backbone()
        # self.model.module.train_backbone_all()
        # Set the mode
        mode_context.set('normal')
        if args.meta_train_is_done:
            assert args.meta_model_dir is not None

        self.model.load_state_dict(self.best_model_dict)  # 确保模型已经有best_model_dict_meta属性
        # self.model.module.encoder.train_fast()  # 确保模型有train_fast方法

        logging.info(f'New classes for this session: {np.unique(trainloader.dataset.targets)}')
        # logging.info(
        #     f'New classes for this session: {np.unique([trainset.targets[i] for i in trainloader.dataset.indices])}')
        optimizer, scheduler = get_optimizer_post(args, self.model)

        for epoch in range(epochs):
            start_time = time.time()

            tl, ta = base_train_with_feature_deviation_constraint(self.model, trainloader, optimizer, scheduler, epoch,
                                                                  args, self.prototypes)
            tsl, tsa = test(self.model, testloader, epoch, args, session, result_list)

            # Save better model
            if (tsa * 100) >= self.trlog['max_acc'][session]:
                self.trlog['max_acc'][session] = float('%.3f' % (tsa * 100))
                self.trlog['max_acc_epoch'] = epoch
                save_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
                torch.save(dict(params=self.model.state_dict()), save_model_dir)
                torch.save(optimizer.state_dict(), os.path.join(args.save_path, 'optimizer_best.pth'))
                self.best_model_dict = deepcopy(self.model.state_dict())  # 确保模型有best_model_dict属性
                logging.info('********A better model is found!!**********')
                logging.info('Saving model to: %s' % save_model_dir)

            logging.info(
                f'Best epoch {self.trlog["max_acc_epoch"]}, best test acc={self.trlog["max_acc"][session]:.3f}')

            self.trlog['train_loss'].append(tl)
            self.trlog['train_acc'].append(ta)
            self.trlog['test_loss'].append(tsl)
            self.trlog['test_acc'].append(tsa)
            lrc = scheduler.get_last_lr()[0]

            logging.info(
                f'Epoch: {epoch}, LR: {lrc:.4f}, Training Loss: {tl:.5f}, Training Acc: {ta:.5f}, Test Loss: {tsl:.5f}, Test Acc: {tsa:.5f}')
            print(
                f'This epoch takes {time.time() - start_time:.0f} seconds, \nstill need around {(time.time() - start_time) * (args.epochs_base - epoch) / 60:.2f} mins to finish this session')
            scheduler.step()

    def train(self, ):
        args = self.args
        t_start_time = time.time()
        # init train statistics
        result_list = [args]
        # gen_mask
        masknum = 3
        mask = np.zeros((args.base_class, args.num_classes))
        for i in range(args.num_classes - args.base_class):
            picked_dummy = np.random.choice(args.base_class, masknum, replace=False)
            mask[:, i + args.base_class][picked_dummy] = 1
        mask = torch.tensor(mask).cuda()
        for session in range(args.start_session, args.sessions):
            train_set, trainloader, testloader = get_dataloader(args, session)
            self.model.load_state_dict(self.best_model_dict)
            if session == 0:  # load base class train img label
                if not args.only_do_incre:
                    if args.metabase and args.project == 'meta_adapter':
                        if not args.meta_train_is_done:
                            # Execute meta-learning pretraining with fake incremental data

                            pretrain_loader, postrain_loader, support_loader, query_loader = get_fake_incremental_data(
                                train_set,
                                np.arange(
                                    args.base_class),
                                args)

                            # self.pretrain_before_meta(pretrain_loader, train_set, testloader, args, session,
                            #                             result_list, args.epochs_pretrain, mask)
                            # self.model.load_state_dict(self.best_model_dict)
                            #
                            # # Perform meta-learning pretraining using the Reptile algorithm

                            logging.info("Executing meta-learning pretraining...")
                            # if args.dataset == 'cub200':
                            #     # if args.dataset == 'cub200' or args.dataset == 'mini_imagenet':
                            #     with torch.no_grad():  # 确保这一操作不会被跟踪用于梯度计算
                            #         for name, param in self.model.named_parameters():
                            #             if 'adapter' in name:
                            #                 # 初始化为非常小的随机数
                            #                 # param.uniform_(-0.01, 0.01)  # 使用较小范围的均匀分布进行初始化
                            #                 param.zero_()  # 将参数直接设置为零
                            meta_loss_avg, meta_acc_avg = meta_train_reptile(self.model, pretrain_loader,
                                                                             postrain_loader, trainloader, train_set,
                                                                             testloader,
                                                                             support_loader,
                                                                             query_loader, args)

                            # # Save the model after meta-learning pretraining, including "meta" in the file name
                            # # self.fine_tune_and_evaluate(trainloader, testloader, args, session,
                            # #                             result_list, args.epochs_pretrain)
                            self.best_model_dict = deepcopy(
                                self.model.state_dict())  # Naming the variable with "meta" to indicate meta-learning
                            save_model_dir_meta = os.path.join(args.save_path,
                                                               'session' + str(session) + '_meta_pretrained.pth')
                            torch.save({'params': self.best_model_dict}, save_model_dir_meta)

                            logging.info(f'Meta-learning pretraining completed. Model saved to: {save_model_dir_meta}')
                            # train_set, trainloader, testloader = get_base_dataloader_new(args)
                            # self.sup_con_finetune(trainloader, train_set, testloader, args, session,
                            #                             result_list, args.epochs_pretrain, mask)
                            self.fine_tune_and_evaluate(trainloader, train_set, testloader, args, session,
                                                        result_list, args.epochs_postrain, mask)

                        # self.model.module.encoder.train_fast()


                    else:
                        logging.info(f'new classes for this session:{np.unique(train_set.targets)}')
                        optimizer, scheduler = get_optimizer(args, self.model)
                        for epoch in range(args.epochs_base):
                            start_time = time.time()

                            tl, ta = base_train(self.model, trainloader, train_set, optimizer, scheduler, epoch, args,session)
                            tsl, tsa = test(self.model, testloader, epoch, args, session, result_list=result_list)

                            # save better model
                            if (tsa * 100) >= self.trlog['max_acc'][session]:
                                self.trlog['max_acc'][session] = float('%.3f' % (tsa * 100))
                                self.trlog['max_acc_epoch'] = epoch
                                save_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
                                torch.save(dict(params=self.model.state_dict()), save_model_dir)
                                torch.save(optimizer.state_dict(), os.path.join(args.save_path, 'optimizer_best.pth'))
                                self.best_model_dict = deepcopy(self.model.state_dict())
                                logging.info('********A better model is found!!**********')
                                logging.info('Saving model to :%s' % save_model_dir)
                            logging.info('best epoch {}, best test acc={:.3f}'.format(
                                self.trlog['max_acc_epoch'], self.trlog['max_acc'][session]))

                            self.trlog['train_loss'].append(tl)
                            self.trlog['train_acc'].append(ta)
                            self.trlog['test_loss'].append(tsl)
                            self.trlog['test_acc'].append(tsa)
                            lrc = scheduler.get_last_lr()[0]

                            logging.info(
                                'epoch:%03d,lr:%.4f,training_loss:%.5f,training_acc:%.5f,test_loss:%.5f,test_acc:%.5f' % (
                                    epoch, lrc, tl, ta, tsl, tsa))
                            print('This epoch takes %d seconds' % (time.time() - start_time),
                                  '\n still need around %.2f mins to finish this session' % (
                                          (time.time() - start_time) * (args.epochs_base - epoch) / 60))
                            scheduler.step()

                        # Finish base train
                        logging.info('>>> Finish Base Train <<<')
                        result_list.append('Session {}, Test Best Epoch {},\nbest test Acc {:.4f}\n'.format(
                            session, self.trlog['max_acc_epoch'], self.trlog['max_acc'][session]))
                else:
                    logging.info('>>> Load Model &&& Finish base train...')
                    assert args.model_dir is not None

                if not args.not_data_init:
                    train_set, trainloader, testloader = get_dataloader(args, session)
                    mode_context.set('normal')
                    # print(mode_context.get())
                    self.model.load_state_dict(self.best_model_dict)
                    tsl, tsa = test(self.model, testloader, 0, args, session, result_list=result_list)
                    self.model = replace_base_fc(train_set, testloader.dataset.transform, self.model, args)

                    if self.prototypes is None:
                        self.prototypes = calculate_class_prototypes(train_set, session, testloader.dataset.transform,
                                                                     self.model,
                                                                     args)
                    else:
                        self.prototypes = torch.stack(
                            calculate_class_prototypes(train_set, session, testloader.dataset.transform,
                                                       self.model,
                                                       args), dim=0)
                    # self.model.module.soft_calibration(args, session, self.prototypes)
                    # self.fine_tune_post(trainloader, testloader, args, session,
                    #                     result_list, args.epochs_postrain)
                    best_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc_r.pth')
                    logging.info('Replace the fc with average embedding, and save it to :%s' % best_model_dir)
                    self.best_model_dict = deepcopy(self.model.state_dict())

                    torch.save(dict(params=self.model.state_dict()), best_model_dir)

                    # self.fisher_information = compute_fisher_matrix(self.model, trainloader, self.criterion, session)
                    # self.optimal_params = store_optimal_params(self.model)

                    self.model.module.mode = args.base_mode

                    tsl, tsa = test(self.model, testloader, 0, args, session, result_list=result_list)
                    if (tsa * 100) >= self.trlog['max_acc'][session]:
                        self.trlog['max_acc'][session] = float('%.3f' % (tsa * 100))
                        logging.info('The new best test acc of base session={:.3f}'.format(
                            self.trlog['max_acc'][session]))
                    # save dummy classifiers
                    self.dummy_classifiers = deepcopy(self.model.module.fc.weight.detach())

                    self.dummy_classifiers = F.normalize(self.dummy_classifiers[self.args.base_class:, :], p=2,
                                                         dim=-1)
                    self.old_classifiers = self.dummy_classifiers[:self.args.base_class, :]

            # incremental learning sessions
            else:
                # pass
                logging.info("training session: [%d]" % session)
                self.model.module.mode = self.args.new_mode
                # Set the mode
                # mode_context.set('parallel_adapters')
                # self.model.eval()
                trainloader.dataset.transform = testloader.dataset.transform

                # if args.soft_mode == 'soft_proto':
                #     self.model.module.update_fc(trainloader, np.unique(train_set.targets), session)
                #     self.model.module.soft_calibration(args, session)
                # else:
                #     raise NotImplementedError
                # self.model.module.update_fc(trainloader, np.unique(train_set.targets), session)

                # Assuming your model has these methods/attributes
                if hasattr(self.model.module, 'encoder') and hasattr(self.model.module.encoder, 'train_meta_block'):
                    self.model.module.encoder.train_meta_block()
                    self.model.module.encoder.train_novel = True
                if hasattr(self.model.module, 'mode'):
                    self.model.module.mode = args.new_mode

                optimizer, scheduler = get_optimizer_novel(args, self.model)

                # self.model.module.soft_calibration(args, session, self.prototypes)
                self.prototypes = torch.cat((self.prototypes,
                                             calculate_class_prototypes_incremental(session,
                                                                                    testloader.dataset.transform,
                                                                                    self.model,
                                                                                    args)), dim=0)
                self.model.module.fc.weight.data[
                args.base_class + (session - 1) * args.way: args.base_class + session * args.way] = self.prototypes[
                                                                                                    args.base_class + (
                                                                                                            session - 1) * args.way: args.base_class + session * args.way]
                # self.model.module.soft_calibration(args, session, self.prototypes)
                # self.model.module.update_fc(trainloader, np.unique(train_set.targets), session)
                # Freeze specific layers like self.fc
                # self.best_model_dict = deepcopy(self.model.state_dict())
                incremental_train(self.prototypes, self.model, trainloader, optimizer, session, args,
                                  self.best_model_dict)
                # self.mosdel.module.update_fc(trainloader, np.unique(train_set.targets), session)
                # self.prototypes = torch.cat((self.prototypes,
                #                              calculate_class_prototypes_incremental(session,
                #                                                                     testloader.dataset.transform,
                #                                                                     self.model,
                #                                                                     args)), dim=0)
                # self.model.module.soft_calibration(args, session, self.prototypes)
                # self.fisher_information = compute_fisher_matrix(self.model, trainloader, self.criterion, session)

                self.prototypes[args.base_class + (
                        session - 1) * args.way: args.base_class + session * args.way] = calculate_class_prototypes_incremental(
                    session,
                    testloader.dataset.transform,
                    self.model,
                    args)
                base_protos = self.prototypes[:args.base_class + (session - 1) * args.way]
                self.model.module.fc.weight.data[: args.base_class + (session - 1) * args.way] = base_protos
                self.model.module.fc.weight.data[
                args.base_class + (session - 1) * args.way: args.base_class + session * args.way] = self.prototypes[
                                                                                                    args.base_class + (
                                                                                                            session - 1) * args.way: args.base_class + session * args.way]

                # #
                # self.model.module.soft_calibration(args, session, self.prototypes)
                # for name, param in self.model.named_parameters():
                #     if 'adapter' in name:
                #         # Get the current parameter data
                #         current_param = param.data
                #         # print(current_param[-1])
                #         #Set the current parameter data to zero
                #         param.data.zero_()
                # self.model.module.soft_calibration(args, session, self.prototypes)
                self.optimal_params = store_optimal_params(self.model)
                self.model.module.mode = self.args.new_mode
                tsl, (seenac, unseenac, avgac) = test(self.model, testloader, 0, args, session, result_list=result_list)

                # update results and save model
                self.trlog['seen_acc'].append(float('%.3f' % (seenac * 100)))
                self.trlog['unseen_acc'].append(float('%.3f' % (unseenac * 100)))
                self.trlog['max_acc'][session] = float('%.3f' % (avgac * 100))
                model_dir = os.path.join(args.save_path, f'session{session}_max_acc.pth')
                optimizer_dir = os.path.join(args.save_path, 'optimizer_best.pth')

                torch.save({'params': self.model.state_dict()}, model_dir)
                # torch.save(optimizer.state_dict(), optimizer_dir)

                self.best_model_dict = deepcopy(self.model.state_dict())

                logging.info(f"Session {session} ==> Seen Acc:{self.trlog['seen_acc'][-1]} "
                             f"Unseen Acc:{self.trlog['unseen_acc'][-1]} Avg Acc:{self.trlog['max_acc'][session]}")
                result_list.append('Session {}, test Acc {:.3f}\n'.format(session, self.trlog['max_acc'][session]))

        # session = 8
        # train_set, trainloader, testloader = get_dataloader(args, session)
        # if session == 8:
        #     visualize_classes(testloader,args, session)
        # #     # save_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
        # #     # torch.save(dict(params=self.model.state_dict()), save_model_dir)
        # #     # for i in range(100):
        # #     test_and_plot_confusion_matrix(self.model, testloader, args, session)
        # plot_tsne(self.model, testloader, args, session,seed=13)
        # Finish all incremental sessions, save results.
        result_list, hmeans = postprocess_results(result_list, self.trlog)
        save_list_to_txt(os.path.join(args.save_path, 'results.txt'), result_list)
        if not self.args.debug:
            save_result(args, self.trlog, hmeans)

        t_end_time = time.time()
        total_time = (t_end_time - t_start_time) / 60
        logging.info(f"Base Session Best epoch:{self.trlog['max_acc_epoch']}")
        logging.info('Total time used %.2f mins' % total_time)
        logging.info(self.args.time_str)
