import argparse
import os, sys
import os.path as osp
import torchvision
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import network
import loss
from torch.utils.data import DataLoader
from data_list import ImageList, ImageList_idx
import random, pdb, math, copy
from tqdm import tqdm
from loss import CrossEntropyLabelSmooth
from scipy.spatial.distance import cdist
from sklearn.metrics import confusion_matrix
from sklearn.metrics.cluster import normalized_mutual_info_score
from utils_func import compute_accuracy, get_model, rand_bbox

def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
    return optimizer

def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
    decay = (1 + gamma * iter_num / max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr0'] * decay
        param_group['weight_decay'] = 1e-3
        param_group['momentum'] = 0.9
        param_group['nesterov'] = True
    return optimizer

def image_train(resize_size=256, crop_size=224, alexnet=False):
  if not alexnet:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
  else:
    normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
  return  transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])

def image_test(resize_size=256, crop_size=224, alexnet=False):
  if not alexnet:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
  else:
    normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
  return  transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        normalize
    ])

def data_load(args): 
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_src = []
    src_idx_to_domain_id = {}
    cur_idx = 0
    for domain, path in enumerate(args.s_dset_path_list):
        cur_list = open(path).readlines()
        txt_src.extend(cur_list)
        for i in range(len(cur_list)):
            src_idx_to_domain_id[cur_idx + i] = domain
        cur_idx += len(cur_list)
    txt_tar = open(args.t_dset_path).readlines()
    txt_test = open(args.test_dset_path).readlines()

    tar_to_domain_id = {i:0 for i in range(len(txt_tar))}
    test_to_domain_id = {i:0 for i in range(len(txt_test))}
 
    count = np.zeros(args.class_num)
    tr_txt = []
    te_txt = []
    tr_to_domain_id = {}
    te_to_domain_id = {}
    tr_idx = 0
    te_idx = 0
    for i in range(len(txt_src)):
        line = txt_src[i]
        reci = line.strip().split(' ')
        if count[int(reci[1])] < 3:
            count[int(reci[1])] += 1
            te_txt.append(line)
            te_to_domain_id[te_idx] = src_idx_to_domain_id[i]
            te_idx += 1
        else:
            tr_txt.append(line)
            tr_to_domain_id[tr_idx] = src_idx_to_domain_id[i]
            tr_idx += 1

    dsets["source_tr"] = ImageList_idx(tr_txt, idx_to_domain=tr_to_domain_id , transform=image_train(), subsample=args.sub_s)
    dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets["source_te"] = ImageList(tr_txt, idx_to_domain=tr_to_domain_id, transform=image_test(), subsample=args.sub_s)
    dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets["source_test"] = ImageList(te_txt, idx_to_domain=te_to_domain_id, transform=image_test(), subsample=args.sub_s)
    dset_loaders["source_test"] = DataLoader(dsets["source_test"], batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False) 
    dsets["target"] = ImageList_idx(txt_tar, idx_to_domain=tar_to_domain_id, transform=image_train(), subsample=args.sub_t)
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
    dsets["target_te"] = ImageList(txt_tar, idx_to_domain=tar_to_domain_id, transform=image_test(), subsample=args.sub_t)
    dset_loaders["target_te"] = DataLoader(dsets["target_te"], batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False)
    dsets["test"] = ImageList(txt_test, idx_to_domain=test_to_domain_id, transform=image_test(), subsample=args.sub_t)
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=False, num_workers=args.worker, drop_last=False)

    return dset_loaders

def cal_acc(loader, netF, netB, netC, flag=False):
    start_test = True
    with torch.no_grad():
        iter_test = iter(loader)
        for i in range(len(loader)):
            data = iter_test.next()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            if netB is None:
                outputs = netC(netF(inputs))
            else:
                outputs = netC(netB(netF(inputs)))
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)

    all_output = nn.Softmax(dim=1)(all_output)
    _, predict = torch.max(all_output, 1)
    mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() / np.log(all_label.size()[0])
    nmi = normalized_mutual_info_score(all_label, predict)
    cluster_acc = compute_accuracy(predict, all_label, args.class_num, args.class_num)
    return mean_ent, nmi, cluster_acc*100

