"""
Builds upon: https://github.com/tntek/TPDS
Corresponding paper: https://link.springer.com/article/10.1007/s11263-023-01892-w
"""
import time

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from src.utils import IID_losses,loss
from src.models import network
from torch.utils.data import DataLoader
from src.data.data_list import ImageList, ImageList_idx, ImageList_idx_aug, ImageList_idx_aug_fix
from scipy.spatial.distance import cdist
from sklearn.metrics import confusion_matrix
import torch.nn.functional as F
from numpy import linalg as LA
from src.utils.utils import *
from tqdm import tqdm

logger = logging.getLogger(__name__)

def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
    return optimizer

def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
    decay = (1 + gamma * iter_num / max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr0'] * decay
        param_group['weight_decay'] = 1e-3
        param_group['momentum'] = 0.9
        param_group['nesterov'] = True
    return optimizer

def image_train(resize_size=256, crop_size=224, alexnet=False):
  if not alexnet:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
#   else:
#     normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
  return  transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])

def image_test(resize_size=256, crop_size=224, alexnet=False):
  if not alexnet:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
#   else:
#     normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
  return  transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        normalize
    ])

def data_load(cfg): 
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = cfg.TEST.BATCH_SIZE
    txt_tar = open(cfg.t_dset_path).readlines()
    txt_test = open(cfg.t_dset_path).readlines()

    # if not cfg.da == 'uda':
    #     label_map_s = {}
    #     for i in range(len(cfg.src_classes)):
    #         label_map_s[cfg.src_classes[i]] = i

    #     new_tar = []
    #     for i in range(len(txt_tar)):
    #         rec = txt_tar[i]
    #         reci = rec.strip().split(' ')
    #         if int(reci[1]) in cfg.tar_classes:
    #             if int(reci[1]) in cfg.src_classes:
    #                 line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'   
    #                 new_tar.append(line)
    #             else:
    #                 line = reci[0] + ' ' + str(len(label_map_s)) + '\n'   
    #                 new_tar.append(line)
    #     txt_tar = new_tar.copy()
    #     txt_test = txt_tar.copy()

    dsets["target"] = ImageList_idx(txt_tar, transform=image_train())
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=cfg.NUM_WORKERS, drop_last=False)
    dsets["test"] = ImageList_idx(txt_test, transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=cfg.NUM_WORKERS, drop_last=False)
    # dsets["test_aug"] = ImageList_idx_aug(txt_test, transform=image_test())
    # dset_loaders["test_aug"] = DataLoader(dsets["test_aug"], batch_size=train_bs*3, shuffle=False, num_workers=cfg.worker, drop_last=False)

    return dset_loaders


def cal_acc(loader, netF, netB, netC, flag=False):
    start_test = True
    with torch.no_grad():
        iter_test = iter(loader)
        for i in range(len(loader)):
            data = next(iter_test)
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            outputs = netC(netB(netF(inputs)))
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)

    _, predict = torch.max(all_output, 1)
    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()

    if flag:
        matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
        acc = matrix.diagonal()/matrix.sum(axis=1) * 100
        aacc = acc.mean()
        aa = [str(np.round(i, 2)) for i in acc]
        acc = ' '.join(aa)
        return aacc, acc
    else:
        return accuracy*100, mean_ent
