# coding=utf-8
import re
import numpy as np
import torch
import faiss
import torch.nn as nn
import sklearn.model_selection as ms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from scipy.spatial.distance import cdist
from sklearn.metrics import confusion_matrix
from sklearn.neighbors import KNeighborsClassifier 

import datautil.imgdata.util as imgutil
from datautil.imgdata.imgdataload import ImageDataset
from datautil.mydataloader import InfiniteDataLoader
from utils.util import log_print, train_valid_target_eval_names
import Replay.utils as RPutils

def assign_pseudo_label(args, dataloader, replay_dataset, taskid, model, epoch, cur=False):
    '''
    Use old model's prediction to assign pseudo labels to new domains data. Only chooses data whose confident larger then args.pseduo_tau
    '''
    # progressly change pseudo tau
    pseudo_tau = args.pseudo_tau * (1- epoch/(args.max_epoch*args.alpha_tau))

    if taskid == 0 or args.pLabelAlg == 'ground':
        return dataloader, None

    # if (taskid == 0 or args.pLabelAlg == 'ground') and not cur:
    #     return dataloader, None
    
    # if args.pLabelAlg == 'ground' or (taskid == 0 and not cur):
    #     return dataloader, None
    
    else:
        image_dict, clabel, dlabel = dataloader.dataset.get_raw_data()
        images = [dataloader.dataset.loader(dict) for dict in image_dict]       # list of PIL image

        pseudo_image_dict = []
        pseudo_clabel = []
        pseudo_dlabel = []

        curr_dataset = RPutils.ReplayDataset(images, clabel, dlabel, transform=imgutil.image_test(args))
        curr_dataloader = DataLoader(dataset=curr_dataset,
                                        shuffle=False,
                                        batch_size=args.batch_size,
                                        num_workers=args.N_WORKERS)
        model.eval().cuda()     # if not use this, it will affect evaluation steps after this evaluation, even they call model.eval().

        if args.pLabelAlg in ['SHOT', 'SHOT_PCL']:
            pseudo_clabel, pacc_dict = SHOT_label(args, curr_dataloader, model, pseudo_tau)
            pseudo_image_dict = image_dict
            pseudo_dlabel = dlabel
        
        elif args.pLabelAlg in ['BMD', 'knn']:
            pseudo_clabel, pacc_dict = BMD_label(args, curr_dataloader, model)
            pseudo_image_dict = image_dict
            pseudo_dlabel = dlabel
            
        elif args.pLabelAlg in ['topkSHOTknn']:
            pseudo_clabel, pacc_dict, bool_index = topkSHOTknn_label(args, curr_dataloader, model, pseudo_tau)
            for i, v in enumerate(bool_index):
                if v:
                    pseudo_image_dict.append(image_dict[i])
                    pseudo_dlabel.append(dlabel[i])
        
        elif args.pLabelAlg in ['MSHOT', 'mknn', 'MSHOTknn', 'knnMSHOT']:
            rimages, rclabel, rdlabel = replay_dataset.get_raw_data()    # transform is image_train
            replay_dataset = RPutils.ReplayDataset(rimages, rclabel, rdlabel, transform=imgutil.image_test(args))
            replay_dataloader = DataLoader(dataset=replay_dataset,
                                        shuffle=False,
                                        batch_size=args.batch_size,
                                        num_workers=args.N_WORKERS)
            pseudo_clabel, pacc_dict, bool_index = MSHOT_label(args, curr_dataloader, replay_dataloader, model, pseudo_tau)
            for i, v in enumerate(bool_index):
                if v:
                    pseudo_image_dict.append(image_dict[i])
                    pseudo_dlabel.append(dlabel[i])

        else:
            softmax = nn.Softmax(dim=1)
            correct = 0
            for i, data in enumerate(curr_dataloader):
                x = data[0].cuda()
                with torch.no_grad():
                    if args.pLabelAlg == 'softmax': 
                        logist = softmax(model(x))
                        pred = torch.argmax(logist, dim=1)

                    elif args.pLabelAlg == 'proxy':     # Only used for PCL algorithm. Assign pseudo label based on simarilty between features and proxy.
                        x = model.featurizer(x)
                        x = model.encoder(x)
                        feature_proj = F.normalize(model.fea_proj(x))   # (N, dim)
                        proxy_proj = F.normalize(F.linear(model.classifier, model.fc_proj))  # (C, dim)

                        score = torch.matmul(feature_proj, proxy_proj.T) * args.pseudo_scale   # (N, C)
                        logist = softmax(score)
                        pred = torch.argmax(logist, dim=1)
                    
                    elif args.pLabelAlg in ['softmax_proxy', 'proxy_softmax']:
                        classifier_logist = softmax(model(x))
                        classifier_pred = torch.argmax(classifier_logist, dim=1)

                        x = model.featurizer(x)
                        x = model.encoder(x)
                        feature_proj = F.normalize(model.fea_proj(x))   # (N, dim)
                        proxy_proj = F.normalize(F.linear(model.classifier, model.fc_proj))  # (C, dim)

                        score = torch.matmul(feature_proj, proxy_proj.T) * args.pseudo_scale   # (N, C)
                        proxy_logist = softmax(score)
                        proxy_pred = torch.argmax(proxy_logist, dim=1)

                if args.pLabelAlg in ['softmax', 'proxy']:
                    # select data whose prediction confidient larger than args.pseduo_tau
                    for j,idx in enumerate(pred):
                        if logist[j,idx] >= pseudo_tau:
                            pseudo_image_dict.append(image_dict[i*args.batch_size+j])
                            pseudo_clabel.append(idx.item())
                            pseudo_dlabel.append(taskid)

                            if idx.item() == data[1][j]:
                                correct += 1
                    
                elif args.pLabelAlg in ['softmax_proxy', 'proxy_softmax']:
                    for j,idx in enumerate(classifier_pred):
                        if args.pLabelAlg == 'softmax_proxy':
                            logist = classifier_logist
                        elif args.pLabelAlg == 'proxy_softmax':
                            logist = proxy_logist
                        if idx == proxy_pred[j] and logist[j,idx] >= pseudo_tau:
                            pseudo_image_dict.append(image_dict[i*args.batch_size+j])
                            pseudo_clabel.append(idx.item())
                            pseudo_dlabel.append(taskid)

                            if idx.item() == data[1][j]:
                                correct += 1
            pacc_dict = {'pc0': round(correct / len(pseudo_clabel),3)}                    
        
        if args.targetAlg == 'MFPL':   # compute the replay center
            rimages, rclabel, rdlabel = replay_dataset.get_raw_data()    # transform is image_train
            replay_dataset = RPutils.ReplayDataset(rimages, rclabel, rdlabel, transform=imgutil.image_test(args))
            replay_dataloader = DataLoader(dataset=replay_dataset,
                                        shuffle=False,
                                        batch_size=args.batch_size,
                                        num_workers=args.N_WORKERS)
            model.replay_proxy = compute_replay_center(args, model, replay_dataloader)
               
        
        model.train()
        # if cur:
        #     log_print('assign on current domain: pseudo labels size: {}  correct size {}'.format(len(pseudo_image_dict), correct), args.log_file)
        # elif epoch == 0:
        #     print('first assign: pseudo labels size: {}  correct size {}'.format(len(pseudo_image_dict), correct))
        # else:
        #     log_print('pseudo labels size: {}  correct size {}'.format(len(pseudo_image_dict), correct), args.log_file, p=False )
        pseudo_dataset = utilDataset(pseudo_image_dict, np.array(pseudo_clabel), np.array(pseudo_dlabel), loader=dataloader.dataset.loader, transform=imgutil.image_train(args))
        pseudo_dataloader = InfiniteDataLoader(dataset=pseudo_dataset, weights=None, batch_size=args.batch_size, num_workers=args.N_WORKERS)
        # pseudo_dataloader = DataLoader(dataset=pseudo_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.N_WORKERS)

        return pseudo_dataloader, pacc_dict #{'ps':len(pseudo_image_dict), 'pc':correct}

