import argparse
from cmath import exp
import os, sys
import os.path as osp
import torchvision
import numpy as np
import torch
import clip
import scipy.io
import re
import glob
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import network, loss_function
from torch.utils.data import DataLoader
import random, pdb, math, copy
import torch.nn.functional as F
from tqdm import tqdm
import pandas as pd
from scipy.spatial.distance import cdist
from sklearn.metrics import confusion_matrix
from sklearn.cluster import KMeans
import scipy.stats as stats
from torch.optim.lr_scheduler import StepLR
from scipy.io import loadmat
import loss_function
from utils import *
import argparse
import os, sys
import os.path as osp
import torchvision
import numpy as np
import torch
import clip
import scipy.io
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import network, loss_function
from torch.utils.data import DataLoader
import random, pdb, math, copy
from loss_function import CrossEntropyLabelSmooth
import warnings
from mmaction.structures import ActionDataSample
import os.path as osp
from mmengine.fileio import list_from_file
from mmengine.dataset import BaseDataset
from mmaction.registry import DATASETS
from mmengine.runner import Runner
from mmaction.registry import MODELS
import pdb
from mmaction.utils import register_all_modules
from mmaction.registry import METRICS
from collections import OrderedDict
from mmengine.evaluator import BaseMetric
from mmaction.evaluation import top_k_accuracy
import torch.optim as optim
from mmengine import track_iter_progress
from tqdm import tqdm
from network import *
from aggretation import *
from model_selection_tools import *
from collections import deque
from numpy import linalg as LA

def calculate_threshold(current_epoch, max_epoch):
    start_threshold = 0.8
    end_threshold = 0.95
    if current_epoch == 1:
        return start_threshold
    elif current_epoch == max_epoch:
        return end_threshold
    else:
        return ((end_threshold - start_threshold) / (max_epoch - 1)) * (current_epoch - 1) + start_threshold


def print_args(args):
    s = "==========================================\n"
    for arg, content in args.__dict__.items():
        s += "{}:{}\n".format(arg, content)
    return s

def extract_with_split(s):
    parts = s.split("_")
    return parts[1] if len(parts) > 2 else None

def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
    decay = (1 + gamma * iter_num / max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr0'] * decay
        param_group['weight_decay'] = 1e-3
        param_group['momentum'] = 0.9
        param_group['nesterov'] = True
    return optimizer

def read_text_file(file_path):
    try:
        with open(file_path, 'r') as file:
            lines = file.readlines()
            return [line.strip() for line in lines]
    except FileNotFoundError:
        print(f"文件 '{file_path}' 未找到。")

def load_mat(file):
    result=loadmat(file)
    #pdb.set_trace()
    feature = result['ft']
    a,b,c,d = feature.shape
    feature = feature.transpose(0, 2, 1, 3)
    #label = result['label'][0]
    label = np.mean(result['label'], axis=-1)
    output = result['output']
    pse=result['pse'][0]
   
    return feature,output,label,pse

def obtain_label(all_output, all_label, all_fea, args):

    all_output = nn.Softmax(dim=1)(all_output)
    ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
   
    _, predict = torch.max(all_output, 1)
    all_label = torch.squeeze(all_label).float()
    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    print(accuracy )
    if args.distance == 'cosine':
        all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()

    all_fea = all_fea.float().cpu().numpy()

    K = all_output.size(1)
    aff = all_output.float().cpu().numpy()
    initc = aff.transpose().dot(all_fea)
    initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
    cls_count = np.eye(K)[predict].sum(axis=0)
    labelset = np.where(cls_count>args.threshold)
    labelset = labelset[0]


    dd = cdist(all_fea, initc[labelset], args.distance)
    pred_label = dd.argmin(axis=1)
    pred_label = labelset[pred_label]

    for round in range(1):
        aff = np.eye(K)[pred_label]
        initc = aff.transpose().dot(all_fea)
        initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
        dd = cdist(all_fea, initc[labelset], args.distance)
        pred_label = dd.argmin(axis=1)
        pred_label = labelset[pred_label]

    acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
    log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)
    
    print(log_str+'\n')

    return pred_label.astype('int')