def train_source_simp(args):
    dset_loaders = data_load(args)

    netF = get_model(model_name=args.net_src, init_type=args.source_init).cuda()
    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()


    torch.save(netF.state_dict(), osp.join(args.output_dir_src, "source_F.pt"))
    torch.save(netB.state_dict(), osp.join(args.output_dir_src, "source_B.pt"))
    torch.save(netC.state_dict(), osp.join(args.output_dir_src, "source_C.pt"))
    test_target_simp(args)


    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate*0.1}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]   
    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    acc_init = 0.0
    max_iter = args.max_epoch * len(dset_loaders["source_tr"])
    interval_iter = max_iter // 10
    iter_num = 0

    model = nn.Sequential(netF, netB, netC).cuda()
    model.eval()
    
    args.epochs = args.max_epoch
    args.iters_per_epoch = len(dset_loaders["source_tr"])
    proto_criterion = loss.ProtoMultiDomainLoss(num_domains=len(args.s), num_classes=args.class_num, ot_reg=args.ot_reg, args=args)
    model.train()
    while iter_num < max_iter:
        try:
            inputs_source, _, tar_idx, domain_id = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            inputs_source, _, tar_idx, domain_id = iter_source.next()
        
        if inputs_source.size(0) == 1:
            continue

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=1.5)
        inputs_source = inputs_source.cuda()
        outputs_feature = model[1](model[0](inputs_source))
        outputs_source = model[2](outputs_feature)
        outputs_source = torch.nn.Softmax(dim=1)(outputs_source)
        pred_labels = torch.argmax(outputs_source, dim=1).detach()
        
        optimizer.zero_grad()

        entropy_loss = torch.mean(loss.Entropy(outputs_source))
        msoftmax = outputs_source.mean(dim=0)
        gentropy_loss = torch.sum(- msoftmax * torch.log(msoftmax + 1e-5))
        entropy_loss -= gentropy_loss
        im_loss = entropy_loss
        prototypes = model[-1].fc.weight
        proto_loss = proto_criterion(prototypes, outputs_feature, pred_labels, domain_id) 
        classifier_loss = args.lambda_im * im_loss + args.lambda_proto * proto_loss
        classifier_loss.backward()


        if args.mix > 0:
            alpha = 0.3
            lam = np.random.beta(alpha, alpha)
            index = torch.randperm(inputs_source.size()[0]).cuda()
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs_source.size(), lam)
            inputs_source[:, :, bbx1:bbx2, bby1:bby2] = inputs_source[index, :, bbx1:bbx2, bby1:bby2]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs_source.size()[-1] * inputs_source.size()[-2]))
            mixed_output = (lam * outputs_source + (1 - lam) * outputs_source[index, :]).detach()

            update_batch_stats(model, False)
            outputs_source_m = model(inputs_source)
            update_batch_stats(model, True)
            outputs_source_m = torch.nn.Softmax(dim=1)(outputs_source_m)
            classifier_loss = args.mix*nn.KLDivLoss(reduction='batchmean')(outputs_source_m.log(), mixed_output)
            classifier_loss.backward()
        optimizer.step()


        if iter_num % interval_iter == 0 or iter_num == max_iter:
            model.eval()
            _, nmi, cluster_acc = cal_acc(dset_loaders['source_test'], model[0], model[1], model[2], False)
            log_str = 'Task: {}, Iter:{}/{}; NMI = {:.2f}; Cluster Acc = {:.2f}'.format(args.name_src, iter_num, max_iter, nmi, cluster_acc)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str+'\n')

            if cluster_acc >= acc_init:
                acc_init = cluster_acc
                best_netF = model[0].state_dict()
                best_netB = model[1].state_dict()
                best_netC = model[2].state_dict()

            model.train()
                
    torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
    torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) 
    torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))

    return netF, netB, netC

def test_target_simp(args):
    """Testing the source model on the target data."""
    dset_loaders = data_load(args)
    netF = get_model(model_name=args.net_src, init_type=args.source_init).cuda()
    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir_src + '/source_F.pt'   
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_B.pt'   
    netB.load_state_dict(torch.load(args.modelpath)) 
    args.modelpath = args.output_dir_src + '/source_C.pt'   
    netC.load_state_dict(torch.load(args.modelpath))
    netF.eval()
    netB.eval()
    netC.eval()

    _, nmi, cluster_acc = cal_acc(dset_loaders['test'], netF, netB, netC, False)
    log_str = '\nTask: {}, NMI = {:.2f}, Cluster Acc = {:.2f}'.format(args.name, nmi, cluster_acc)

    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str + '\n')