#python image_target_of_oh_vs.py --cfg "cfgs/office-home/tpds.yaml" SETTING.S 0 SETTING.T 1
def machine(cfg):
    mid1 = time.time()
    dset_loaders = data_load(cfg)
    if cfg.MODEL.ARCH[0:3] == 'res':
        netF = network.ResBase(res_name=cfg.MODEL.ARCH).cuda()
    elif cfg.MODEL.ARCH[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=cfg.MODEL.ARCH).cuda()

    netB = network.feat_bottleneck(type='bn', feature_dim=netF.in_features, bottleneck_dim=cfg.bottleneck).cuda()
    netC = network.feat_classifier(type='wn', class_num = cfg.class_num, bottleneck_dim=cfg.bottleneck).cuda()
    modelpath = r'C:\Users\UserX\Desktop\source-free-domain-adaptation-main\output\uda\office-home\AC\tpds\target_F_tpds.pt'
    netF.load_state_dict(torch.load(modelpath))
    modelpath = r'C:\Users\UserX\Desktop\source-free-domain-adaptation-main\output\uda\office-home\AC\tpds\target_B_tpds.pt'
    netB.load_state_dict(torch.load(modelpath))
    modelpath = r'C:\Users\UserX\Desktop\source-free-domain-adaptation-main\output\uda\office-home\AC\tpds\target_C_tpds.pt'
    netC.load_state_dict(torch.load(modelpath))
    dset_loaders = {}
    print('模型在目标域')
    #test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\toartX.txt"  #TARTX
    #test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\toartY.txt"  #TARTY
    #test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\toartZ.txt"  #TARTZ
    test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\tocliX.txt"  #TCliX
    #test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\tocliY.txt"  #TCliY
    #test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\tocliZ.txt"  #TCliZ
    #test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\toproX.txt"  #TProX
    #test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\toproY.txt"  #TProY
    #test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\toproZ.txt"  #TProZ
    #test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\toReaX.txt"   #TReaX
    #test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\toReaY.txt"   #TReaY
    #test_dset_path = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\toReaZ.txt"   #TReaZ
    with open(test_dset_path) as f_test:
        txt_test = f_test.readlines()
        target_dataset = ImageList_idx(txt_test, transform=image_test())
    dset_loaders["target_images"] = DataLoader(target_dataset, batch_size=64, shuffle=True,
                                               num_workers=4, drop_last=False, pin_memory=False)
    netF.eval()
    netB.eval()
    netC.eval()

    #MB_log = cal_acc(dset_loaders["target_images"], netF, netB, netC, False)
    #print(MB_log)
