import json
import math
import multiprocessing
from multiprocessing import shared_memory
import torch.multiprocessing as mp
import os
import os.path as osp
import random
import time
from cmath import isnan
from collections import defaultdict
from copy import deepcopy
from multiprocessing import Manager, Process, Value, Queue
from pickletools import optimize
from posixpath import split

import numpy as np
import torch
import torch.nn.functional as F
from torch import autograd, nn
from torch.autograd import Variable
from torch.nn import DataParallel

from .loss_function import *
from .models import CosClassifier
from .models.utils.load_parameters import (load_ckpt,
                                                model_load_state_dict,
                                                save_ckpt)
from .models.utils.lr_strategies import get_lr_strategy
from .utils import may_make_dir, to_scalar, calculate_cos_affinity, calculate_single_cos_affinity, norm_2
from .utils.meters import AverageMeter
from .utils.smooth_label import smooth_label
from .base_trainer import Memory_Trainer
from .utils.roate_tensor import multi_calculate_R, generate_raw_feature




class Phase_Trainer(Memory_Trainer):
    def __init__(self, backbone_model, data_loader, tester, TVT, TMO, device_ids=None, **kwargs):
        super(Phase_Trainer, self).__init__(backbone_model, data_loader, tester, TVT, TMO, device_ids = device_ids)
        # ==> general setting
        self.backbone_feature_loss = kwargs['backbone_feature_loss']
        self.backbone_local_loss = kwargs['backbone_local_loss']
        self.backbone_local_loss_weight = kwargs['backbone_local_loss_weight']
        self.backbone_optimizer = kwargs['backbone_optimizer']
        self.backbone_module_optimize = kwargs['backbone_module_optimize']
        # self.use_gsam = kwargs['use_gsam']
        # ==> lr setting
        self.lr_satus = kwargs['lr_satus']
        self.lr_decay_type = kwargs['lr_decay_type']
        self.lr_decay_kwargs = kwargs['lr_decay_kwargs']
        self.change_lr = get_lr_strategy(self.lr_decay_type)
        # ==> train process setting
        self.resume = kwargs['resume']
        self.steps_per_log = kwargs['steps_per_log']
        self.epochs_per_val = kwargs['epochs_per_val']
        # ==> weight paths
        self.current_train_path = kwargs['current_train_path']
        self.old_train_path = os.path.splitext(self.current_train_path)[0] + '_old.pth'
        self.memory_file = kwargs['memory_file']
        self.backbone_baselr = [i['lr'] for i in self.backbone_optimizer.param_groups]
        # ==> record
        self.backbone_baselr_len = len(self.backbone_baselr)
        self.backbone_opt_len = len(self.backbone_optimizer.param_groups)
        self.backbone_module_opt_len = len(self.backbone_module_optimize['mod'])


    def _init_classfier(self):
        feature_in_dim = self.feature_in_dim
        trained_memory_num = torch.sum((self.initial_memory==False).int())
        if self.resume:
            init_classfier = CosClassifier(feature_in_dim, trained_memory_num)
            if self.start_epoch == 0:
                map_location = (lambda storage, loc: storage)
                old_ckpt = torch.load(self.current_train_path, map_location=map_location)
                src_state_dict = old_ckpt['mod_state_dicts'][-1]
                src_param = src_state_dict['module.fc.weight'].data
                end = trained_memory_num if self.data_loader.part!='expert' else trained_memory_num-self.phase
                with torch.no_grad():
                    init_classfier.fc.weight[:end].copy_(src_param)
        else:
            classifier_dim = self.init_class_number if trained_memory_num == 0 else trained_memory_num+self.phase
            init_classfier = CosClassifier(feature_in_dim, classifier_dim)
            # ------ for old task fc ------ #
            # assume the labels are in order
            if self.classfier is not None:
                src_state_dict = self.classfier.state_dict()
                src_param = src_state_dict['module.fc.weight'].data
                with torch.no_grad():
                    init_classfier.fc.weight[:trained_memory_num].copy_(src_param)
                # after create the self.classfier
                del(self.classfier)
                torch.cuda.empty_cache()
            # ------ end ------- #
        self.classfier = init_classfier
        
        classfier_param = list(self.classfier.parameters())
        base_lr = self.backbone_baselr[0]*self.lr_satus
        classfier_param_groups = [{'params':classfier_param, 'lr':base_lr}]
        new_parameters_group = self.backbone_optimizer.param_groups[:self.backbone_opt_len]+classfier_param_groups
        self.backbone_baselr = self.backbone_baselr[:self.backbone_baselr_len]+[base_lr]
        self.backbone_optimizer = torch.optim.SGD(new_parameters_group, momentum=0.9)
        self.classfier = DataParallel(self.classfier, device_ids=self.device_ids)
        self.TMO([self.classfier])
        self.backbone_module_optimize['mod'] = self.backbone_module_optimize['mod'][:self.backbone_module_opt_len]+[self.classfier]


    def _train_init(self):
        # init memory
        if self.resume:
            assert os.path.exists(
                self.current_train_path), "The weight file haven't been saved!"
            assert os.path.exists(self.memory_file), "The memory file can't be loaded!"
            assert os.path.exists(
                    self.old_train_path), "The old model weight file haven't been saved!"
            self._init_memory(memory_file=self.memory_file)
            self._load_parameters(self.backbone_optimizer, self.backbone_module_optimize, self.current_train_path)
            if self.data_loader.part != 'expert':
                assert os.path.exists(
                    self.old_train_path), "The old model weight file haven't been saved!"
            self.start_task = self.data_loader.task_ptr-1
        else:
            self.step = 0
            self.start_epoch = 0
            self.start_task = 0
            self._init_memory()

        self.task_number = self.data_loader.task_number
        self.feature_per_step = self.data_loader.use_ims_per_id
        self.phase = self.data_loader.phase
        self.init_class_number = self.data_loader.init_class_number

        # generate the best model floer
        weight_upper_floder = os.path.split(self.current_train_path)[0]
        self.best_model_path = os.path.join(weight_upper_floder, 'BestModel')
        may_make_dir(self.best_model_path)



    def _init_task(self):
        # initial model
        self._init_classfier()
        self.classfier.train()
        self.backbone_model.train()
        self.backbone_optimizer.zero_grad()
        if self.resume and self.start_epoch > 0:
            self._load_parameters(self.backbone_optimizer, self.backbone_module_optimize, self.current_train_path)
        if not self.resume:
            for i, _ in enumerate(self.backbone_optimizer.param_groups):
                self.backbone_optimizer.param_groups[i]['lr'] = self.backbone_baselr[i]
                


    def _update_memory(self, ims_var, labels_var):
        # => graph method
        # compress_feat is used for metric network to calculate the affinity
        embedding_feat, _ = self.backbone_model(ims_var)
        # if the memory[label] is [0,0,0,...0], then use the first feature as memory
        batch_labels = torch.unique_consecutive(labels_var)
        batch_label_mask = self.initial_memory[batch_labels]
        first_feature = torch.stack([embedding_feat[labels_var==i][0] for i in batch_labels])
        if True in batch_label_mask:
            changed_label = batch_labels[batch_label_mask]
            self.memory[changed_label] = norm_2(first_feature[batch_label_mask])
            self.initial_memory[batch_labels] = False
        # get memory in mini batch
        normed_feature = norm_2(embedding_feat)
        _batch_memory = self.memory[labels_var]
        batch_memory = self.memory[batch_labels]
        epsilon=1e-7
        batch_affinity = torch.sum(_batch_memory*normed_feature, dim=-1, keepdim=True)
        batch_affinity = torch.clamp(batch_affinity, -1 + epsilon, 1 - epsilon)
        
        # update memory
        # lbd = 0.6
        cos_normed_feature = normed_feature*batch_affinity
        _update_item = [cos_normed_feature[labels_var==i] for i in batch_labels]
        _updated_memory = [torch.cat((i, batch_memory[index].unsqueeze(0))) for index, i in enumerate(_update_item)]
        updated_memory = torch.stack([norm_2(torch.mean(i, dim=-2)) for i in _updated_memory])
        # try norm affinity
        # _updated_memory = [torch.cat((i*(1-lbd), lbd*batch_memory[index].unsqueeze(0))) for index, i in enumerate(_update_item)]
        # updated_memory = torch.stack([self._norm_2(torch.sum(i, dim=-2)) for i in _updated_memory])

        self.memory[batch_labels] = updated_memory


    def _update_cos_mean(self, ims_var, labels_var):
        # fusion_feat is used for metric network to calculate the affinity
        embedding_feat, _ = self.backbone_model(ims_var)
        batch_labels = torch.unique_consecutive(labels_var)
        
        # get memory in mini batch
        normed_feature = norm_2(embedding_feat)
        _batch_memory = self.memory[labels_var]
        epsilon=1e-7
        batch_affinity = torch.sum(_batch_memory*normed_feature, dim=-1, keepdim=True)
        batch_affinity = torch.clamp(batch_affinity, -1 + epsilon, 1 - epsilon)
        
        # update memory mean
        # lbd = 0.6
        slices_num = self.memory_mean.shape[-1]
        slices = 1.0/slices_num
        for i in range(slices_num):
            min_cos, max_cos = i*slices, (i+1)*slices
            selected = (batch_affinity<=max_cos) & (batch_affinity>min_cos)
            for label in batch_labels:
                affi_in_label = selected[labels_var==label]
                num = torch.sum(affi_in_label.squeeze(-1)==True)
                self.memory_mean[label][i]+=num
    

    def _global_update_backbone(self, ims, labels, old_model=None, use_prototype=True, release=True):
        release = use_prototype & release
        ims_var = self.TVT(torch.from_numpy(ims).float())
        labels_var = self.TVT(torch.from_numpy(labels).long())
        # compress_feat is used for metric network to calculate the affinity
        embedding_feat, _ = self.backbone_model(ims_var)
        embedding_feat = norm_2(embedding_feat)
        all_labels = self.TVT(torch.arange(len(self.initial_memory)))
        trained_label = all_labels[self.initial_memory==False]
        # transfer labels
        img_per_id = self.data_loader.use_ims_per_id
        trained_number = len(trained_label)
        phase_num = self.data_loader.phase
        train_memory_num = trained_number-phase_num
        if use_prototype:
            memory = self.memory[self.initial_memory==False]
            _index = self.TVT(torch.arange(trained_number))
            batch_labels = self.TVT(torch.unique_consecutive(labels_var))
            transfered_index = torch.cat([_index[trained_label==i] for i in batch_labels])
            transed_labels_var = transfered_index.unsqueeze(-1).repeat(1, img_per_id)
            transed_labels_var = self.TVT(transed_labels_var.view(-1))

            f_m_affi = calculate_cos_affinity(embedding_feat, memory, smooth=False)
            current_f_loss = self.backbone_feature_loss(f_m_affi, transed_labels_var)
        else:
            current_f_loss = torch.tensor(0, dtype=torch.float)

        # ------- for classifier training --------- #
        if old_model is not None and release:
            class_losses = []
            # train random memory only once
            # start_time = time.time()
            sample_dict = self.release_memory(self.feature_per_step)
            # sample_dict = self.repeat_memory(self.feature_per_step)
            # print('Release Memory Time: %.4f'%(time.time()-start_time))
            memo_label_var = torch.tensor(list(sample_dict.keys()))
            transfered_index = torch.cat([_index[trained_label==i] for i in memo_label_var])
            memo_label_var = transfered_index.unsqueeze(-1).repeat(1, self.feature_per_step)
            memo_label_var = self.TVT(memo_label_var.view(-1))
            memo_feature_var = torch.cat(list(sample_dict.values()))
            memo_feature_var = self.TVT(memo_feature_var)

            merged_memo = torch.cat([embedding_feat, memo_feature_var])
            merged_label = torch.cat([labels_var, memo_label_var])
            f_c_affi = self.classfier(merged_memo)
            class_loss = self.backbone_feature_loss(f_c_affi, merged_label)

            class_losses.append(class_loss)
            class_loss = sum(class_losses)/len(class_losses)
        else:
            f_c_affi = self.classfier(embedding_feat)
            class_loss = self.backbone_feature_loss(f_c_affi, labels_var)

        if old_model is not None:
            with torch.no_grad():
                old_embedding_feat, _ = old_model(ims_var)
            affinity = calculate_single_cos_affinity(embedding_feat, old_embedding_feat, smooth=False)
            # -> normaled version
            scale = 16
            new_norm_f = scale*norm_2(embedding_feat)
            old_norm_f = scale*norm_2(old_embedding_feat)
            l_loss = self.backbone_local_loss(new_norm_f, old_norm_f)
            mse_loss = self.mse_loss(new_norm_f, old_norm_f)
            
            cos_loss = torch.mean(1-affinity)
            local_loss_weight = self.backbone_local_loss_weight
            kd_loss = local_loss_weight*30*mse_loss
        else:
            kd_loss = torch.tensor(0, dtype=torch.float)
            mse_loss = torch.tensor(0, dtype=torch.float)
            cos_loss = torch.tensor(0, dtype=torch.float)
            l_loss = torch.tensor(0, dtype=torch.float)

        f_loss = current_f_loss + class_loss
        loss = f_loss + kd_loss
        loss_list = [current_f_loss, class_loss, l_loss, cos_loss, mse_loss]
        # update model gradient
        loss.backward()
        # ----- for old task fc ----- #
        if old_model is not None:
            self.classfier.module.fc.weight.grad[:train_memory_num]*=0.001
            # self.classfier.module.fc.weight.grad[:train_memory_num]*=0.0
        # ----- end ----- #
        return [loss, f_loss, loss_list]

    def release_memory(self, feature_number):
        # use original memory
        all_labels = torch.arange(len(self.initial_memory))
        task_label = self.data_loader.current_task_label
        self.initial_memory[task_label] = True
        trained_memory_indx = all_labels[self.initial_memory.data.cpu()==False]
        self.initial_memory[task_label] = False
        memory_indx = list(trained_memory_indx.numpy())
        # choiced_memo_indx = random.sample(memory_indx, memory_number)
        choiced_memo_indx = memory_indx
        # ------------------------------------------------------------- #
        # use class weight as memory
        # choiced_memo_indx = list(range(len(old_class)))
        # trained_memory = old_class
        # end #
        trans_feature_dict = dict()
        for i in choiced_memo_indx:
            # -----set for original distribution----- #
            distribute = F.normalize(self.memory_mean[i].data.cpu(), p=1, dim=-1)
            cos_templat = torch.arange(0,1,1/self.memory_mean.shape[-1])
            cos_list = cos_templat[torch.multinomial(distribute, feature_number, replacement=True)]
            # -----set for original distribution----- #
            dim = self.memory[i].shape[-1]
            released_raw_features = generate_raw_feature(cos_list, dim)
            R = self.R_dict[i]
            released_feature = torch.mm(released_raw_features, R.t())
            trans_feature_dict[i] = released_feature
        return trans_feature_dict

    def repeat_memory(self, feature_number):
        # use original memory
        all_labels = torch.arange(len(self.initial_memory))
        task_label = self.data_loader.current_task_label
        self.initial_memory[task_label] = True
        trained_memory_indx = all_labels[self.initial_memory.data.cpu()==False]
        self.initial_memory[task_label] = False
        memory_indx = list(trained_memory_indx.numpy())
        choiced_memo_indx = memory_indx
        # ------------------------------------------------------------- #

        trans_feature_dict = dict()
        for i in choiced_memo_indx:
            _memory = self.memory[i]
            trans_feature_dict[i] = _memory.unsqueeze(0).repeat(feature_number, 1)
        return trans_feature_dict
    

    def distill_model(self, scale, distill_factor=1.0):
        '''
        NOTE:
        Load old parameters, and then distill
        '''
        current_weight = deepcopy(self.backbone_model.state_dict())
        assert os.path.exists(self.old_train_path), "The old model weight file haven't been saved!"
        map_location = (lambda storage, loc: storage)
        ckpt = torch.load(self.old_train_path, map_location=map_location)
        model_load_state_dict(self.backbone_model, ckpt['mod_state_dicts'][0])
        old_weight = self.backbone_model.state_dict()
        scale *= -1
        distill_factor = distill_factor * \
            math.sin(1/2*math.pi*math.e**scale)
        # distill_factor = 0.6

        # for name in current_weight:
        #     if 'ca' not in name and 'sa' not in name and name != 'conv1.weight':
        #         current_weight.update({name: old_weight[name]+(current_weight[name]-old_weight[name])*distill_factor})
        # self.backbone_model.load_state_dict(current_weight)

        self.backbone_model.load_state_dict({name: old_weight[name] + (
            current_weight[name] - old_weight[name]) * distill_factor for name in current_weight})


    def generate_memory(self, update_memory_mean=False):
        '''
        NOTE:
        Generate memory in GPU, and save in cpu with numpy
        '''
        self._clear_current_memory()
        with torch.no_grad():
            while not self.epoch_done:
                self.step += 1
                # dataloader: enhanced_ims, enhanced_im_ids, mirrored, enhanced_im_names, im_enhanced_labels, epoch_done
                sample_info = self.data_loader.next_batch()
                ims, _, _, _, labels, self.epoch_done = sample_info
                ims_var = self.TVT(torch.from_numpy(ims).float())
                labels_var = self.TVT(torch.from_numpy(labels).long())
                self._update_memory(ims_var, labels_var)
            print('*'*5+'Updated the memory'+5*'*')
        if update_memory_mean:
            self._clear_record()
            with torch.no_grad():
                while not self.epoch_done:
                    self.step += 1
                    # dataloader: enhanced_ims, enhanced_im_ids, mirrored, enhanced_im_names, im_enhanced_labels, epoch_done
                    sample_info = self.data_loader.next_batch()
                    ims, _, _, _, labels, self.epoch_done = sample_info
                    ims_var = self.TVT(torch.from_numpy(ims).float())
                    labels_var = self.TVT(torch.from_numpy(labels).long())
                    self._update_cos_mean(ims_var, labels_var)
            print('*'*5+'Updated the memory mean'+5*'*')
        # To save the memory
        self._save_memory()
        print('---All memories are updated!---')
        self._clear_record()

    def get_R_matrxi(self, skip_current_task=True):
        start_time = time.time()
        self.R_dict.clear()
        # for multiprocess memory manager
        self._reset_memory_manager()
        # -------------------------- #
        # use original memory
        all_labels = torch.arange(len(self.initial_memory))
        task_label = self.data_loader.current_task_label
        if skip_current_task:
            self.initial_memory[task_label] = True
        trained_memory_indx = all_labels[self.initial_memory.data.cpu()==False]
        self.initial_memory[task_label] = False
        memory_indx = list(trained_memory_indx.numpy())
        choiced_memo_indx = memory_indx
        # choiced_memo_indx = random.sample(memory_indx, memory_number)
        trained_memory = self.memory[choiced_memo_indx]
        # ------------------------------------------------------------- #
        
        self.indx_memo_l.append(memory_indx)
        self.indx_memo_l.append(trained_memory.cpu())
        if self.all_process is None:
            self.all_process = [mp.Process(target=multi_calculate_R, args=(
                self.event, self.lock, self.memo_index_num, self.indx_memo_l, self.trans_feature_dict)) for _ in range(10)]
            for tmp_process in self.all_process:
                tmp_process.start()
        self.event.set()
        while True:
            if not self.event.is_set() and len(self.trans_feature_dict) == len(memory_indx):
                break
        self.R_dict = dict(self.trans_feature_dict)
        for key in self.R_dict:
            self.R_dict[key] = self.TVT(self.R_dict[key])
        print('Getting R Matrix Time: %.4f'%(time.time()-start_time))


    def train_extractor(self, task, train_mini_batch=1, epochs_per_task=200):
        '''
        NOTE:
        Train memory in GPU.
        '''
        # get old model to distill
        if self.start_epoch == 0:
            save_ckpt(self.backbone_module_optimize, task, self.old_train_path)
        if task > 0:
            map_location = (lambda storage, loc: storage)
            old_ckpt = torch.load(self.old_train_path, map_location=map_location)
            old_model = deepcopy(self.backbone_model)
            model_load_state_dict(old_model, old_ckpt['mod_state_dicts'][0])
            old_model.eval()
        else:
            old_model = None
            # old_class = None
        for i in range(self.start_epoch, epochs_per_task):
            self._clear_record()
            self.generate_memory()
            if old_model is not None:
                self.get_R_matrxi()
            self.backbone_mse_loss.clear()
            self.change_lr(self.backbone_optimizer, i+1, **self.lr_decay_kwargs)
            while not self.epoch_done:
                step_star_time = time.time()
                self.step += 1
                # dataloader: ims, im_ids, mirrored, im_names, im_labels, epoch_done
                sample_info = self.data_loader.next_batch()
                ims, _, _, _, labels, self.epoch_done = sample_info
                # if self.use_gsam:
                #     self.backbone_optimizer.set_closure(self._global_update_backbone, ims, labels, old_model=old_model)
                loss, f_loss, loss_list = self._global_update_backbone(ims, labels, old_model=old_model)
                
                if self.step % train_mini_batch == 0:
                    self.backbone_optimizer.step()
                    self.backbone_optimizer.zero_grad()

                # update recorder
                self.backbone_f_loss.update(to_scalar(f_loss))
                current_f_loss, class_loss, l_loss, cos_loss, mse_loss = loss_list
                self.backbone_l_loss.update(to_scalar(l_loss))
                self.backbone_mse_loss.update(to_scalar(mse_loss))

                # print log
                base_lr = self.backbone_optimizer.param_groups[0]['lr']
                if self.step % self.steps_per_log == 0 or self.epoch_done:
                    print('Step:%05d\tLoss:%.3f\tFLoss:%.3f\tCLoss:%.3f\tLloss:%.3f\tCosloss:%.3f\tMloss:%.3f\tBase_LR:%.1e\tTime:%.1fs' % (self.step, to_scalar(
                        loss), to_scalar(current_f_loss), to_scalar(class_loss), to_scalar(l_loss), to_scalar(cos_loss), to_scalar(mse_loss), base_lr, time.time()-step_star_time)) 
            
            # ==> save order file
            self.data_loader.dump_order_file()
            print('Epoch:%03d\tAvgFLoss:%.4f\tAvgMseLoss:%.4f\t' % (i+1, self.backbone_f_loss.avg, self.backbone_mse_loss.avg))
            # save model and optimizer parameters
            epoch_num = i+1
            save_ckpt(self.backbone_module_optimize, epoch_num, self.current_train_path)
            if (i+1)%self.epochs_per_val == 0 and self.epochs_per_val != -1:
                self.part_test(task)
        self._clear_record()


    def phase_train(self, train_mini_batch=1, **kwargs):
        mini_batch = train_mini_batch
        epochs_per_task = kwargs['epochs_per_task']
        update_scale = kwargs['update_scale']
        distill_factor = kwargs['distill_factor']
        self.data_loader.change_task(self.resume)
        self._train_init()
        for task in range(self.start_task, self.task_number):
            print('-'*20+'Task '+str(task+1)+' Start'+'-'*20)
            self._init_task()
            # change the data loader task data
            if self.data_loader.part == 'expert':
                print('-'*10+'Training Exactor Network'+'-'*10)
                self.train_extractor(task, mini_batch, epochs_per_task)
                self.start_epoch = 0
                self.generate_memory()
                self.start_epoch = 0
                # ==> save model
                save_ckpt(self.backbone_module_optimize, 0, self.current_train_path)
                self.part_test(task)
                # ==> save order file
                self.data_loader.part = 'meta'
                self.data_loader.dump_order_file()
                print('-'*10+'Model has been saved!'+'-'*10)
                self.start_epoch = 0

            print('-'*10+'Before Distill Testing task '+str(task+1)+'-'*10)
            self.test(task)
            self._clear_record()
            if task <= self.task_number-1:
                print('-'*10+'Distill & Save Backbone Model'+'-'*10)
                final_update_scale = update_scale*((0.8*(self.task_number-2-task)+1.0*task)/(self.task_number-2))
                # final_update_scale = 0.7
                print('Current Update Scale: '+str(final_update_scale))
                self.distill_model(final_update_scale, distill_factor)
                print('-'*10+'Generating Memory'+'-'*10)
                self.generate_memory(update_memory_mean=True)
                self.start_epoch = 0
                # ==> save distilled model
                save_ckpt(self.backbone_module_optimize, 0, self.current_train_path)

                self.part_test(task)
                print('-'*10+'After Distill Testing task '+str(task+1)+'-'*10)
                self.test(task)
                print('-'*10+'Distilled model has been saved!'+'-'*10)
            self.resume = False if self.resume else self.resume
            self.data_loader.part = 'expert'
            self.data_loader.change_task(self.resume)
            self.data_loader.dump_order_file()
            self.backbone_mse_loss.clear()
        # protect main process
        if self.all_process is not None:
            for tmp_process in self.all_process:
                tmp_process.terminate()


    @torch.no_grad()
    def part_test(self, current_task):
        _indx = self.TVT(torch.arange(len(self.memory)))
        trained_indx = _indx[self.initial_memory==False]
        trained_memory = self.memory[self.initial_memory==False]
        # current_task_label is list
        task_label = self.data_loader.current_task_label
        task_index = self.TVT(torch.tensor(task_label).long())
        task_memory = self.memory[task_index]
        # return feats, used_ids, labels, im_names
        feats, labels = self.tester.test(current_task)
        # print(labels)
        global_f_m_affi = calculate_cos_affinity(feats, trained_memory, smooth=False)
        local_f_m_affi = calculate_cos_affinity(feats, task_memory, smooth=False)
        _, global_index = torch.max(global_f_m_affi, dim=-1)
        _, local_index = torch.max(local_f_m_affi, dim=-1)
        global_predict = trained_indx[global_index]
        local_predict = task_index[local_index]
        labels = self.TVT(labels)
        global_true_number = len(labels[labels==global_predict])
        local_true_number = len(labels[labels==local_predict])
        global_acc = global_true_number/len(labels)
        local_acc = local_true_number/len(labels)
        print('The task '+str(current_task+1)+': Global memory acc is %.4f, Local memory acc is %.4f' % (global_acc, local_acc))
        print('*'*10+'Part Test Finished'+'*'*10)

    @torch.no_grad()
    def test(self, current_task=0):
        true_number, whole_number, whole_acc = 0, 0, 0
        m_true_number,m_whole_number, m_whole_acc = 0, 0, 0
        _indx = self.TVT(torch.arange(len(self.memory)))
        trained_indx = _indx[self.initial_memory==False]
        trained_memory = self.memory[self.initial_memory==False]
        ### test 

        for task in range(current_task+1):
            # return feats, used_ids, labels, im_names
            feats, labels = self.tester.test(task)
            labels = self.TVT(labels)
            #### test in local task
            start = 0 if task == 0 else self.init_class_number+(task-1)*self.phase
            end = self.init_class_number+task*self.phase

            tmp_memory = trained_memory[start:end]
            tmp_index = trained_indx[start:end]
            tmp_affi = calculate_cos_affinity(feats, tmp_memory, smooth=False)
            _, _index = torch.max(tmp_affi, dim=-1)
            tmp_predict = tmp_index[_index]
            tmp_true_number = len(labels[labels==tmp_predict])
            tmp_acc = tmp_true_number/len(labels)
            print('The acc of task '+str(task+1)+' in task '+str(task+1)+' is %.4f' % (tmp_acc))            
            #### test in global task
            old_classfier_status = self.classfier.training
            self.classfier.eval()
            f_m_affi = self.classfier(feats)
            self.classfier.train(old_classfier_status)
            _, index = torch.max(f_m_affi, dim=-1)
            predict = trained_indx[index]

            _true_number = len(labels[labels==predict])
            true_number += _true_number
            whole_number += len(labels)
            task_acc = _true_number/len(labels)
            whole_acc += task_acc

            f_m_affi = calculate_cos_affinity(feats, trained_memory, smooth=False)
            _, index = torch.max(f_m_affi, dim=-1)
            predict = trained_indx[index]

            _true_number = len(labels[labels==predict])
            m_true_number += _true_number
            m_whole_number += len(labels)
            m_task_acc = _true_number/len(labels)
            m_whole_acc += m_task_acc
            print('Use Memory: The acc of task '+str(task+1)+' in task '+str(current_task+1)+' is %.4f' % (m_task_acc))
            print('Use Classifier: The acc of task '+str(task+1)+' in task '+str(current_task+1)+' is %.4f' % (task_acc))
        avg_acc = whole_acc/(current_task+1)
        global_acc = true_number/whole_number
        m_avg_acc = m_whole_acc/(current_task+1)
        m_global_acc = m_true_number/m_whole_number
        print('In Classifier: The average acc is %.4f, the global acc is %.4f' % (avg_acc, global_acc))
        print('In Memory: The average acc is %.4f, the global acc is %.4f' % (m_avg_acc, m_global_acc))
        print('*'*10+'Task Test Finished'+'*'*10)