def copy_target_simp(args):
    dset_loaders = data_load(args)

    netF = get_model(model_name=args.net_src, init_type=args.source_init).cuda()
    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
   
    args.modelpath = args.output_dir_src + '/source_F.pt'   
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_B.pt'   
    netB.load_state_dict(torch.load(args.modelpath)) 
    args.modelpath = args.output_dir_src + '/source_C.pt'   
    netC.load_state_dict(torch.load(args.modelpath))
    source_model = nn.Sequential(netF, netB, netC).cuda()
    source_model.eval()

    netF = get_model(model_name=args.net, init_type=args.target_init).cuda()
    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate*0.1}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]   
    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    ent_best = 1.0
    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // 10
    iter_num = 0

    model = nn.Sequential(netF, netB, netC).cuda()
    model.eval()

    start_test = True
    with torch.no_grad():
        iter_test = iter(dset_loaders["target_te"])
        for i in range(len(dset_loaders["target_te"])):
            data = iter_test.next()
            inputs, labels = data[0], data[1]
            inputs = inputs.cuda()
            outputs = source_model(inputs)
            outputs = nn.Softmax(dim=1)(outputs)
            outputs = F.one_hot(torch.argmax(outputs, dim=1), num_classes=args.class_num)
            outputs = (1-args.epsilon) * outputs + args.epsilon/outputs.shape[1]
 
            if start_test:
                all_output = outputs.float()
                all_label = labels
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float()), 0)
                all_label = torch.cat((all_label, labels), 0)
        mem_P = all_output.detach()

    model.train()
    args.epochs = args.max_epoch
    args.iters_per_epoch = len(dset_loaders["target"]) 
    proto_criterion = loss.ProtoMultiDomainLoss(num_domains=1, num_classes=args.class_num, ot_reg=args.ot_reg, args=args)
 
    while iter_num < max_iter:
 
        if args.ema < 1.0 and iter_num > 0 and iter_num % interval_iter == 0:
            model.eval()
            start_test = True
            with torch.no_grad():
                iter_test = iter(dset_loaders["target_te"])
                for i in range(len(dset_loaders["target_te"])):
                    data = iter_test.next()
                    inputs = data[0]
                    inputs = inputs.cuda()
                    outputs = model(inputs)
                    outputs = nn.Softmax(dim=1)(outputs)
                    if start_test:
                        all_output = outputs.float()
                        start_test = False
                    else:
                        all_output = torch.cat((all_output, outputs.float()), 0)
                mem_P = mem_P * args.ema + all_output.detach() * (1 - args.ema)
            model.train()

        try:
            inputs_target, y, tar_idx, domain_id = iter_target.next()
        except:
            iter_target = iter(dset_loaders["target"])
            inputs_target, y, tar_idx, domain_id = iter_target.next()
    
        if inputs_target.size(0) == 1:
            continue

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=1.5)
        inputs_target = inputs_target.cuda()
        with torch.no_grad():
            outputs_target_by_source = mem_P[tar_idx, :]
            _, src_idx = torch.sort(outputs_target_by_source, 1, descending=True)
        outputs_feature = model[1](model[0](inputs_target))
        outputs_target = model[2](outputs_feature) 
        outputs_target = torch.nn.Softmax(dim=1)(outputs_target)
        pred_labels = torch.argmax(outputs_target, dim=1).detach() 
        classifier_loss = nn.KLDivLoss(reduction='batchmean')(outputs_target.log(), outputs_target_by_source)
        optimizer.zero_grad()

        entropy_loss = torch.mean(loss.Entropy(outputs_target))
        msoftmax = outputs_target.mean(dim=0)
        gentropy_loss = torch.sum(- msoftmax * torch.log(msoftmax + 1e-5))
        entropy_loss -= gentropy_loss
        prototypes = model[-1].fc.weight 
        proto_loss = proto_criterion(prototypes, outputs_feature, pred_labels, domain_id)  
        classifier_loss += args.lambda_im * entropy_loss + args.lambda_proto * proto_loss

        classifier_loss.backward()

        if args.mix > 0:
            alpha = 0.3
            lam = np.random.beta(alpha, alpha)
            index = torch.randperm(inputs_target.size()[0]).cuda()
            bbx1, bby1, bbx2, bby2 = rand_bbox(inputs_target.size(), lam)
            inputs_target[:, :, bbx1:bbx2, bby1:bby2] = inputs_target[index, :, bbx1:bbx2, bby1:bby2]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs_target.size()[-1] * inputs_target.size()[-2]))
            mixed_output = (lam * outputs_target + (1 - lam) * outputs_target[index, :]).detach()

            update_batch_stats(model, False)
            outputs_target_m = model(inputs_target)
            update_batch_stats(model, True)
            outputs_target_m = torch.nn.Softmax(dim=1)(outputs_target_m)
            classifier_loss = args.mix*nn.KLDivLoss(reduction='batchmean')(outputs_target_m.log(), mixed_output)
            classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            model.eval()
            mean_ent, nmi, cluster_acc = cal_acc(dset_loaders['test'], netF, netB, netC, False)
            log_str = 'Task: {}, Iter:{}/{}; NMI = {:.2f}, Cluster Acc = {:.2f}, Ent = {:.4f}'.format(args.name, iter_num, max_iter, nmi, cluster_acc, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str+'\n')
            model.train()

    torch.save(netF.state_dict(), osp.join(args.output_dir, "source_F.pt"))
    torch.save(netB.state_dict(), osp.join(args.output_dir, "source_B.pt"))
    torch.save(netC.state_dict(), osp.join(args.output_dir, "source_C.pt"))