# 获取伪标签
    mid2 = time.time()
    test = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\Clipart_list.txt"
    with open(test) as f_test:
        txt_test = f_test.readlines()
        target_dataset = ImageList_idx(txt_test, transform=image_test())
    dset_loaders["target_images"] = DataLoader(target_dataset, batch_size=64, shuffle=False,
                                               num_workers=4, drop_last=False, pin_memory=False)
    loader = dset_loaders["target_images"]
    start_test = True
    with torch.no_grad():
        data_bar = tqdm(range(len(loader)))  # 使用 tqdm 创建进度条
        for i in data_bar:
            try:  # 尝试从 loader 中获取数据
                clean_images, clean_labels, _ = next(iter_clean)
            except:  # 如果遇到异常，则重新生成 clean_loader 迭代器
                iter_clean = iter(loader)
                clean_images, clean_labels, _ = next(iter_clean)
            # 设置进度条描述
            data_bar.set_description("MainBranch : Step:{}".format(i))
            clean_images1 = clean_images.cuda()  # 将 clean_images 移动到 GPU 上
            clean_outputs = netC(netB(netF(clean_images1)))  # 通过网络模型进行前向传播
            if start_test:  # 如果是第一次测试，则初始化结果
                all_output_clean = clean_outputs.float().cpu()
                all_label_clean = clean_labels.float()

                start_test = False
            else:  # 非第一次测试，则拼接结果
                all_output_clean = torch.cat((all_output_clean, clean_outputs.float().cpu()), 0)
                all_label_clean = torch.cat((all_label_clean, clean_labels.float()), 0)
    _, predict_clean = torch.max(all_output_clean, 1)
    max_probs = all_output_clean[range(len(all_output_clean)), predict_clean]
    filename = r"C:\Users\UserX\Desktop\source-free-domain-adaptation-main\data\office-home\fake\ATCZ.txt"
    result_file = open(filename, 'w')

    #classes = [46, 60, 12, 58, 23]   #PRX
    #classes = [46, 60, 12, 58, 23,47,19,11,62,52,5,42,43,17,33] #PRY
    # classes =[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 13, 14, 15, 16, 18, 20, 21, 22, 24, 25, 26,
    #  27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 41, 44, 45, 48, 49, 50, 51,
    #  53, 54, 55, 56, 57, 59, 61, 63, 64]#PRZ
    #classes = [53, 62, 1, 17, 58]  # CX
    #classes = [53, 62, 1, 17, 58,60,9,35,11,44,2,15,3,52,43]  # CY
    classes =[0, 4, 5, 6, 7, 8, 10, 12, 13, 14, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26,
                27,28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42,   45, 46, 47, 48, 49, 50, 51,
                 54, 55, 56, 57,  59,  61, 63, 64]  # CZ
    #classes =[34, 33, 12, 26, 23,35,24,15,40,63,17,55,43,45,38]#AY
    # classes =[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14,  16, 18, 19, 20, 21, 22,  25,
    #             27,28, 29, 30, 31, 32, 36, 37,  39,  41, 42, 44,  46, 47, 48, 49, 50, 51,
    #             52, 53,54, 56, 57, 58, 59, 60, 61, 62, 64]
    #classes = [11, 62, 44, 33, 12] #RX
    #classes =[11, 62, 44, 33, 12,38,61,41,5,39,34,16,63,43,23]#RY
    # classes =[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 13, 14, 15, 17, 18, 19, 20, 21, 22, 24, 25, 26,
    #  27, 28, 29, 30, 31, 32, 35, 36, 37, 40, 42, 45, 46, 47, 48, 49, 50, 51,
    #  52, 53, 54, 55, 56, 57, 58, 59, 60, 64] #RZ
    for class_id in classes:
        results = []
        idx = (predict_clean == class_id).nonzero().squeeze()
        probs = max_probs[idx]
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        topk_indices = sorted_indices[:7].tolist()
        topk_img_idx = idx[topk_indices].tolist()
        topk_img_idx = np.array(topk_img_idx).flatten()
        results.extend(topk_img_idx)
        results = np.array(results)
        # 保存结果到filename文件中
        with open(r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\Clipart_list_data.txt") as f:
            lines = f.readlines()
            # 遍历矩阵中的元素作为行号,读取对应行内容
            for i in results:
                line = lines[int(i) - 1].strip()  # 获取对应行的内容,并去掉换行符
                result_file.write(f"{line} {class_id}\n")  # 将图片地址和类别ID写入新文件
    result_file.close()
    # 删除空白行
    lines = []
    with open(filename) as f:
        for line in f:
            if line.strip():
                lines.append(line)
    with open(filename, 'w') as f:
        for line in lines:
            f.write(line)
    mid3 = time.time()
    # UL框架
    classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
               27,28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
               52, 53,54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64]
    # classes_to_forget = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,  24, 25,
    #            27,28, 29, 30, 31, 32, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
    #            52, 53,54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64]  # ARTX
    # classes_to_forget = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14,  16, 18, 19, 20, 21, 22,  25,
    #            27,28, 29, 30, 31, 32, 36, 37,  39,  41, 42, 44,  46, 47, 48, 49, 50, 51,
    #            52, 53,54, 56, 57, 58, 59, 60, 61, 62, 64]  # ARTY
    #classes_to_forget = [34, 33, 12, 26, 23,35,24,15,40,63,17,55,43,45,38] # ARTZ
    # classes_to_forget = [0,  2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,  18, 19, 20, 21, 22, 23, 24, 25, 26,
    #            27,28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
    #            52, 54, 55, 56, 57,  59, 60, 61, 63, 64]  # CLIX
    # classes_to_forget = [0, 4, 5, 6, 7, 8,  10,  12, 13, 14,  16,  18, 19, 20, 21, 22, 23, 24, 25, 26,
    #            27,28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42,   45, 46, 47, 48, 49, 50, 51,
    #             54, 55, 56, 57,  59,  61, 63, 64]  # CLIY
    classes_to_forget= [53,62,1,17,58,60,9,35,11,44,2,15,3,52,43] #CLIZ
    # classes_to_forget =[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,  24, 25, 26,
    #            27,28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,  47, 48, 49, 50, 51,
    #            52, 53,54, 55, 56, 57,  59,  61, 62, 63, 64] #PROX
    # classes_to_forget = [0, 1, 2, 3, 4,  6, 7, 8, 9, 10,  13, 14, 15, 16,  18,  20, 21, 22, 24, 25, 26,
    #                      27, 28, 29, 30, 31, 32,  34, 35, 36, 37, 38, 39, 40, 41,  44, 45, 48, 49, 50, 51,
    #                      53, 54, 55, 56, 57, 59, 61,  63, 64]  # PROY
    #classes_to_forget =[5,11,12,17,19, 23,33, 42,43,46,47, 52,58, 60,62] #PROZ
    # classes_to_forget = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
    #  27, 28, 29, 30, 31, 32,  34, 35, 36, 37, 38, 39, 40, 41, 42, 43,  45, 46, 47, 48, 49, 50, 51,
    #  52, 53, 54, 55, 56, 57, 58, 59, 60, 61,  63, 64]#REAX
    #classes_to_forget = [5, 11, 12, 16, 23, 33, 34, 38, 39, 41, 43, 44, 61, 62, 63] #REAZ
    # classes_to_forget = [0, 1, 2, 3, 4,  6, 7, 8, 9, 10,  13, 14, 15, 17, 18, 19, 20, 21, 22, 24, 25, 26,
    #            27,28, 29, 30, 31, 32,  35, 36, 37,  40,  42,  45, 46, 47, 48, 49, 50, 51,
    #            52, 53,54, 55, 56, 57, 58, 59, 60,  64] #REAY
    num_classes = 65
    # 伪标签加载
    with open(filename) as f_test:
        txt_test = f_test.readlines()
    sub = ImageList(txt_test, transform=image_test())
    classwise_test = {}
    for i in range(num_classes):
        classwise_test[i] = []

    for img, label in sub:
        classwise_test[label].append((img, label))
    retain_samples = []  # 用于noise data
    for i in range(len(classes)):
        if classes[i] not in classes_to_forget:  # 如果该类别不需要遗忘
            retain_samples += classwise_test[i][:7]
        print(f"加入Class {i}后，训练noise的数据集共有: {len(retain_samples)} 张伪标签")

    class Noise(nn.Module):
        def __init__(self, *dim):
            super().__init__()
            self.noise = torch.nn.Parameter(torch.randn(*dim), requires_grad=True)

        def forward(self):
            return self.noise

    # 获得噪声矩阵
    noises = {}
    for cls in classes_to_forget:
        print("Optiming loss for class {}".format(cls))
        noises[cls] = Noise(32, 3, 224, 224).cuda()
        opt = torch.optim.Adam(noises[cls].parameters(), lr=0.1)

        num_epochs = 2
        num_steps = 16
        class_label = cls
        for epoch in range(num_epochs):
            total_loss = []
            for batch in range(num_steps):
                inputs = noises[cls].noise  # 使用 noises[cls].noise 访问噪声张量
                labels = torch.zeros(32).cuda() + class_label  # 贴上伪标签
                outputs = netC(netB(netF(inputs)))
                loss = -F.cross_entropy(outputs, labels.long()) + 0.1 * torch.mean(
                    torch.sum(torch.square(inputs), [1, 2, 3]))
                opt.zero_grad()
                loss.backward()
                opt.step()
                total_loss.append(loss.cpu().detach().numpy())
            print("Loss: {}".format(np.mean(total_loss)))


    # 进行遗忘
    batch_size = 32
    noisy_data = []
    num_batches = 1  # 定义了每个类别需要的样本批次
    class_num = 0
    # noisy_data中包含了所有需要遗忘的类别的噪声数据
    for cls in classes_to_forget:
        for i in range(num_batches):
            batch = noises[cls]().cpu().detach()
            batch_size = batch.size(0)
            for i in range(batch_size):
                label = torch.tensor(cls)
                data = (batch[i], label)
                noisy_data.append(data)
    other_samples = []
    for i in range(len(retain_samples)):
        other_samples.append((retain_samples[i][0].cpu(), torch.tensor(retain_samples[i][1])))

    noisy_data += other_samples
    noisy_loader = DataLoader(noisy_data, batch_size=32, shuffle=True, num_workers=4, drop_last=False,
                              pin_memory=False)

    param_group = []
    learning_rate = 1e-2
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate * 0.1}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)
    for epoch in range(5):
        netF.train()
        netB.train()
        netC.train()
        running_loss = 0.0
        running_acc = 0
        for i, data in enumerate(noisy_loader):
            inputs, labels = data
            inputs, labels = inputs.cuda(), labels.clone().detach().cuda()
            optimizer.zero_grad()
            outputs = netC(netB(netF(inputs)))
            for label_idx, label in enumerate(labels):
                if label in classes_to_forget:
                    outputs[label_idx, label] *= 5  # 需要遗忘的类别数据权重
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item() * inputs.size(0)
            out = torch.argmax(outputs.detach(), dim=1)
            assert out.shape == labels.shape
            running_acc += (labels == out).sum().item()

    mid4=time.time()
    print("自学习步骤")
    heal_loader = torch.utils.data.DataLoader(other_samples, batch_size=32, shuffle=True)
    param_group = []
    learning_rate = 1e-2
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate * 0.1}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)
    for epoch in range(60):
        netF.train()
        netB.train()
        netC.train()
        running_loss = 0.0
        running_acc = 0
        for i, data in enumerate(heal_loader):
            inputs, labels = data
            # inputs, labels = inputs.cuda(), torch.tensor(labels).cuda()
            inputs, labels = inputs.cuda(), labels.clone().detach().cuda()
            optimizer.zero_grad()
            # outputs = model(inputs)
            outputs = netC(netB(netF(inputs)))
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item() * inputs.size(0)
            out = torch.argmax(outputs.detach(), dim=1)
            assert out.shape == labels.shape
            running_acc += (labels == out).sum().item()




    output_dir_src = r'C:\Users\UserX\Desktop\source-free-domain-adaptation-main\output\uda\office-home\AC'
    torch.save(netF.state_dict(), osp.join(output_dir_src, "source_F.pt"))
    torch.save(netB.state_dict(), osp.join(output_dir_src, "source_B.pt"))
    torch.save(netC.state_dict(), osp.join(output_dir_src, "source_C.pt"))
    mid5=time.time()
    forgettime = 'F Running time: %s Seconds'%(round(mid4-mid3+mid2-mid1, 2))
    print(forgettime)
    adapttime  = 'A Running time: %s Seconds'%(round(mid5-mid4+mid3-mid1, 2))
    print(adapttime)
    ourtime    =  'O Running time: %s Seconds'%(round(mid5-mid1, 2))
    print(ourtime)