def construct_net(net,args):
    dataset_type = 'VideoDataset'
    data_root_val = 
    ann_file_test = 
    if net == 'i3d':
        i3d_nldot_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(i3d_nldot_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=10,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 256)),
            dict(type='ThreeCrop', crop_size=256),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
        batch_size=16,
        num_workers=8,
        persistent_workers=True,
        sampler=dict(type='DefaultSampler', shuffle=False),
        dataset=dict(
            type=dataset_type,
            ann_file=ann_file_test,
            data_prefix=dict(video=data_root_val),
            pipeline=test_pipeline,
            test_mode=True))
    elif net == 'c3d':
        c3d_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(c3d_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=16,
                frame_interval=1,
                num_clips=10,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 128)),
            dict(type='CenterCrop', crop_size=112),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=64,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    elif net == 'slowfast':
        slow_fast_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(slow_fast_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=10,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 256)),
            dict(type='ThreeCrop', crop_size=256),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=8,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    elif net == 'slowonly':
        slowonly_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(slowonly_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=4,
                frame_interval=16,
                num_clips=10,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 256)),
            dict(type='ThreeCrop', crop_size=256),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=8,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    elif net == 'swint':
        swint_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(swint_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=4,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 224)),
            dict(type='ThreeCrop', crop_size=224),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=4,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    elif net == 'swins':
        swins_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(swins_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=4,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 224)),
            dict(type='ThreeCrop', crop_size=224),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=4,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    elif net == 'swinb':
        swinb_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(swinb_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=4,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 224)),
            dict(type='ThreeCrop', crop_size=224),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=4,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    elif net == 'swinl':
        swinl_model['cls_head']['num_classes'] = args.class_num
        model = MODELS.build(swinl_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=4,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 224)),
            dict(type='ThreeCrop', crop_size=224),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=4,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    return model

def load_mat(file):
    result=loadmat(file)

    feature = result['ft']
    a,b,c,d = feature.shape
    feature = feature.transpose(0, 2, 1, 3)
    label = np.mean(result['label'], axis=-1)
    output = result['output']
    pse=result['pse'][0]
   
    return feature,output,label,pse

def load_mat2(file):
    result=loadmat(file)
    file_name = os.path.basename(file)

    model_name = file_name.split('_')[1]
    if model_name == 'slowfast':
        feature = result['ft']
        feature = feature.transpose(0, 2, 1, 3)
    else :
        feature = result['ft']
        feature = feature.reshape(feature.shape[0], feature.shape[1], -1)
        a,b,c = feature.shape
        num = result['label'].shape[0]
        num_clip = a//num
        feature = feature.reshape(num,b,num_clip,c)
    label = np.mean(result['label'], axis=-1)
    output = result['output']
   
    return feature,output,label


def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
    return optimizer

def model_initial_tc(net,path_list):
    param_group = []
 
  
    for i in range(len(path_list)):
        modelpath = path_list[i]
        net.netF[i].load_state_dict(torch.load(modelpath)["state_dict"])

        for k, v in net.netF[i].named_parameters():
            if "cls_head" in k:
                v.requires_grad = True

                param_group += [{'params': v, 'lr': args.lr }]
            else:
                v.requires_grad = False

    if hasattr(net, 'netQ'):  
        for k, v in net.netQ.named_parameters():
            param_group += [{'params': v, 'lr': args.lr }]

    optimizer = optim.AdamW(param_group)
    optimizer = op_copy(optimizer)
    return net,optimizer   

def aggregation_model(netF,model_name_set,model_path,tran_tran,model_features,model_outputs,model_label,model_features_t,model_features_s,model_timelist_set,args):   

    elif args.aggregation_method == 'IMVMA':
        noisy_gating=True
        num_experts = 3
        k=2
        embed_dim=4096
        model=IMVMA(netF,args.class_num,num_experts,model_features,model_name_set,model_features_t,model_features_s,model_timelist_set,model_outputs,model_label)

    model,optimizer=model_initial_tc(model,model_path)
    
    return model,optimizer