def compute_replay_center(args, model, replay_loader):
    start_test = True
    with torch.no_grad():
        for i, data in enumerate(replay_loader):
            inputs = data[0].cuda()
            labels = data[1]

            if args.targetAlg in ['ERM', 'supcon', 'ERMPCL']:
                feas = model.featurizer(inputs)
                outputs = model.classifier(feas)
            elif args.targetAlg in ['PCL','PCL2', 'SupPCL', 'SupPCL2', 'FP']:
                if args.pLabelAlg in ['SHOT', 'MSHOT', 'mknn', 'MSHOTknn', 'knnMSHOT']:
                    feas = model.encoder(model.featurizer(inputs))
                    outputs = F.linear(feas, model.classifier)
                elif args.pLabelAlg == 'SHOT_PCL':
                    features = model.encoder(model.featurizer(inputs))
                    feas = model.fea_proj(features)                     # (N, dim)
                    proxy = F.linear(model.classifier, model.fc_proj)   # (C, dim)
                    outputs = F.linear(feas, proxy)
            elif args.targetAlg in ['ERM_bot']:
                feas = model.bottleneck(model.featurizer(inputs))
                outputs = model.classifier(feas)

            if start_test:
                replay_fea = [feas.float().cpu()]
                replay_output = [outputs.float().cpu()]
                replay_label = [labels.float()]
                start_test = False
            else:
                replay_fea.append(feas.float().cpu()) #= torch.cat((replay_fea, feas.float().cpu()), 0)
                replay_output.append(outputs.float().cpu()) #= torch.cat((replay_output, outputs.float().cpu()), 0)
                replay_label.append(labels.float()) #= torch.cat((replay_label, labels.float()), 0)
        replay_fea = torch.cat(replay_fea, dim=0)
        replay_output = torch.cat(replay_output, dim=0)
        replay_label = torch.cat(replay_label, dim=0)

        # compute replay data class center
        replay_fea, replay_output, replay_label = replay_fea.numpy(), replay_output.numpy(), replay_label.numpy().astype('int')
        onehot_replay_label = np.eye(replay_output.shape[1])[replay_label]   # (N, C)
        replay_center = onehot_replay_label.transpose().dot(replay_fea)     # (C, dim)
        replay_center = replay_center / (1e-8 + onehot_replay_label.sum(axis=0)[:,None])   # (C, dim) / (C, 1)
        replay_center = torch.from_numpy(replay_center).to(torch.float).cuda()
    return replay_center