def count(cfg):
    dset_loaders = data_load(cfg)
    if cfg.MODEL.ARCH[0:3] == 'res':
        netF = network.ResBase(res_name=cfg.MODEL.ARCH).cuda()
    elif cfg.MODEL.ARCH[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=cfg.MODEL.ARCH).cuda()

    netB = network.feat_bottleneck(type='bn', feature_dim=netF.in_features, bottleneck_dim=cfg.bottleneck).cuda()
    netC = network.feat_classifier(type='wn', class_num = cfg.class_num, bottleneck_dim=cfg.bottleneck).cuda()
    #modelpath = r'C:\Users\UserX\Desktop\source-free-domain-adaptation-main\output\uda\office-home\RP\tpds\target_F_tpds.pt'
    modelpath = r'C:\Users\UserX\Desktop\source-free-domain-adaptation-main\output\uda\office-home\RP\source_F.pt'
    netF.load_state_dict(torch.load(modelpath))
    #modelpath = r'C:\Users\UserX\Desktop\source-free-domain-adaptation-main\output\uda\office-home\RP\tpds\target_B_tpds.pt'
    modelpath = r'C:\Users\UserX\Desktop\source-free-domain-adaptation-main\output\uda\office-home\RP\source_B.pt'
    netB.load_state_dict(torch.load(modelpath))
    #modelpath = r'C:\Users\UserX\Desktop\source-free-domain-adaptation-main\output\uda\office-home\RP\tpds\target_C_tpds.pt'
    modelpath = r'C:\Users\UserX\Desktop\source-free-domain-adaptation-main\output\uda\office-home\RP\source_C.pt'
    netC.load_state_dict(torch.load(modelpath))


    dset_loaders = {}
    test_low = r"C:\Users\UserX\Desktop\SHOT-master\data\office-home\toproZ.txt"
    with open(test_low) as f_test:
        txt_test = f_test.readlines()
    target_dataset = ImageList_idx(txt_test, transform=image_test())
    dset_loaders["low"] = DataLoader(target_dataset, batch_size=16, shuffle=False,
                                     num_workers=4, drop_last=False, pin_memory=False)

    loader = dset_loaders["low"]

    netF.eval()
    netB.eval()
    netC.eval()

    MB_log = cal_acc(dset_loaders["low"], netF, netB, netC, False)
    print(MB_log)


    start_test = True
    with torch.no_grad():
        data_bar = tqdm(range(len(loader)))  # 使用 tqdm 创建进度条
        for i in data_bar:
            try:  # 尝试从 loader 中获取数据
                clean_images, clean_labels, _ = next(iter_clean)
            except:  # 如果遇到异常，则重新生成 clean_loader 迭代器
                iter_clean = iter(loader)
                clean_images, clean_labels, _ = next(iter_clean)
            # 设置进度条描述
            data_bar.set_description("MainBranch : Step:{}".format(i))
            clean_images1 = clean_images.cuda()  # 将 clean_images 移动到 GPU 上
            clean_outputs = netC(netB(netF(clean_images1)))  # 通过网络模型进行前向传播
            if start_test:  # 如果是第一次测试，则初始化结果
                all_output_clean = clean_outputs.float().cpu()
                all_label_clean = clean_labels.float()

                start_test = False
            else:  # 非第一次测试，则拼接结果
                all_output_clean = torch.cat((all_output_clean, clean_outputs.float().cpu()), 0)
                all_label_clean = torch.cat((all_label_clean, clean_labels.float()), 0)
    _, predict_clean = torch.max(all_output_clean, 1)
    max_probs = all_output_clean[range(len(all_output_clean)), predict_clean]

    #classes = [34,33,12,26,23] #AX
    #classes = [34, 33, 12, 26, 23,35,24,15,40,63,17,55,43,45,38] #AY  AZ
    #classes=[53,62,1,17,58] #CX
    #classes = [53, 62, 1, 17, 58, 60, 9, 35, 11, 44, 2, 15, 3, 52, 43]  # CY CZ
    #classes = [46, 60, 12, 58, 23]  # PX
    classes = [46, 60, 12, 58, 23, 47, 19, 11, 62, 52, 5, 42, 43, 17, 33] # PY PZ
    #classes =[11,62,44,12,33] #RX
    #classes = [11, 62, 44, 33, 12, 38, 61, 41, 5, 39, 34, 16, 63, 43, 23]  # RY RZ

    from collections import Counter

    num_classes = len(classes)
    class_counts = [0] * num_classes

    for label in predict_clean:
        if label in classes:
            class_idx = classes.index(label)
            class_counts[class_idx] += 1
    print(class_counts)
    print("和是:",sum(class_counts) )