def cal_acc(test_data_loader, model, flag=False):
    start_test = True
    with torch.no_grad():
        for data_batch in tqdm(test_data_loader, desc='Processing batches'):
            all_fea,output,gt_labels,tran = model.forward(data_batch['inputs'], data_batch['idx1'],data_batch['data_samples'],gpu_id)
            data_idx = ' '.join(map(str, data_batch['idx1']))
            tran_val = ' '.join(map(str, tran))
            print("idx: " + data_idx)
            print("tran: " + tran_val)
            if start_test:
                    all_output = output.float().cpu()
                    all_label = gt_labels.float().cpu()
                    start_test = False
                        
            else:
                all_output = torch.cat((all_output, output.float().cpu()), 0)
                all_label = torch.cat((all_label, gt_labels.float().cpu()), 0)

        all_output = nn.Softmax(dim=1)(all_output)

        _, predict = torch.max(all_output, 1)

    pred_acc=torch.sum(torch.squeeze(predict).float()==torch.squeeze(all_label))/float(all_label.size()[0])
    mean_ent = torch.mean(loss_function.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
    all_output = all_output.numpy()
    all_label = all_label.numpy()
    all_label = np.round(all_label).astype(np.int64)
    all_output = all_output.tolist()
    all_label = all_label.tolist()
    mean1 = mean_class_accuracy(all_output,all_label)
    pred_label=np.array(predict)

    print('Accuracy = {:.2f}% '.format(pred_acc * 100))
    print('Accuracy_mean1 = {:.2f}% '.format(mean1 * 100))
  
    return pred_acc*100,mean1 * 100, mean_ent,pred_label.astype('int')

def train_target_STHC(model,test_dataloader,train_data_loader,optimizer,gpu_id,total_SUTE_I,args):

    max_iter = args.max_epoch * len(train_data_loader)
    interval_iter = max_iter // args.interval
    iter_num = 0
    epoch = 0 
    memory_bank = deque(maxlen=2)
    while iter_num < max_iter:
        try:
            batch = next(iter_test)
        except:
            iter_test = iter(train_data_loader)
            batch = next(iter_test)

        if batch['inputs'][0].size(0) == 1:
            continue


        # if iter_num == 0:
        if iter_num % interval_iter == 0 and args.cls_par > 0:
            start_test = True
            epoch = epoch +1
            with torch.no_grad():
                for data_batch in tqdm(test_dataloader, desc='Processing batches'):

                    fea,output,gt_labels,trans = model.forward(data_batch['inputs'], data_batch['idx1'],data_batch['data_samples'],gpu_id)
                    if start_test:
                        all_fea = fea.float().cpu()
                        all_output = output.float().cpu()
                        all_label = gt_labels.float().cpu()
                        start_test = False
                                
                    else:
                        all_fea = torch.cat((all_fea, fea.float().cpu()), 0)
                        all_output = torch.cat((all_output, output.float().cpu()), 0)
                        all_label = torch.cat((all_label, gt_labels.float().cpu()), 0)
            model.update_timefeature
            mem_label = obtain_label(all_output, all_label, all_fea, args)
            memory_bank.append(all_output)
        
        kl_result = []
        iter_num += 1
        batch_sutei = total_SUTE_I[:,batch['idx1']].cuda(gpu_id)
        batch_sutei = batch_sutei.transpose(0,1)
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
        all_fea,outputs_test,gt_labels,trans = model.forward(batch['inputs'], batch['idx1'],batch['data_samples'],gpu_id)
        inputs = batch['inputs']
        l2_loss = torch.norm(trans - batch_sutei, p=2)

        time_lengh=data_batch['inputs'][0].shape[2]
        list_time=list(range(time_lengh))
        sample_size=min(10,time_lengh-1)
        sampled_idx = np.random.choice(list_time, size=sample_size, replace=False)
        for i in range(len(data_batch['inputs'])):
            data_batch['inputs'][i][:,:,sampled_idx,:,:]=0

        all_fea_t,outputs_t_test,_,trans = model.forward_t(batch['inputs'], batch['idx1'],batch['data_samples'],gpu_id)
        kl_t = F.kl_div(F.log_softmax(outputs_t_test, dim=0), F.softmax(outputs_test, dim=0), reduction='sum')
        kl_result.append(kl_t)
        l2_loss_1 = torch.norm(trans - batch_sutei, p=2)
        batch['inputs'] = inputs
        for i in range(len(data_batch['inputs'])):
            data_batch['inputs'][i] = torch.flip(data_batch['inputs'][i], [3])
        all_fea_s,outputs_s_test,_,trans = model.forward_s(batch['inputs'], batch['idx1'],batch['data_samples'],gpu_id)
        kl_s = F.kl_div(F.log_softmax(outputs_s_test, dim=0), F.softmax(outputs_test, dim=0), reduction='sum')
        kl_result.append(kl_s)
        l2_loss_2 = torch.norm(trans - batch_sutei, p=2)
        if len(memory_bank) > 0 :
            list_kl = []
            for i in range(len(memory_bank)):
                idx = batch['idx1']
                output_old = memory_bank[i][idx].cuda(gpu_id)
                kl = F.kl_div(F.log_softmax(output_old, dim=0), F.softmax(outputs_test, dim=0), reduction='sum')
                list_kl.append(kl.item())
            kl_h = np.mean(list_kl)
            kl_h = torch.tensor(kl_h).cuda(gpu_id)
        kl_result.append(kl_h)
        if args.cls_par > 0:
            pred = mem_label[batch['idx1']]
            pred = torch.from_numpy(pred).cuda(gpu_id)
            classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred)
            classifier_loss *= args.cls_par
            if iter_num < interval_iter and args.dset == "VISDA-C":
                classifier_loss *= 0
        else:
            classifier_loss = torch.tensor(0.0).cuda(gpu_id)

        if args.ent:
            softmax_out = nn.Softmax(dim=1)(outputs_test)
            entropy_loss = torch.mean(loss_function.Entropy(softmax_out))
            if args.gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
                entropy_loss -= gentropy_loss
            im_loss = entropy_loss * args.ent_par
            classifier_loss += im_loss

        kl_tensor_list = torch.stack(kl_result)
        l2_loss = l2_loss+l2_loss_1+l2_loss_2
        average_kl = torch.mean(kl_tensor_list)
        classifier_loss = classifier_loss+ 0.1*average_kl +0.01*l2_loss


        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            model.eval()
            
            acc_s_te,mean1, _ ,_= cal_acc(test_dataloader, model, False)
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te)
            log_str1 = 'Task: {}, Iter:{}/{}; Accuracy_mean1 = {:.2f}%'.format(args.name, iter_num, max_iter, mean1)

            torch.save(model.state_dict(), osp.join(args.output_dir, "model" + ".pt"))
            model.train()

            args.out_file.write(log_str + '\n')
            args.out_file.write(log_str1 + '\n')
            args.out_file.flush()
            print(log_str+'\n')
            for i in range(model.source_model_number):
                model.netF[i].backbone.eval()
                model.netF[i].cls_head.train()


    torch.save(model.state_dict(), osp.join(args.output_dir, "model" + ".pt"))
      
        
    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='CAiDA')
    parser.add_argument('--t', type=int, default=3, help="target")
    parser.add_argument('--gpu_id', type=str, nargs='?', default='2', help="device id to run")
    parser.add_argument('--max_epoch', type=int, default=10, help="max iterations")
    parser.add_argument('--interval', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=128, help="batch_size")
    parser.add_argument('--worker', type=int, default=4, help="number of workers")
    parser.add_argument('--dset', type=str, default='Sports-DA', choices=[])
    parser.add_argument('--lr', type=float, default=1e-3, help="learning rate")
    parser.add_argument('--seed', type=int, default=2022, help="random seed")
    parser.add_argument('--bottleneck', type=int, default=256)
    parser.add_argument('--epsilon', type=float, default=1e-5)
    parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
    parser.add_argument('--smooth', type=float, default=0.1)

    parser.add_argument('--output_dir', type=str, default=)
    parser.add_argument('--fix', type=bool, default=False)
    parser.add_argument('--distance', type=str, default='euclidean', choices=["euclidean", "cosine"])  
    parser.add_argument('--threshold', type=int, default=0)
    parser.add_argument('--feature_path', type=str, default=)
    parser.add_argument('--featuret_path', type=str, default=)
    parser.add_argument('--features_path', type=str, default=)
    parser.add_argument('--model_config_path', type=str, default=)
    parser.add_argument('--transferability_save_path', type=str, default=)
    parser.add_argument('--config_file', type=str, default=, choices=[])
    parser.add_argument('--trans_method', type=str, default="", choices=[])
    parser.add_argument('--aggregation_method', type=str, default="IMVMA", choices=[])
    parser.add_argument('--models', type=int, default=3)
    parser.add_argument('--diversity', type=bool, default=True)
    parser.add_argument("--r", type=int, default=3)

    parser.add_argument('--gent', type=bool, default=True)
    parser.add_argument('--ent', type=bool, default=True)
    parser.add_argument('--cls_par', type=float, default=0.3)
    parser.add_argument('--ent_par', type=float, default=1)
    
    args = parser.parse_args()
    register_all_modules(init_default_scope=True)
    file_client_args = dict(io_backend='disk')
    
    dataset_type = 'VideoDataset'

    aggregation_method = args.aggregation_method


    if args.dset == 'Sports-DA':
        names = ['sports1m', 'UCF101']
        args.class_num = 23



    gpu_id = int(args.gpu_id)

   
    if args.dset == 'Sports-DA':
        feature_path= 
 

    if args.dset == 'Sports-DA':
        model_path = 

    test_pipeline = [
                dict(type='DecordInit', **file_client_args),
                dict(
                    type='SampleFrames',
                    clip_len=16,
                    frame_interval=1,
                    num_clips=10,
                    test_mode=True),
                dict(type='DecordDecode'),
                dict(type='Resize', scale=(-1, 128)),
                dict(type='CenterCrop', crop_size=112),
                dict(type='FormatShape', input_format='NCTHW'),
                dict(type='PackActionInputs')
            ]


    test_dataloader = dict(
    batch_size=8,
    num_workers=8,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        ann_file=ann_file_test,
        data_prefix=dict(video=data_root_val),
        pipeline=test_pipeline,
        test_mode=True))

    train_pipeline = [
                dict(type='DecordInit', **file_client_args),
                dict(
                    type='SampleFrames',
                    clip_len=16,
                    frame_interval=1,
                    num_clips=10,
                    test_mode=True),
                dict(type='DecordDecode'),
                dict(type='Resize', scale=(-1, 128)),
                dict(type='CenterCrop', crop_size=112),
                dict(type='FormatShape', input_format='NCTHW'),
                dict(type='PackActionInputs')
            ]

    train_dataloader = dict(
    batch_size=8,
    num_workers=8,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        ann_file=ann_file_test,
        data_prefix=dict(video=data_root_val),
        pipeline=test_pipeline))

    # agg_type = 'all'
    agg_type = 'top'
    top_number = "3"
    output_dir_ori = args.output_dir
    method = args.trans_method
    strategy = ''
    for method in ["SUTE_I"]:
        folder = '/data/liruizhe/'
        for tt in range(len(names)):
            for s in range(len(names)):
                args.output_dir = output_dir_ori
                model_list = {}
                model_pth_set = []
                model_name_set = {}
                model_timelist_set = {}
                model_features={}
                model_features_s={}
                model_features_t={}
                model_outputs={}
                model_label = {}
                agg_model = nn.ModuleList()
                if str(tt) == str(s):
                    continue
                args.s = s
                args.t=tt
                args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
                args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
                model_config_file = args.model_config_path+args.config_file+'/'+str(args.s)+'_'+str(args.t)+'.txt'
                file_contents = read_text_file(model_config_file)
                print(model_config_file)
                if not os.path.exists(args.transferability_save_path+args.config_file):
                    os.makedirs(args.transferability_save_path+args.config_file)
                transferability_output_filename = args.transferability_save_path+args.config_file+'/'+method+'/'+str(args.s)+'_'+str(args.t)+'.txt'
                if not os.path.exists(args.transferability_save_path+args.config_file+'/'+method):
                        os.makedirs(args.transferability_save_path+args.config_file+'/'+method)
                if agg_type == 'top':
                    if method == "SUTE_I":
                        ins_trans_list,pred_acc,all_output,ins_trans_list_i,pred_acc_set,line_set = transfer_calcualte_for_individual_models(model_config_file,feature_path,transferability_output_filename,method,strategy,args=args)
                        sorted_model_names, sorted_transferability_metrics,sorted_transferability_i,sorted_output = sort_lists_by_transferability_1(line_set, ins_trans_list,ins_trans_list_i,all_output)
                        sorted_transferability_i = np.array(sorted_transferability_i)
                        sorted_transferability_i_list = []
                        for row in sorted_transferability_i:
                            # 将每行转换为PyTorch张量
                            tensor = torch.tensor(row)
                            # 将张量添加到列表中
                            sorted_transferability_i_list.append(tensor)
                    else:
                        all_output=transfer_calcualte_for_individual_models(model_config_file,feature_path,transferability_output_filename,method,strategy,args=args)
                        model_names,transferability_metrics,accuracies= read_transferability_text_file(transferability_output_filename)
                        sorted_model_names, sorted_transferability_metrics,sorted_output = sort_lists_by_transferability(model_names, transferability_metrics,all_output)
                  

                elif agg_type == 'all':
                    model_select = file_contents
                # model_select = transfer_model_set[:3]
                # model_select = random.sample(file_contents[1:], 3)
                # pdb.set_trace()
                # model_select = ['Daily-DA_slowfast_1_0.mat', 'Daily-DA_swint_1_0.mat', 'Daily-DA_c3d_1_0.mat']
                # model_select = {0: 'i3d', 1: 'swint', 2: 'c3d'}
                for i,selected_model in enumerate(model_select):
                    source= selected_model.split(".mat")[0][-3]
                    target = selected_model.split(".mat")[0][-1]
                    
                    elif args.dset == 'Sports-DA':
                        if source == '0':
                            source_file ='sports1m'
                        elif source == '1':
                            source_file = 'UCF101'
                        if target == '0':
                            target_name = 'sports1m'
                        elif target == '1':
                            target_name = 'UCF101'
                   

                    net = extract_with_split(selected_model)
                    model_time_list = []
                    fea,output,label,_=load_mat(feature_path+selected_model)
                    fea_t,_,_=load_mat2(featuret_path+selected_model)
                    fea_s,_,_=load_mat2(features_path+selected_model) 
                    for j in range(len(featuret_path_list)):
                        model_time_list.append(featuret_path_list[j]+selected_model)
                    model_timelist_set[i] = model_time_list

                    model_features[i]=torch.from_numpy(fea)
                    model_features_t[i]=torch.from_numpy(fea_t)
                    model_features_s[i]=torch.from_numpy(fea_s)
                    model_outputs[i]=torch.from_numpy(output)
                    model_label[i] = torch.from_numpy(label)
                    model_name_set[i] = net 
                    model_pth = osp.join(model_path, net, source_file)
                    pth_files = glob.glob(os.path.join(model_pth, '**/*.pth'), recursive=True)
                    model = construct_net(net,args)
                    model.cls_head.num_classes = args.class_num
                    pth_files = pth_files[0]
                    model_list[i] = model
                    model_pth_set.append(pth_files)
                    agg_model.append(model_list[i])

                print("selected models",model_name_set)
                print("lenth:",len(model_name_set))
                tran_tran=[1,]*len(model_name_set)
                model,optimizer=aggregation_model(agg_model,model_name_set,model_pth_set,tran_tran,model_features,model_outputs,model_label,model_features_t,model_features_s,model_timelist_set,args)
                
                mode=

                args.output_dir = osp.join(args.output_dir, args.dset,args.config_file,aggregation_method,agg_type,top_number,method,mode,names[args.s][0].upper(), names[args.t][0].upper())


                if not os.path.exists(args.output_dir):
                    os.makedirs(args.output_dir)

                args.name = "->"+names[args.t][0].upper()

                if not osp.exists(args.output_dir):
                    os.system('mkdir -p ' + args.output_dir)
                if not osp.exists(args.output_dir):
                    os.mkdir(args.output_dir)

                args.out_file = open(osp.join(args.output_dir, 'log' + '.txt'), 'w')
                args.out_file.write(print_args(args)+'\n')
                args.out_file.flush()

                elif target_name =='sports1m':
                    test_dataloader['dataset']['ann_file'] = '/data/liruizhe/Sports-DA/sports1m_msda_test.txt'
                    test_dataloader['dataset']['data_prefix']['video'] = '/data/liruizhe/Sports-DA/sports1m/'
                    test_data_loader = Runner.build_dataloader(dataloader=test_dataloader)
                    train_dataloader['dataset']['ann_file'] = '/data/liruizhe/Sports-DA/sports1m_msda_test.txt'
                    train_dataloader['dataset']['data_prefix']['video'] = '/data/liruizhe/Sports-DA/sports1m/'
                    train_data_loader = Runner.build_dataloader(dataloader=train_dataloader)
                    model=model.cuda(gpu_id)
                    train_target_STHC(model,test_data_loader,train_data_loader,optimizer,gpu_id,transferability_i_select,args)


                elif target_name =='UCF101':

                    test_dataloader['dataset']['ann_file'] = '/data/liruizhe/Sports-DA/ucf101_msda_test.txt'
                    test_dataloader['dataset']['data_prefix']['video'] = '/data/liruizhe/Sports-DA/UCF101/test/'
                    test_data_loader = Runner.build_dataloader(dataloader=test_dataloader)
                    train_dataloader['dataset']['ann_file'] = '/data/liruizhe/Sports-DA/ucf101_msda_test.txt'
                    train_dataloader['dataset']['data_prefix']['video'] = '/data/liruizhe/Sports-DA/UCF101/test/'
                    train_data_loader = Runner.build_dataloader(dataloader=train_dataloader)
                    model=model.cuda(gpu_id)
                    train_target_STHC(model,test_data_loader,train_data_loader,optimizer,gpu_id,transferability_i_select,args)
             