def SHOT_label(args, loader, model, pseudo_tau):
    start_test = True
    all_fea = []
    all_output = []
    all_label = []
    with torch.no_grad():
        for i, data in enumerate(loader):
            inputs = data[0].cuda()
            labels = data[1]

            if args.targetAlg in ['ERM', 'supcon', 'ERMPCL']:
                feas = model.featurizer(inputs)
                outputs = model.classifier(feas)
            elif args.targetAlg in ['PCL','PCL2', 'SupPCL', 'SupPCL2', 'FP']:
                if args.pLabelAlg == 'SHOT':
                    feas = model.encoder(model.featurizer(inputs))
                    outputs = F.linear(feas, model.classifier)
                elif args.pLabelAlg == 'SHOT_PCL':
                    features = model.encoder(model.featurizer(inputs))
                    feas = model.fea_proj(features)                     # (N, dim)
                    proxy = F.linear(model.classifier, model.fc_proj)   # (C, dim)
                    outputs = F.linear(feas, proxy)
            elif args.targetAlg in ['ERM_bot']:
                feas = model.bottleneck(model.featurizer(inputs))
                outputs = model.classifier(feas)

            all_fea.append(feas.float().cpu())
            all_output.append(outputs.float().cpu())
            all_label.append(labels.float())

    all_fea = torch.cat(all_fea)
    all_output = torch.cat(all_output)
    all_label = torch.cat(all_label)

    all_output = nn.Softmax(dim=1)(all_output)
    epsilon = 1e-5
    ent = torch.sum(-all_output * torch.log(all_output + epsilon), dim=1)
    unknown_weight = 1 - ent / np.log(args.num_classes)
    _, predict = torch.max(all_output, 1)

    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    distance = 'cosine'
    if distance == 'cosine':
        all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)   # (N, dim+1)
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()       # (N, dim+1) / norm(N, dim+1)

    all_fea = all_fea.float().cpu().numpy()   # (N, dim+1)
    K = all_output.size(1)
    aff = all_output.float().cpu().numpy()    # (N, C)

    for _ in range(2):
        initc = aff.transpose().dot(all_fea)  # (C, dim+1) molecule of equation (4)
        initc = initc / (1e-8 + aff.sum(axis=0)[:,None])  # (C, dim+1) / (C, 1)   cluster center of each class
        cls_count = np.eye(K)[predict].sum(axis=0)   # (1, C) each element representation the number of prediction of that class
        labelset = np.where(cls_count>0)    
        labelset = labelset[0]    # 1D index of which class has been assign pseudo label. 

        dd = cdist(all_fea, initc[labelset], distance)   # (N,C) distance of each features and cluster center
        pred_label = dd.argmin(axis=1)   # (N)
        predict = labelset[pred_label]   # (N)

        aff = np.eye(K)[predict]         # one-hot (N, C)

    acc = np.sum(predict == all_label.float().numpy()) / len(all_fea)
    log_print('Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100), args.log_file, p=False)

    return predict.astype('int'), {'pa0':round(accuracy,3), 'pa1':round(acc,3)}



def topkSHOTknn_label(args, loader, model, pseudo_tau):
    start_test = True
    with torch.no_grad():
        for i, data in enumerate(loader):
            inputs = data[0].cuda()
            labels = data[1]

            if args.targetAlg in ['ERM', 'supcon', 'ERMPCL']:
                feas = model.featurizer(inputs)
                outputs = model.classifier(feas)
            elif args.targetAlg in ['PCL','PCL2', 'SupPCL', 'SupPCL2', 'FP', 'MFP']:
                feas = model.encoder(model.featurizer(inputs))
                outputs = F.linear(feas, model.classifier)
            elif args.targetAlg in ['ERM_bot']:
                feas = model.bottleneck(model.featurizer(inputs))
                outputs = model.classifier(feas)

            if start_test:
                all_fea = [feas.float().cpu()]
                all_output = [outputs.float().cpu()]
                all_label = [labels.float()]
                start_test = False
            else:
                all_fea.append(feas.float().cpu()) 
                all_output.append(outputs.float().cpu())
                all_label.append(labels.float()) 
    all_fea = torch.cat(all_fea, dim=0)
    all_output = torch.cat(all_output, dim=0)
    all_label = torch.cat(all_label, dim=0)
    
    all_output = nn.Softmax(dim=1)(all_output)
    ov, idx = torch.max(all_output, 1)
    bool_index = ov > pseudo_tau
    all_output = all_output[bool_index]
    all_fea = all_fea[bool_index]
    all_label = all_label[bool_index]
    
    acc_list = []
    
    # softmax predict
    _, predict = torch.max(all_output, 1)
    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    acc_list.append(accuracy)
    
    all_fea = all_fea / torch.norm(all_fea, p=2, dim=1, keepdim=True)
    
    all_fea = all_fea.float().cpu()  # (N, dim)
    K = all_output.size(1)
    aff = all_output.float().cpu()   # (N, C)
    
    # top k features for SHOT
    topk_num = max(all_fea.shape[0] // (args.num_classes * args.topk_alpha2), 1) 
    top_aff, top_fea = [], []
        
    for cls_idx in range(args.num_classes):
        feat_samp_idx = torch.topk(aff[:, cls_idx], topk_num)[1]                
        top_fea.append(all_fea[feat_samp_idx, :])        
        top_aff.append(aff[feat_samp_idx, :])
        
    top_aff = torch.cat(top_aff, dim=0).numpy()
    top_fea = torch.cat(top_fea, dim=0).numpy()
    _, top_predict = torch.max(torch.from_numpy(top_aff), 1)
    
    # SHOT      
    for _ in range(args.SHOT_step):
        initc = top_aff.transpose().dot(top_fea)  # (C, dim+1) molecule of equation (4)
        initc = initc / (1e-8 + top_aff.sum(axis=0)[:,None])  # (C, dim) / (C, 1)   cluster center of each class

        cls_count = np.eye(K)[predict].sum(axis=0)   # (1, C) each element representation the number of prediction of that class
        labelset = np.where(cls_count>0)    
        labelset = labelset[0]    # 1D index of which class has been assign pseudo label. 

        dd = cdist(all_fea, initc[labelset], args.distance)   # (N,C) distance of each features and cluster center
        pred_label = dd.argmin(axis=1)   # (N)
        predict = labelset[pred_label]   # (N)
        
        top_cls_count = np.eye(K)[top_predict].sum(axis=0)   # (1, C) each element representation the number of prediction of that class
        top_labelset = np.where(top_cls_count>0)    
        top_labelset = top_labelset[0]    # 1D index of which class has been assign pseudo label. 

        top_dd = cdist(top_fea, initc[top_labelset], args.distance)   # (N,C) distance of each features and cluster center
        top_pred_label = top_dd.argmin(axis=1)   # (N)
        top_predict = top_labelset[top_pred_label]   # (N)

        top_aff = np.eye(K)[top_predict]         # one-hot (N, C)
        acc_list.append(np.sum(predict == all_label.float().numpy()) / len(all_fea))
        
    # knn on distance of each features and cluster center
    top_sample = []
    top_label = []
    topk_fit_num = max(all_fea.shape[0] // (args.num_classes * args.topk_alpha2), 1)
    topk_num = max(all_fea.shape[0] // (args.num_classes * args.topk_alpha), 1)
    
    # for cls_idx in range(args.num_classes):
    for cls_idx in range(len(labelset)):     # in some case it doesn't have predict of one class
        feat_samp_idx = torch.topk(torch.from_numpy(dd)[:, cls_idx], topk_fit_num, largest=False )[1]
            
        feat_cls_sample = all_fea[feat_samp_idx, :]
        feat_cls_label = torch.zeros([len(feat_samp_idx)]).fill_(cls_idx)

        top_sample.append(feat_cls_sample)
        top_label.append(feat_cls_label)
    top_sample = torch.cat(top_sample, dim=0).cpu().numpy()
    top_label = torch.cat(top_label, dim=0).cpu().numpy()

    knn = KNeighborsClassifier(n_neighbors=topk_num)
    knn.fit(top_sample, top_label)
    
    knn_predict = knn.predict(all_fea.cpu().numpy()).tolist()
    knn_predict = [int(i) for i in knn_predict]
    
    predict = labelset[knn_predict]
    acc_list.append(np.sum(predict == all_label.float().numpy()) / len(all_fea))
        
    log_print("acc:" + " --> ".join("{:.3f}".format(acc) for acc in acc_list), args.log_file, p=False)
    acc_dict = {}
    for i in range(len(acc_list)):
        acc_dict['pa{}'.format(i)] = round(acc_list[i],3)

    return predict.astype('int'), acc_dict, bool_index


# --------------------------------------------------------------------------------------------------------------- #
# --------------------------------------------------------------------------------------------------------------- #


# --------------------------------------------------------------------------------------------------------------- #
# --------------------------------------------------------------------------------------------------------------- #

def MSHOT_label(args, loader, replay_loader, model, pseudo_tau):
    start_test = True
    with torch.no_grad():
        for i, data in enumerate(replay_loader):
            inputs = data[0].cuda()
            labels = data[1]

            if args.targetAlg in ['ERM', 'supcon', 'ERMPCL']:
                feas = model.featurizer(inputs)
                outputs = model.classifier(feas)
            elif args.targetAlg in ['PCL','PCL2', 'SupPCL', 'SupPCL2', 'FP']:
                if args.pLabelAlg in ['SHOT', 'MSHOT', 'mknn', 'MSHOTknn', 'knnMSHOT']:
                    feas = model.encoder(model.featurizer(inputs))
                    outputs = F.linear(feas, model.classifier)
                elif args.pLabelAlg == 'SHOT_PCL':
                    features = model.encoder(model.featurizer(inputs))
                    feas = model.fea_proj(features)                     # (N, dim)
                    proxy = F.linear(model.classifier, model.fc_proj)   # (C, dim)
                    outputs = F.linear(feas, proxy)
            elif args.targetAlg in ['ERM_bot']:
                feas = model.bottleneck(model.featurizer(inputs))
                outputs = model.classifier(feas)

            if start_test:
                replay_fea = [feas.float().cpu()]
                replay_output = [outputs.float().cpu()]
                replay_label = [labels.float()]
                start_test = False
            else:
                replay_fea.append(feas.float().cpu()) #= torch.cat((replay_fea, feas.float().cpu()), 0)
                replay_output.append(outputs.float().cpu()) #= torch.cat((replay_output, outputs.float().cpu()), 0)
                replay_label.append(labels.float()) #= torch.cat((replay_label, labels.float()), 0)
        replay_fea = torch.cat(replay_fea, dim=0)
        replay_output = torch.cat(replay_output, dim=0)
        replay_label = torch.cat(replay_label, dim=0)

        # compute replay data class center
        replay_fea, replay_output, replay_label = replay_fea.numpy(), replay_output.numpy(), replay_label.numpy().astype('int')
        onehot_replay_label = np.eye(replay_output.shape[1])[replay_label]   # (N, C)
        replay_center = onehot_replay_label.transpose().dot(replay_fea)     # (C, dim)
        replay_center = replay_center / (1e-8 + onehot_replay_label.sum(axis=0)[:,None])   # (C, dim) / (C, 1)
        replay_center = torch.from_numpy(replay_center).to(torch.float).cuda()
        model.replay_center = replay_center

        start_test = True
        for i, data in enumerate(loader):
            inputs = data[0].cuda()
            labels = data[1]

            if args.targetAlg in ['ERM', 'supcon', 'ERMPCL']:
                feas = model.featurizer(inputs)
                outputs = model.classifier(feas)
                re_outputs = F.linear(feas, replay_center)
            elif args.targetAlg in ['PCL','PCL2', 'SupPCL', 'SupPCL2']:
                if args.pLabelAlg in ['SHOT', 'MSHOT', 'mknn', 'MSHOTknn', 'knnMSHOT']:
                    feas = model.encoder(model.featurizer(inputs))
                    outputs = F.linear(feas, model.classifier)
                    re_outputs = F.linear(feas, replay_center)
                elif args.pLabelAlg == 'SHOT_PCL':
                    features = model.encoder(model.featurizer(inputs))
                    feas = model.fea_proj(features)                     # (N, dim)
                    proxy = F.linear(model.classifier, model.fc_proj)   # (C, dim)
                    outputs = F.linear(feas, proxy)
                    re_outputs = F.linear(feas, replay_center)
            elif args.targetAlg in ['ERM_bot']:
                feas = model.bottleneck(model.featurizer(inputs))
                outputs = model.classifier(feas)
                re_outputs = F.linear(feas, replay_center)

            if start_test:
                all_fea = [feas.float().cpu()]
                all_output = [outputs.float().cpu()]
                allre_output = [re_outputs.float().cpu()]
                all_label = [labels.float()]
                start_test = False
            else:
                all_fea.append(feas.float().cpu()) #= torch.cat((all_fea, feas.float().cpu()), 0)
                all_output.append(outputs.float().cpu()) #= torch.cat((all_output, outputs.float().cpu()), 0)
                allre_output.append(re_outputs.float().cpu()) #= torch.cat((allre_output, re_outputs.float().cpu()), 0)
                all_label.append(labels.float()) #= torch.cat((all_label, labels.float()), 0)
        all_fea = torch.cat(all_fea, dim=0)
        all_output = torch.cat(all_output, dim=0)
        allre_output = torch.cat(allre_output, dim=0)
        all_label = torch.cat(all_label, dim=0)
    
    # epsilon = 1e-5
    # ent = torch.sum(-all_output * torch.log(all_output + epsilon), dim=1)
    # unknown_weight = 1 - ent / np.log(args.num_classes)
    acc_list = []
    _, predict = torch.max(all_output, 1)
    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    acc_list.append(accuracy)
    
    if args.pLabelAlg in ['MSHOT']:
        all_output = nn.Softmax(dim=1)(all_output)
        allre_output = nn.Softmax(dim=1)(allre_output)
        # all_output = args.MSHOT_tau * all_output + (1-args.MSHOT_tau) * allre_output    # curr and replay

        ov, idx = torch.max(all_output, 1)
        bool_index = ov > pseudo_tau
        all_output = all_output[bool_index]
        allre_output = allre_output[bool_index]
        all_fea = all_fea[bool_index]
        all_label = all_label[bool_index]
        
        if args.distance == 'cosine':
            all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)   # (N, dim+1)
            all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()       # (N, dim+1) / norm(N, dim+1)

        all_fea = all_fea.float().cpu().numpy()   # (N, dim+1)
        K = all_output.size(1)
        aff = all_output.float().cpu().numpy()    # (N, C)
        re_aff = allre_output.float().cpu().numpy()

        re_initc = re_aff.transpose().dot(all_fea)
        re_initc = re_initc / (1e-8 + re_aff.sum(axis=0)[:,None])
        for _ in range(args.SHOT_step):
            initc = aff.transpose().dot(all_fea)  # (C, dim+1) molecule of equation (4)
            initc = initc / (1e-8 + aff.sum(axis=0)[:,None])  # (C, dim+1) / (C, 1)   cluster center of each class
            initc = args.MSHOT_tau * initc + (1-args.MSHOT_tau) * re_initc

            cls_count = np.eye(K)[predict].sum(axis=0)   # (1, C) each element representation the number of prediction of that class
            labelset = np.where(cls_count>0)    
            labelset = labelset[0]    # 1D index of which class has been assign pseudo label. 

            dd = cdist(all_fea, initc[labelset], args.distance)   # (N,C) distance of each features and cluster center
            pred_label = dd.argmin(axis=1)   # (N)
            predict = labelset[pred_label]   # (N)

            aff = np.eye(K)[predict]         # one-hot (N, C)
            
    elif args.pLabelAlg in ['MSHOTknn']:
        all_output = nn.Softmax(dim=1)(all_output)
        allre_output = nn.Softmax(dim=1)(allre_output)
        # all_output = args.MSHOT_tau * all_output + (1-args.MSHOT_tau) * allre_output    # curr and replay

        ov, idx = torch.max(all_output, 1)
        bool_index = ov > pseudo_tau
        all_output = all_output[bool_index]
        allre_output = allre_output[bool_index]
        all_fea = all_fea[bool_index]
        all_label = all_label[bool_index]
        
        all_fea = all_fea / torch.norm(all_fea, p=2, dim=1, keepdim=True)
        
        all_fea = all_fea.float().cpu().numpy()   # (N, dim)
        K = all_output.size(1)
        aff = all_output.float().cpu().numpy()    # (N, C)
        re_aff = allre_output.float().cpu().numpy()

        re_initc = re_aff.transpose().dot(all_fea)
        re_initc = re_initc / (1e-8 + re_aff.sum(axis=0)[:,None])
        for _ in range(args.SHOT_step):
            initc = aff.transpose().dot(all_fea)  # (C, dim+1) molecule of equation (4)
            initc = initc / (1e-8 + aff.sum(axis=0)[:,None])  # (C, dim) / (C, 1)   cluster center of each class
            initc = args.MSHOT_tau * initc + (1-args.MSHOT_tau) * re_initc

            cls_count = np.eye(K)[predict].sum(axis=0)   # (1, C) each element representation the number of prediction of that class
            labelset = np.where(cls_count>0)    
            labelset = labelset[0]    # 1D index of which class has been assign pseudo label. 

            dd = cdist(all_fea, initc[labelset], args.distance)   # (N,C) distance of each features and cluster center
            pred_label = dd.argmin(axis=1)   # (N)
            predict = labelset[pred_label]   # (N)

            aff = np.eye(K)[predict]         # one-hot (N, C)
        
        acc_list.append(np.sum(predict == all_label.float().numpy()) / len(all_fea))
        
        # knn on distance of each features and cluster center
        top_sample = []
        top_label = []
        topk_num = max(all_fea.shape[0] // (args.num_classes * args.topk_alpha), 1)
        all_fea = torch.from_numpy(all_fea).float()
        initc = torch.from_numpy(initc).float()
        all_output = torch.matmul(all_fea, initc.T)    # use SHOT center to compute output
        if args.knn_softmax:
            all_output = nn.Softmax(dim=1)(all_output)
        
        for cls_idx in range(args.num_classes):
            feat_samp_idx = torch.topk(all_output[:, cls_idx], topk_num)[1]
                
            feat_cls_sample = all_fea[feat_samp_idx, :]
            feat_cls_label = torch.zeros([len(feat_samp_idx)]).fill_(cls_idx)

            top_sample.append(feat_cls_sample)
            top_label.append(feat_cls_label)
        top_sample = torch.cat(top_sample, dim=0).cpu().numpy()
        top_label = torch.cat(top_label, dim=0).cpu().numpy()

        knn = KNeighborsClassifier(n_neighbors=topk_num)
        knn.fit(top_sample, top_label)

        predict = knn.predict(all_fea.cpu().numpy())
        
    elif args.pLabelAlg in ['knnMSHOT']:
        all_output = nn.Softmax(dim=1)(all_output)
        allre_output = nn.Softmax(dim=1)(allre_output)
        # all_output = args.MSHOT_tau * all_output + (1-args.MSHOT_tau) * allre_output    # curr and replay

        ov, idx = torch.max(all_output, 1)
        bool_index = ov > pseudo_tau
        all_output = all_output[bool_index]
        allre_output = allre_output[bool_index]
        all_fea = all_fea[bool_index]
        all_label = all_label[bool_index]
        
        all_fea = all_fea / torch.norm(all_fea, p=2, dim=1, keepdim=True)
        
        all_fea = all_fea.float().cpu()  # (N, dim)
        K = all_output.size(1)
        aff = all_output.float().cpu()   # (N, C)
        re_aff = allre_output.float().cpu()
        
        # knn initial center 1
        topk_num = max(all_fea.shape[0] // (args.num_classes * args.topk_alpha), 1) 
        top_aff, re_top_aff = [], []
        top_fea, re_top_fea = [], []
        
        for cls_idx in range(args.num_classes):
            feat_samp_idx = torch.topk(aff[:, cls_idx], topk_num)[1]
            re_feat_samp_idx = torch.topk(re_aff[:, cls_idx], topk_num)[1]
                  
            top_fea.append(all_fea[feat_samp_idx, :])
            re_top_fea.append(all_fea[re_feat_samp_idx, :])
            
            top_aff.append(aff[feat_samp_idx, :])
            re_top_aff.append(re_aff[feat_samp_idx, :])
            
        top_aff = torch.cat(top_aff, dim=0).numpy()
        re_top_aff = torch.cat(re_top_aff, dim=0).numpy()
        top_fea = torch.cat(top_fea, dim=0).numpy()
        re_top_fea = torch.cat(re_top_fea, dim=0).numpy()
            
        re_initc = top_aff.transpose().dot(top_fea)
        re_initc = re_initc / (1e-8 + top_aff.sum(axis=0)[:,None])
        for _ in range(args.SHOT_step):
            initc = top_aff.transpose().dot(top_fea)  # (C, dim+1) molecule of equation (4)
            initc = initc / (1e-8 + top_aff.sum(axis=0)[:,None])  # (C, dim) / (C, 1)   cluster center of each class
            initc = args.MSHOT_tau * initc + (1-args.MSHOT_tau) * re_initc

            cls_count = np.eye(K)[predict].sum(axis=0)   # (1, C) each element representation the number of prediction of that class
            labelset = np.where(cls_count>0)    
            labelset = labelset[0]    # 1D index of which class has been assign pseudo label. 

            dd = cdist(all_fea, initc[labelset], args.distance)   # (N,C) distance of each features and cluster center
            pred_label = dd.argmin(axis=1)   # (N)
            predict = labelset[pred_label]   # (N)

            top_aff = np.eye(K)[predict]         # one-hot (N, C)
        
        # knn initial center 2
        # topk_num = max(all_fea.shape[0] // (args.num_classes * args.topk_alpha), 1)
        # top_fea, re_top_fea = [], []
        # top_aff, re_top_aff = [], []
        # initc, re_initc = [], []
        
        # for cls_idx in range(args.num_classes):
        #     feat_samp_idx = torch.topk(aff[:, cls_idx], topk_num)[1]
        #     re_feat_samp_idx = torch.topk(re_aff[:, cls_idx], topk_num)[1]
                  
        #     feat_cls_sample = all_fea[feat_samp_idx, :]
        #     re_feat_cls_sample = all_fea[re_feat_samp_idx, :]
            
        #     initc.append(feat_cls_sample.mean(dim=0).unsqueeze(0))
        #     re_initc.append(re_feat_cls_sample.mean(dim=0).unsqueeze(0))
        
        # initc = torch.cat(initc, dim=0).numpy()
        # re_initc = torch.cat(re_initc, dim=0).numpy()
        
        # initc = args.MSHOT_tau * initc + (1-args.MSHOT_tau) * re_initc

        # cls_count = np.eye(K)[predict].sum(axis=0)   # (1, C) each element representation the number of prediction of that class
        # labelset = np.where(cls_count>0)    
        # labelset = labelset[0]    # 1D index of which class has been assign pseudo label. 

        # dd = cdist(all_fea, initc[labelset], args.distance)   # (N,C) distance of each features and cluster center
        # pred_label = dd.argmin(axis=1)   # (N)
        # predict = labelset[pred_label]   # (N)

        # aff = np.eye(K)[predict]         # one-hot (N, C)
        
        acc_list.append(np.sum(predict == all_label.float().numpy()) / len(all_fea))
    
    elif args.pLabelAlg in ['mknn']:
        ov, idx = torch.max(nn.Softmax(dim=1)(all_output), 1)
        bool_index = ov > pseudo_tau
        all_output = all_output[bool_index]
        allre_output = allre_output[bool_index]
        all_fea = all_fea[bool_index]
        all_label = all_label[bool_index]
        
        topk_num = max(all_fea.shape[0] // (args.num_classes * args.topk_alpha), 1)
        all_fea = all_fea / torch.norm(all_fea, p=2, dim=1, keepdim=True)
        top_sample = []
        top_label = []
                
        for cls_idx in range(args.num_classes):
            feat_samp_idx = torch.topk(all_output[:, cls_idx], topk_num)[1]
                
            feat_cls_sample = all_fea[feat_samp_idx, :]
            feat_cls_label = torch.zeros([len(feat_samp_idx)]).fill_(cls_idx)

            top_sample.append(feat_cls_sample)
            top_label.append(feat_cls_label)
        top_sample = torch.cat(top_sample, dim=0).cpu().numpy()
        top_label = torch.cat(top_label, dim=0).cpu().numpy()

        # add memory sample
        top_sample = np.concatenate([top_sample, replay_fea], axis=0)
        top_label = np.concatenate([top_label, replay_label], axis=0)
        knn = KNeighborsClassifier(n_neighbors=topk_num)
        knn.fit(top_sample, top_label)

        predict = knn.predict(all_fea.cpu().numpy())

    acc = np.sum(predict == all_label.float().numpy()) / len(all_fea)
    acc_list.append(acc)
    log_print("acc:" + " --> ".join("{:.3f}".format(acc) for acc in acc_list), args.log_file, p=False)
    acc_dict = {}
    for i in range(len(acc_list)):
        acc_dict['pa{}'.format(i)] = round(acc_list[i],3)

    return predict.astype('int'), acc_dict, bool_index

# --------------------------------------------------------------------------------------------------------------- #
# --------------------------------------------------------------------------------------------------------------- #

def BMD_label(args, dataloader, model):
    model.eval()
    emd_feat_stack = []
    cls_out_stack = []
    gt_label_stack = []

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            inputs = data[0].cuda()
            labels = data[1].cuda()

            if args.targetAlg in ['ERM', 'supcon', 'ERMPCL']:
                feas = model.featurizer(inputs)
                outputs = model.classifier(feas)
            elif args.targetAlg in ['PCL','PCL2', 'SupPCL', 'SupPCL2', 'FP']:
                feas = model.encoder(model.featurizer(inputs))
                outputs = F.linear(feas, model.classifier)
            emd_feat_stack.append(feas)
            cls_out_stack.append(outputs)
            gt_label_stack.append(labels)
        
    all_gt_label = torch.cat(gt_label_stack, dim=0)
    
    all_emd_feat = torch.cat(emd_feat_stack, dim=0)
    all_emd_feat = all_emd_feat / torch.norm(all_emd_feat, p=2, dim=1, keepdim=True)
    # current VISDA-C k_seg is set to 3
    topk_num = max(all_emd_feat.shape[0] // (args.num_classes * args.topk_alpha), 1)
        
    all_cls_out = torch.cat(cls_out_stack, dim=0)
    _, all_psd_label = torch.max(all_cls_out, dim=1)
    acc = torch.sum(all_gt_label == all_psd_label) / len(all_gt_label)
    acc_list = [acc.item()]
    #------------------------------------------------------------#
    ### BMD
    multi_cent_num = 4
    feat_multi_cent = torch.zeros((args.num_classes, multi_cent_num, all_emd_feat.size(1))).cuda()
    faiss_kmeans = faiss.Kmeans(all_emd_feat.size(1), multi_cent_num, niter=100, verbose=False, min_points_per_centroid=1)
    
    for cls_idx in range(args.num_classes):
        feat_samp_idx = torch.topk(all_cls_out[:, cls_idx], topk_num)[1]
            
        feat_cls_sample = all_emd_feat[feat_samp_idx, :].cpu().numpy()
        faiss_kmeans.train(feat_cls_sample)
        feat_multi_cent[cls_idx, :] = torch.from_numpy(faiss_kmeans.centroids).cuda()
        
    feat_dist = torch.einsum("cmk, nk -> ncm", feat_multi_cent, all_emd_feat) #[N,C,M]
    feat_dist, _ = torch.max(feat_dist, dim=2)  # [N, C]
    feat_dist = torch.softmax(feat_dist, dim=1) # [N, C]
        
    _, all_psd_label = torch.max(feat_dist, dim=1)
    acc = torch.sum(all_psd_label == all_gt_label) / len(all_gt_label)
    acc_list.append(acc.item())

    ### knn
    top_sample = []
    top_label = []
        
    for cls_idx in range(args.num_classes):
        feat_samp_idx = torch.topk(all_cls_out[:, cls_idx], topk_num)[1]
        feat_cls_sample = all_emd_feat[feat_samp_idx, :]
        feat_cls_label = all_gt_label[feat_samp_idx].fill_(cls_idx)

        top_sample.append(feat_cls_sample)
        top_label.append(feat_cls_label)
    top_sample = torch.cat(top_sample, dim=0).cpu().numpy()
    top_label = torch.cat(top_label, dim=0).cpu().numpy()

    # knn = KNeighborsClassifier(n_neighbors=topk_num, metric='cosine')
    knn = KNeighborsClassifier(n_neighbors=topk_num)
    knn.fit(top_sample, top_label)

    pred = knn.predict(all_emd_feat.cpu().numpy())
    pred = torch.from_numpy(pred).cuda()
    acc = torch.sum(pred == all_gt_label) / len(all_gt_label)
    acc_list.append(acc.item())
        
    log = "acc:" + " --> ".join("{:.3f}".format(acc) for acc in acc_list)
    psd_confu_mat = confusion_matrix(all_gt_label.cpu(), all_psd_label.cpu())
    psd_acc_list = psd_confu_mat.diagonal()/psd_confu_mat.sum(axis=1) * 100
    psd_acc = psd_acc_list.mean()
    psd_acc_str = "{:.2f}        ".format(psd_acc) + " ".join(["{:.2f}".format(i) for i in psd_acc_list])
    
    log_print(log, args.log_file, p=False)
    acc_dict = {}
    for i in range(len(acc_list)):
        acc_dict['pa{}'.format(i)] = round(acc_list[i],3)
    if args.pLabelAlg == 'BMD':
        return all_psd_label.cpu().numpy(), acc_dict
    elif args.pLabelAlg == 'knn':
        return pred.cpu().numpy(), acc_dict

class utilDataset(Dataset):
    '''
    construct pseudo dataset
    input: images_dict.
    '''
    def __init__(self, images_dict, class_labels, domain_labels, loader, transform=None, target_transform=None):
        self.x = images_dict                 # list of [PIL image]
        self.labels = class_labels           # numpy array
        self.dlabels = domain_labels         # numpy array
        self.loader = loader
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        imgs = self.transform(self.loader(self.x[index])) if self.transform is not None else self.loader(self.x[index])
        return imgs, self.labels[index], self.dlabels[index] 

    def get_raw_data(self):
        return self.x, self.labels, self.dlabels