def train_target(cfg):
    dset_loaders = data_load(cfg)
    ## set base network
    if cfg.MODEL.ARCH[0:3] == 'res':
        netF = network.ResBase(res_name=cfg.MODEL.ARCH).cuda()
    elif cfg.MODEL.ARCH[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=cfg.MODEL.ARCH).cuda()  

    netB = network.feat_bottleneck(type='bn', feature_dim=netF.in_features, bottleneck_dim=cfg.bottleneck).cuda()
    netC = network.feat_classifier(type='wn', class_num = cfg.class_num, bottleneck_dim=cfg.bottleneck).cuda()

    modelpath = cfg.output_dir_src + '/source_F.pt'   
    netF.load_state_dict(torch.load(modelpath))
    modelpath = cfg.output_dir_src + '/source_B.pt'   
    netB.load_state_dict(torch.load(modelpath))
    modelpath = cfg.output_dir_src + '/source_C.pt'    
    netC.load_state_dict(torch.load(modelpath))


    param_group = []
    for k, v in netF.named_parameters():
        if cfg.OPTIM.LR_DECAY1 > 0:
            param_group += [{'params': v, 'lr': cfg.OPTIM.LR * cfg.OPTIM.LR_DECAY1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if cfg.OPTIM.LR_DECAY2 > 0:
            param_group += [{'params': v, 'lr': cfg.OPTIM.LR * cfg.OPTIM.LR_DECAY2}]
        else:
            v.requires_grad = False

    for k, v in netC.named_parameters():
        if cfg.OPTIM.LR_DECAY1 > 0:
            param_group += [{'params': v, 'lr': cfg.OPTIM.LR * cfg.OPTIM.LR_DECAY1}]
        else:
            v.requires_grad = False


    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    max_iter = cfg.TEST.MAX_EPOCH * len(dset_loaders["target"])
    interval_iter = max_iter // cfg.TEST.INTERVAL
    iter_num = 0
    iter_num_update = 0

    while iter_num < max_iter:
        try:
            inputs_test, _, _ = next(iter_test)
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, _ = next(iter_test)
            
        if inputs_test.size(0) == 1:
            continue

        if iter_num % interval_iter == 0:
            iter_num_update += 1
            netF.eval()
            netB.eval()
            netC.eval()
            _, feas_all, label_confi, _, _ = obtain_label_ts(dset_loaders['test'], netF, netB, netC, cfg, iter_num_update)
            netF.train()
            netB.train()
            netC.train()

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        # -----------------------------------data--------------------------------
        inputs_test = inputs_test.cuda()
        # inputs_test_aug = inputs_test_aug.cuda()
        features_test_F = netF(inputs_test)
        features_test = netB(features_test_F)
        outputs_test = netC(features_test)
        softmax_out = nn.Softmax(dim=1)(outputs_test)

        features_test_N, _, _ = obtain_nearest_trace(features_test_F, feas_all, label_confi)
        features_test_N = features_test_N.cuda()
        features_test_N = netB(features_test_N)
        outputs_test_N = netC(features_test_N)
        softmax_out_hyper = nn.Softmax(dim=1)(outputs_test_N)

        # -------------------------------objective------------------------------
        classifier_loss = torch.tensor(0.0).cuda()
        iic_loss = IID_losses.IID_loss(softmax_out, softmax_out_hyper)
        classifier_loss = classifier_loss + 1.0 * iic_loss

        if cfg.SETTING.DATASET == "office" or cfg.SETTING.DATASET== "domainnet126" :
            # print(1)
            msoftmax = softmax_out.mean(dim=0)
            gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + cfg.TPDS.EPSILON))
            gentropy_loss = gentropy_loss * 1.0
            classifier_loss = classifier_loss - gentropy_loss

        elif cfg.SETTING.DATASET == "office-home":
            msoftmax = softmax_out.mean(dim=0)
            gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + cfg.TPDS.EPSILON))
            gentropy_loss = gentropy_loss * 0.5
            classifier_loss = classifier_loss - gentropy_loss

        # elif cfg.dset == "office-home":
        #     msoftmax = softmax_out.mean(dim=0)
        #     gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + cfg.epsilon))
        #     gentropy_loss = gentropy_loss * 0.5
        #     classifier_loss = classifier_loss - gentropy_loss

        # --------------------------------------------------------------------    
        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            netC.eval()
            if cfg.SETTING.DATASET=='VISDA-C':
                acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(cfg.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list
            else:
                acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(cfg.name, iter_num, max_iter, acc_s_te)

            logging.info(log_str)
            netF.train()
            netB.train()
            netC.train()


    if cfg.ISSAVE:   
        torch.save(netF.state_dict(), osp.join(cfg.output_dir, "target_F_" + cfg.savename + ".pt"))
        torch.save(netB.state_dict(), osp.join(cfg.output_dir, "target_B_" + cfg.savename + ".pt"))
        torch.save(netC.state_dict(), osp.join(cfg.output_dir, "target_C_" + cfg.savename + ".pt"))
        
    return netF, netB, netC

def print_cfg(cfg):
    s = "==========================================\n"
    for arg, content in cfg.__dict__.items():
        s += "{}:{}\n".format(arg, content)
    return s

def obtain_label_ts(loader, netF, netB, netC, cfg, iter_num_update_f):
    start_test = True
    with torch.no_grad():
        iter_test = iter(loader)
        for _ in range(len(loader)):
            data = next(iter_test)
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            feas_F = netF(inputs)
            feas = netB(feas_F)
            outputs = netC(feas)
            if start_test:
                all_fea_F = feas_F.float().cpu()
                all_fea = feas.float().cpu()
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_fea_F = torch.cat((all_fea_F, feas_F.float().cpu()), 0)
                all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)

    # all_logis = all_output
    all_output = nn.Softmax(dim=1)(all_output)
    ent = torch.sum(-all_output * torch.log(all_output + cfg.TPDS.EPSILON), dim=1)
    unknown_weight = 1 - ent / np.log(cfg.class_num)
    _, predict = torch.max(all_output, 1)

    len_unconfi = int(ent.shape[0]*0.5)
    idx_unconfi = ent.topk(len_unconfi, largest=True)[-1]
    idx_unconfi_list_ent = idx_unconfi.cpu().numpy().tolist()

    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    if cfg.TPDS.DISTANCE == 'cosine':
        all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()

    all_fea = all_fea.float().cpu().numpy()
    K = all_output.size(1)
    aff = all_output.float().cpu().numpy()
    initc = aff.transpose().dot(all_fea)
    initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
    cls_count = np.eye(K)[predict].sum(axis=0)
    labelset = np.where(cls_count>cfg.TPDS.THRESHOLD)
    labelset = labelset[0]
    # print(labelset)

    dd = cdist(all_fea, initc[labelset], cfg.TPDS.DISTANCE)
    pred_label = dd.argmin(axis=1)
    pred_label = labelset[pred_label]

    # --------------------use dd to get confi_idx and unconfi_idx-------------
    dd_min = dd.min(axis = 1)
    dd_min_tsr = torch.from_numpy(dd_min).detach()
    dd_t_confi = dd_min_tsr.topk(int((dd.shape[0]*0.6)), largest = False)[-1]
    dd_confi_list = dd_t_confi.cpu().numpy().tolist()
    dd_confi_list.sort()
    idx_confi = dd_confi_list

    idx_all_arr = np.zeros(shape = dd.shape[0], dtype = np.int64)
    idx_all_arr[idx_confi] = 1
    idx_unconfi_arr = np.where(idx_all_arr == 0)
    idx_unconfi_list_dd = list(idx_unconfi_arr[0])

    idx_unconfi_list = list(set(idx_unconfi_list_dd).intersection(set(idx_unconfi_list_ent)))
    # ------------------------------------------------------------------------
    # idx_unconfi_list = idx_unconfi_list_dd # idx_unconfi_list_dd

    label_confi = np.ones(ent.shape[0], dtype="int64")
    label_confi[idx_unconfi_list] = 0

    acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
    log_str = '{:.1f} AccuracyEpoch = {:.2f}% -> {:.2f}%'.format(iter_num_update_f, accuracy * 100, acc * 100)

    logging.info(log_str)

    return pred_label.astype('int'), all_fea_F, label_confi, all_label, all_output