def update_batch_stats(model, flag):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.update_batch_stats = flag

def print_args(args):
    s = "==========================================\n"
    for arg, content in args.__dict__.items():
        s += "{}:{}\n".format(arg, content)
    return s

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='DINE')
    parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
    parser.add_argument('--t', type=int, default=0, help="target")
    parser.add_argument('--max_epoch', type=int, default=20, help="max iterations")
    parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
    parser.add_argument('--worker', type=int, default=4, help="number of workers")
    parser.add_argument('--dset', type=str, default='office-home', choices=['VISDA-C', 'office', 'image-clef', 'office-home', 'office-caltech', 'PACS'])
    parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
    parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101")
    parser.add_argument('--output', type=str, default='san')
    parser.add_argument('--lr_src', type=float, default=1e-2, help="learning rate")
    parser.add_argument('--net_src', type=str, default='resnet50', help="alexnet, vgg16, resnet18, resnet34, resnet50, resnet101")
    parser.add_argument('--output_src', type=str, default='san')  

    parser.add_argument('--seed', type=int, default=2020, help="random seed")
    parser.add_argument('--bottleneck', type=int, default=256)
    parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
    parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
    parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda'])
    parser.add_argument('--topk', type=int, default=1)

    parser.add_argument('--distill', action='store_true')
    parser.add_argument('--ema', type=float, default=0.6)
    parser.add_argument('--mix', type=float, default=1.0)

    parser.add_argument('--source_init', type=str, default="sup", choices=["ssl", "sup"])
    parser.add_argument('--target_init', type=str, default="sup", choices=["ssl", "sup"])

    parser.add_argument('--epsilon', type=float, default=0.1,
                        help='label smoothing parameter')
    parser.add_argument('--lambda_im', type=float, default=1.0,
                        help='weight for im loss.')
    parser.add_argument('--lambda_proto', type=float, default=1.0,
                        help='weight for proto loss.')
    parser.add_argument('--ot_reg', type=float, default=0.01,
                        help='sinkhorn regularization')
    parser.add_argument('--beta', type=float, default=0.9999,
                        help='smoothing parameter')
    parser.add_argument('--sub_s', action='store_true',
                        help='whether to subsample source data')
    parser.add_argument('--sub_t', action='store_true',
                        help='whether to subsample target data')
 





    args = parser.parse_args()
    if args.dset == 'office-home':
        names = ['Art', 'Clipart', 'Product', 'RealWorld']
        args.class_num = 65 

    if args.dset == 'office':
        names = ['amazon', 'dslr', 'webcam']
        args.class_num = 31

    if args.dset == 'PACS':
        names = ['art_painting', 'cartoon', 'photo', 'sketch']
        args.class_num = 7

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    SEED = args.seed
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    # torch.backends.cudnn.deterministic = True

    args.s = [x for x in range(len(names)) if x != args.t]
    folder = './data/'
    args.s_dset_path_list = []
    for s in args.s:
        args.s_dset_path = folder + args.dset + '/' + names[s] + '_list.txt'
        args.s_dset_path_list.append(args.s_dset_path)
    for i in range(len(names)):
        if i in args.s:
            continue
        args.t = i
    args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
    args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'

    args.name_src = "".join([names[s][0].upper() for s in args.s])
    args.output_dir_src = osp.join(args.output_src, args.net_src, str(args.seed), 'uda', args.dset, args.name_src)
    
    if not osp.exists(args.output_dir_src):
        os.system('mkdir -p ' + args.output_dir_src)
    if not osp.exists(args.output_dir_src):
        os.mkdir(args.output_dir_src)
        
    if not args.distill: 
        print(args.output_dir_src + '/source_F.pt')
        args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w')
        args.out_file.write(print_args(args)+'\n')
        args.out_file.flush()
        for i in range(len(names)):
            if i in args.s:
                continue
            args.t = i
            args.name = args.name_src + names[args.t][0].upper()
        train_source_simp(args)

        args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w')
        for i in range(len(names)):
            if i in args.s:
                continue
            args.t = i
            args.name = args.name_src + names[args.t][0].upper()
            args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
            args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
                            
            test_target_simp(args)

    if args.distill:
        for i in range(len(names)):
            if i in args.s:
                continue
            args.t = i
            args.name = args.name_src + names[args.t][0].upper()

            args.output_dir = osp.join(args.output, args.net_src + '_' + args.net, str(args.seed), args.da, args.dset, args.name_src+names[args.t][0].upper())
            if not osp.exists(args.output_dir):
                os.system('mkdir -p ' + args.output_dir)
            if not osp.exists(args.output_dir):
                os.mkdir(args.output_dir)

            args.out_file = open(osp.join(args.output_dir, 'log_tar.txt'), 'w')
            args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
            args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
                
            test_target_simp(args)
            copy_target_simp(args)