def obtain_nearest_trace(data_q, data_all, lab_confi):
    data_q_ = data_q.detach()
    data_all_ = data_all.detach()
    data_q_ = data_q_.cpu().numpy()
    data_all_ = data_all_.cpu().numpy()
    num_sam = data_q.shape[0]
    LN_MEM = 70

    flag_is_done = 0         # indicate whether the trace process has done over the target dataset 
    ctr_oper = 0             # counter the operation time
    idx_left = np.arange(0, num_sam, 1)
    mtx_mem_rlt = -3*np.ones((num_sam, LN_MEM), dtype='int64')
    mtx_mem_ignore = np.zeros((num_sam, LN_MEM), dtype='int64')
    is_mem = 0
    mtx_log = np.zeros((num_sam, LN_MEM), dtype='int64')
    indices_row = np.arange(0, num_sam, 1)
    flag_sw_bad = 0 
    nearest_idx_last = np.array([-7])

    while flag_is_done == 0:

        nearest_idx_tmp, idx_last_tmp = get_nearest_sam_idx(data_q_, data_all_, is_mem, ctr_oper, mtx_mem_ignore, nearest_idx_last)
        is_mem = 1
        nearest_idx_last = nearest_idx_tmp

        if ctr_oper == (LN_MEM-1):    
            flag_sw_bad = 1
        else:
            flag_sw_bad = 0 

        mtx_mem_rlt[:, ctr_oper] = nearest_idx_tmp
        mtx_mem_ignore[:, ctr_oper] = idx_last_tmp
        
        lab_confi_tmp = lab_confi[nearest_idx_tmp]
        idx_done_tmp = np.where(lab_confi_tmp == 1)[0]
        idx_left[idx_done_tmp] = -1

        if flag_sw_bad == 1:
            idx_bad = np.where(idx_left >= 0)[0]
            mtx_log[idx_bad, 0] = 1
        else:
            mtx_log[:, ctr_oper] = lab_confi_tmp

        flag_len = len(np.where(idx_left >= 0)[0])
        # print("{}--the number of left:{}".format(str(ctr_oper), flag_len))
        
        if flag_len == 0 or flag_sw_bad == 1:
            # idx_nn_tmp = [list(mtx_log[k, :]).index(1) for k in range(num_sam)]
            idx_nn_step = []
            for k in range(num_sam):
                try:
                    idx_ts = list(mtx_log[k, :]).index(1)
                    idx_nn_step.append(idx_ts)
                except:
                    print("ts:", k, mtx_log[k, :])
                    # mtx_log[k, 0] = 1
                    idx_nn_step.append(0)

            idx_nn_re = mtx_mem_rlt[indices_row, idx_nn_step]
            data_re = data_all[idx_nn_re, :]
            flag_is_done = 1
        else:
            data_q_ = data_all_[nearest_idx_tmp, :]
        ctr_oper += 1

    return data_re, idx_nn_re, idx_nn_step # array



def get_nearest_sam_idx(Q, X, is_mem_f, step_num, mtx_ignore, nearest_idx_last_f): # Q、X arranged in format of row-vector
    Xt = np.transpose(X)
    Simo = np.dot(Q, Xt)               
    nq = np.expand_dims(LA.norm(Q, axis=1), axis=1)
    nx = np.expand_dims(LA.norm(X, axis=1), axis=0)
    Nor = np.dot(nq, nx)
    Sim = 1 - (Simo / Nor)

    # Sim = cdist(Q, X, "cosine") # too slow
    # print('eeeeee \n', Sim)

    indices_min = np.argmin(Sim, axis=1)
    indices_row = np.arange(0, Q.shape[0], 1)
    
    idx_change = np.where((indices_min - nearest_idx_last_f)!=0)[0] 
    if is_mem_f == 1:
        if idx_change.shape[0] != 0:
            indices_min[idx_change] = nearest_idx_last_f[idx_change]  
    Sim[indices_row, indices_min] = 1000

    # mytst = np.eye(795)[indices_min]
    # mytst_log = np.sum(mytst, axis=0)
    # haha = np.where(mytst_log > 1)[0]
    # if haha.size != 0:
    #     print(haha)

    # Ignore the history elements. 
    if is_mem_f == 1:
        for k in range(step_num):
            indices_ingore = mtx_ignore[:, k]
            Sim[indices_row, indices_ingore] = 1000
    
    indices_min_cur = np.argmin(Sim, axis=1)
    indices_self = indices_min
    return indices_min_cur, indices_self
