import argparse
from cmath import exp
import os, sys
import os.path as osp
import torchvision
import numpy as np
import torch
import torch.nn.functional as F
import scipy.io
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import network, loss_function
from torch.utils.data import DataLoader
import random, pdb, math, copy
from tqdm import tqdm
import pandas as pd
from scipy.spatial.distance import cdist
from sklearn.metrics import confusion_matrix
from sklearn.cluster import KMeans
import scipy.stats as stats
from torch.optim.lr_scheduler import StepLR
from mmd import mmd_rbf
from a_distance import calculate_a_distance
from scipy.io import loadmat
import re
import statistics
from accuracy import *

from transfer_metrics import LogME,NCE,LEEP,NCE_ours_addpz
logme = LogME(regression=False)

def find_file(search_str, directory):

    for root, dirs, files in os.walk(directory):
        for file in files:
            if search_str in file:
                return os.path.join(root, file)
    return None

def extract_with_split(s):
    parts = s.split("_")
    return parts[1] if len(parts) > 2 else None


def custom_normalize(data):

    max_val = torch.max(data)
    min_val = torch.min(data)
    
    normalized_data = (max_val - data) / (max_val - min_val)
    
    return normalized_data

def normalize(lst):
    lst_min = min(lst)
    lst_max = max(lst)
    return [(x - lst_min) / (lst_max - lst_min) for x in lst]

def pse_principle(all_output, all_label, all_fea, args,tau=0.5):
    all_output=torch.from_numpy(all_output)
    all_label=torch.from_numpy(all_label)

    all_fea=torch.from_numpy(all_fea)
    _, predict = torch.max(all_output, 1)
    pred_acc=torch.sum(predict==all_label)/float(all_output.size()[0])


    pse_label=obtain_label(all_output, all_label, all_fea, args)


    return pse_label

def obtain_label(all_output, all_label, all_fea, args):


    all_output = nn.Softmax(dim=1)(all_output)
    ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
   
    _, predict = torch.max(all_output, 1)
    all_label = torch.squeeze(all_label).float()
    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    if args.distance == 'cosine':
        all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()

    all_fea = all_fea.float().cpu().numpy()

    K = all_output.size(1)
    aff = all_output.float().cpu().numpy()
    initc = aff.transpose().dot(all_fea)
    initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
    cls_count = np.eye(K)[predict].sum(axis=0)
    labelset = np.where(cls_count>args.threshold)
    labelset = labelset[0]

    dd = cdist(all_fea, initc[labelset], args.distance)
    pred_label = dd.argmin(axis=1)
    pred_label = labelset[pred_label]

    for round in range(1):
        aff = np.eye(K)[pred_label]
        initc = aff.transpose().dot(all_fea)
        initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
        dd = cdist(all_fea, initc[labelset], args.distance)
        pred_label = dd.argmin(axis=1)
        pred_label = labelset[pred_label]

    acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
    log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)


    return pred_label.astype('int')

def get_accuracy_from_file(file_path, model_name):
    with open(file_path, 'r') as file:
        lines = file.readlines()

    for line in lines:
        parts = line.split(' - acc/mean1: ')
        if len(parts) == 2:
            name, accuracy_str = parts[0], parts[1]
            if name.strip() == model_name:
                try:
                    accuracy = float(accuracy_str.strip())
                    return accuracy
                except ValueError:
                    return "Error: Accuracy value is not a valid number."
    
    return "Model name not found in the file."


def obtain_label_cpu(all_output, all_label, all_fea, args):


    all_output = nn.Softmax(dim=1)(all_output)
    ent = torch.sum(-all_output * torch.log(all_output + 1e-5), dim=1)
   
    _, predict = torch.max(all_output, 1)

    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    if args.distance == 'cosine':
        all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()

    all_fea = all_fea.float().cpu().numpy()
    K = all_output.size(1)
    aff = all_output.float().cpu().numpy()
    initc = aff.transpose().dot(all_fea)
    initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
    cls_count = np.eye(K)[predict].sum(axis=0)
    labelset = np.where(cls_count>args.threshold)
    labelset = labelset[0]
    # print(labelset)

    dd = cdist(all_fea, initc[labelset], args.distance)
    pred_label = dd.argmin(axis=1)
    pred_label = labelset[pred_label]

    for round in range(1):
        aff = np.eye(K)[pred_label]
        initc = aff.transpose().dot(all_fea)
        initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
        dd = cdist(all_fea, initc[labelset], args.distance)
        pred_label = dd.argmin(axis=1)
        pred_label = labelset[pred_label]

    #acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
    #log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)
    
    #print(log_str+'\n')

    return pred_label.astype('int')



class LogisticRegressionModel(nn.Module):
    def __init__(self,feature_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear1 = nn.Linear(feature_dim, 256)
        self.linear2 = nn.Linear(256, 1)

    def forward(self, x):
        x=self.linear1(x)
        x=self.linear2(x)
        return torch.sigmoid(x)

def get_reweighted_weights(data,data_no_shuffle,model,epoches=3):
    criterion = nn.BCELoss() 
    model=model.cuda()
    optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.01)
    model.train()
    
    scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
    for epoch in range(epoches):
        scheduler.step()
        for input,label in data:
            
            input, label = input.cuda(), label.cuda()
            outputs = model(input)
            label=label.float().unsqueeze(1)
            #pdb.set_trace()
            #print(label.shape,outputs.shape,input.shape)
            
            loss = criterion(outputs, label)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if (epoch + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{epoches}], Loss: {loss.item():.4f}')

    model.eval()
    with torch.no_grad():
        start_test=True
        for input,label in data_no_shuffle:

            input, label = input.cuda(), label.cuda()
            label=label.float().unsqueeze(1)
            outputs = model(input)
            loss = criterion(outputs, label)
            
            # Backward pass and optimization
           
            if start_test:
                all_output = outputs.float().cuda()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cuda()), 0)

    return all_output              

def GroupN_principle(all_output, all_label, all_fea, pse_label,args):
    try:
        all_output=torch.from_numpy(all_output)
        all_label=torch.from_numpy(all_label)
        all_fea=torch.from_numpy(all_fea)
        pse_label=torch.from_numpy(pse_label).long()
    except:
        pass
    pred = nn.Softmax(dim=1)(all_output)
    epsilon = 1e-5
    ent = torch.sum(-pred * torch.log(pred + epsilon), dim=1)
   
    _, predict = torch.max(pred, 1)
    all_label = torch.squeeze(all_label).float()
    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])

    all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
    all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()

    all_fea = all_fea.float().cpu().numpy()
 
    K = pred.size(1)
    aff = pred.float().cpu().numpy()
    initc = aff.transpose().dot(all_fea)
    initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
    cls_count = np.eye(K)[predict].sum(axis=0)
    labelset = np.where(cls_count>0)
    labelset = labelset[0]

    n = all_fea.shape[0]
    v= 6
    min_group = np.zeros((n, v), dtype=int)
    min_group_values = np.zeros((n, v))
    min_dist = np.zeros(n)
    entropies = np.zeros(n)

    dd = cdist(all_fea, all_fea, 'cosine')
    pred_label = dd.argmin(axis=1)
    dict_c = {}
    
    min_group = np.zeros((n, v), dtype=int)
    min_group_values = np.zeros((n, v))
    min_dist = np.zeros(n)
    entropies = np.zeros(n)
    dd = cdist(all_fea, all_fea, 'cosine')
    for i in range(n):
        row = dd[i]
        sorted_row = np.argsort(row)  
        min_group[i] = sorted_row[:v]  
        min_group_values[i] = row[sorted_row[:v]]
        
        min_dist[i] = min_group_values[i][v-1] - min_group_values[i][0]


    group_trans = min_dist.mean()
    all_output = all_output.numpy()
    all_label = all_label.numpy()
    all_label = np.round(all_label).astype(np.int64)
    all_output = all_output.tolist()
    all_label = all_label.tolist()
    mean1 = mean_class_accuracy(all_output,all_label)
    
    pred_acc = mean1

    return min_dist,group_trans,pred_acc              


def ce_pse_principle(all_output, all_label, all_fea, pse_label, args):
    try:
        all_output=torch.from_numpy(all_output)
        all_label=torch.from_numpy(all_label)
        all_fea=torch.from_numpy(all_fea)
        pse_label=torch.from_numpy(pse_label).long()
    except:
        pass
    _, predict = torch.max(all_output, 1)
    pred_acc=torch.sum(predict==all_label)/float(all_output.size()[0])


    all_output=all_output.cuda()
    with torch.no_grad():
        loss=nn.CrossEntropyLoss()(all_output,pse_label)
       
    loss=loss.item()

    return loss,pred_acc

def ce_nce_pse_principle(all_output, all_label, all_fea, pse_label, args):
    try:
        all_output=torch.from_numpy(all_output)
        all_label=torch.from_numpy(all_label)
        all_fea=torch.from_numpy(all_fea)
        pse_label=torch.from_numpy(pse_label).long()
    except:
        pass
    _, predict = torch.max(all_output, 1)
    pred_acc=torch.sum(predict==all_label)/float(all_output.size()[0])


    #pse_label=obtain_label(all_output, all_label, all_fea, args)
    #pse_label=torch.from_numpy(pse_label).long()
    pse_label1=get_one_hot(all_output,pse_label)#np.array(pse_label)
    predict1=get_one_hot(all_output,predict)#np.array(predict)
   
    with torch.no_grad():
        loss1=nn.CrossEntropyLoss(reduction='none')(all_output,pse_label)
        loss2=nn.CrossEntropyLoss(reduction='none')(predict1,pse_label1)#+0.001

        loss=torch.sum(loss2)/(predict1.shape[0])
    loss=loss.item()

    return loss,pred_acc

def get_one_hot(inputs,targets):
    log_probs =torch.log(inputs) #self.logsoftmax(inputs)
    targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).long(), 1)
    return targets


def LogME_pse_principle(all_output, all_label, all_fea, pse_label, args):
    try:
        all_output=torch.from_numpy(all_output)
        all_label=torch.from_numpy(all_label)
        all_fea=torch.from_numpy(all_fea)
        pse_label=torch.from_numpy(pse_label).long()
    except:
        pass
    logme = LogME(regression=False)
    # f has shape of [N, D], y has shape [N]
  
    #pdb.set_trace()
    #all_fea=torch.from_numpy(all_fea)
    _, predict = torch.max(all_output, 1)
    all_output = all_output.numpy()
    all_label = all_label.numpy()
    all_label = np.round(all_label).astype(np.int64)
    all_output = all_output.tolist()
    all_label = all_label.tolist()
    mean1 = mean_class_accuracy(all_output,all_label)
    
    pred_acc = mean1


    all_output=nn.Softmax(dim=1)(all_output)
    all_fea=np.array(all_fea)
    pse_label=np.array(pse_label)
    loss=logme.fit(all_fea,pse_label)


    return loss,pred_acc


def Leep_pse_principle(all_output, all_label, all_fea, pse_label, args):
    try:
        all_output=torch.from_numpy(all_output)
        all_label=torch.from_numpy(all_label)
        all_fea=torch.from_numpy(all_fea)
        pse_label=torch.from_numpy(pse_label).long()
    except:
        pass
    _, predict = torch.max(all_output, 1)
    all_output = all_output.numpy()
    all_label = all_label.numpy()
    all_label = np.round(all_label).astype(np.int64)
    all_output = all_output.tolist()
    all_label = all_label.tolist()
    mean1 = mean_class_accuracy(all_output,all_label)
    
    pred_acc = mean1

    
    #pse_label=obtain_label(all_output, all_label, all_fea, args)
    #pse_label=torch.from_numpy(pse_label)
    #pse_label=get_one_hot(all_output,pse_label)
    all_output=nn.Softmax(dim=1)(all_output)
    #with torch.no_grad():
    all_output=np.array(all_output)
    pse_label=np.array(pse_label)
    
    loss=LEEP(all_output,pse_label)
    #print(loss)
    #loss=loss.item()

    return loss,pred_acc


def entropy_loss(p):

    epsilon = 1e-5
    entropy = -p * torch.log(p + epsilon)

    loss = torch.sum(entropy, dim=1)
    return loss

def entropy_principle(all_output, all_label, all_fea,pse_label, args):
    try:
        all_output=torch.from_numpy(all_output)
        all_label=torch.from_numpy(all_label)
        all_fea=torch.from_numpy(all_fea)
        pse_label=torch.from_numpy(pse_label).long()
    except:
        pass
    _, predict = torch.max(all_output, 1)
    pred=nn.Softmax(dim=1)(all_output).cuda()

    all_output = all_output.numpy()
    all_label = all_label.numpy()
    all_label = np.round(all_label).astype(np.int64)
    all_output = all_output.tolist()
    all_label = all_label.tolist()
    mean1 = mean_class_accuracy(all_output,all_label)
    
    pred_acc = mean1
    loss=torch.sum(entropy_loss(pred))/(pred.shape[0])
    loss=-loss
    return loss,pred_acc


def MI_principle(all_output, all_label, all_fea,pse_label, args):
    try:
        all_output=torch.from_numpy(all_output)
        all_label=torch.from_numpy(all_label)
        all_fea=torch.from_numpy(all_fea)
        pse_label=torch.from_numpy(pse_label).long()
    except:
        pass
    _, predict = torch.max(all_output, 1)
    pred=nn.Softmax(dim=1)(all_output).cuda()
    softmax_out = pred
    entropy_loss = torch.mean(loss_function.Entropy(softmax_out))
    msoftmax = softmax_out.mean(dim=0)
    gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
    entropy_loss -= gentropy_loss
    all_output = all_output.numpy()
    all_label = all_label.numpy()
    all_label = np.round(all_label).astype(np.int64)
    all_output = all_output.tolist()
    all_label = all_label.tolist()
    mean1 = mean_class_accuracy(all_output,all_label)
    
    pred_acc = mean1
    #loss=torch.sum(entropy_loss(pred))/(pred.shape[0])
    entropy_loss=-entropy_loss
    return entropy_loss,pred_acc



def MDE_principle(all_output, all_label, all_fea,pse_label, args,T=1):
    try:
        all_output=torch.from_numpy(all_output)
        all_label=torch.from_numpy(all_label)
        all_fea=torch.from_numpy(all_fea)
        pse_label=torch.from_numpy(pse_label).long()
    except:
        pass
    _, predict = torch.max(all_output, 1)
    #pdb.set_trace()
    energy = -T * (torch.logsumexp(all_output / T, dim=1))

    avg_energies = torch.log_softmax(energy, dim=0).mean()
    avg_energies = torch.log(-avg_energies).item()
  
 
    all_output = all_output.numpy()
    all_label = all_label.numpy()
    all_label = np.round(all_label).astype(np.int64)
    all_output = all_output.tolist()
    all_label = all_label.tolist()
    mean1 = mean_class_accuracy(all_output,all_label)
    
    pred_acc = mean1
    #loss=torch.sum(entropy_loss(pred))/(pred.shape[0])
    #entropy_loss=-entropy_loss
    return avg_energies,pred_acc



def ce_principle(all_output, all_label, all_fea, pse_label,args):
    try:
        all_output=torch.from_numpy(all_output)
        all_label=torch.from_numpy(all_label)
        all_fea=torch.from_numpy(all_fea)
        pse_label=torch.from_numpy(pse_label).long()
    except:
        pass
    _, predict = torch.max(all_output, 1)
    pred_acc=torch.sum(predict==all_label)/float(all_output.size()[0])


    pse_label=predict
    all_output = all_output.cuda()
    pse_label = pse_label.cuda()
    with torch.no_grad():
        loss=nn.CrossEntropyLoss()(all_output,pse_label)
     
    loss=loss.item()

    return loss,pred_acc

def Entropy(input_):
    bs = input_.size(0)
    epsilon = 1e-5
    entropy = -input_ * torch.log(input_ + epsilon)
    entropy = torch.sum(entropy, dim=1)
    return entropy 
def MI(pred):
   
    softmax_out = pred
    entropy_loss =torch.mean(Entropy(softmax_out))
 
    msoftmax = softmax_out.mean(dim=0)
    gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
    #entropy_loss -= gentropy_loss
    return entropy_loss,gentropy_loss


def MI2(pred):
   
    softmax_out = pred
    entropy_loss =torch.mean(Entropy(softmax_out))
    _,pred=torch.max(softmax_out,1)
    softmax_out=torch.eye(23)[pred]
    msoftmax = softmax_out.mean(dim=0)
    gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
    #entropy_loss -= gentropy_loss
    return entropy_loss,gentropy_loss


def Temporal_principle(all_output, all_label, all_fea, pse_label,feature_path_t,args,feature_source,line):
    try:
        all_output=torch.from_numpy(all_output)
        all_label=torch.from_numpy(all_label)
        all_fea=torch.from_numpy(all_fea)
        pse_label=torch.from_numpy(pse_label).long()
    except:
        pass
    directory = feature_path_t

    file_path = find_file(line, directory)
    net = extract_with_split(line)

    all_fea_s,all_output_s,all_label_s = load_mat2(file_path)
    try:
        all_fea_s=torch.from_numpy(all_fea_s)
        all_output_s=torch.from_numpy(all_output_s)
        all_label_s=torch.from_numpy(all_label_s)
    except:
        pass

    kl_t = F.kl_div(F.log_softmax(all_output_s, dim=0), F.softmax(all_output, dim=0), reduction='sum')

    print(kl_t)
    _, predict = torch.max(all_output, 1)
    all_output = all_output.numpy()
    all_label = all_label.numpy()
    all_label = np.round(all_label).astype(np.int64)
    all_output = all_output.tolist()
    all_label = all_label.tolist()
    mean1 = mean_class_accuracy(all_output,all_label)
    
    pred_acc = mean1
    return kl_t,pred_acc
    # return euclidean_distance,pred_acc

    
def SUTE_principle(all_output, all_label, all_fea, pse_label,args):
    try:
        all_output=torch.from_numpy(all_output)
        all_label=torch.from_numpy(all_label)
        all_fea=torch.from_numpy(all_fea)
        pse_label=torch.from_numpy(pse_label).long()
    except:
        pass
    _, predict = torch.max(all_output, 1)
    pred=nn.Softmax(dim=1)(all_output)
   
    entropy_loss,diversity_loss =  MI2(pred)

    all_output = all_output.numpy()
    all_label = all_label.numpy()
    all_label = np.round(all_label).astype(np.int64)
    all_output = all_output.tolist()
    all_label = all_label.tolist()
    mean1 = mean_class_accuracy(all_output,all_label)
    
    pred_acc = mean1

    all_label=np.array(all_label)
    pse_label=np.array(pse_label)
    predict=np.array(predict)

   
    nce=NCE(predict,pse_label)
    nentropy_loss=-entropy_loss

    return nce,nentropy_loss,diversity_loss,pred_acc
def calculate_entropy(n, c):

    probability = 1 / c
    

    log_term = -torch.log(torch.tensor(probability))
    

    entropy = n * probability * log_term
    return entropy.item() 
def read_text_file(file_path):
    try:
        with open(file_path, 'r') as file:
            lines = file.readlines()
            return [line.strip() for line in lines]
    except FileNotFoundError:
        print(f"文件 '{file_path}' 未找到。")


def load_mat1(file):
    result=loadmat(file)
    #pdb.set_trace()
    feature = result['ft']
    label = np.mean(result['label'], axis=-1)
    output = result['output']
    pse=result['pse'][0]
   
    return feature,output,label,pse

def load_mat2(file):
    result=loadmat(file)
    #pdb.set_trace()
    feature = result['ft']
    label = np.mean(result['label'], axis=-1)
    output = result['output']
   
    return feature,output,label

def load_mat(file):
    result=loadmat(file)
    # pdb.set_trace()
    feature = result['ft']
    a,b,c,d = feature.shape
    feature = feature.transpose(0, 2, 1, 3).reshape(a, c, -1)
    feature = torch.from_numpy(feature).cuda()
    pool = nn.AdaptiveAvgPool1d(1).cuda()
    feature_pool = pool(feature)
    # feature_pool = np.mean(feature, axis=-1)
    #label = result['label'][0]
    feature_pool = np.array(feature_pool.cpu()).squeeze(axis=-1)
    label = np.mean(result['label'], axis=-1)
    output = result['output']
    pse=result['pse'][0]
   
    return feature_pool,output,label,pse

def load_mat3(file):
    result=loadmat(file)
    file_name = os.path.basename(file)
    # pdb.set_trace()
    model_name = file_name.split('_')[1]

    if model_name == 'slowfast':
        feature = result['ft']
        a,b,c,d = feature.shape
        feature = feature.transpose(0, 2, 1, 3).reshape(a, c, -1)
        # pdb.set_trace()
    else :
        feature = result['ft']

        feature = feature.reshape(feature.shape[0], feature.shape[1], -1)
        #feature1 = feature[:30] 
        a,b,c = feature.shape
        num = result['label'].shape[0]
        num_clip = a//num
        feature = feature.reshape(num,num_clip,b,c)
        feature = feature.transpose(0, 2, 1, 3).reshape(num, b, -1)
    feature = np.mean(feature, axis=-1)
    # pdb.set_trace()
    label = np.mean(result['label'], axis=-1)
    output = result['output']
    
    return feature,output,label


def transfer_calcualte_for_a_model(method,output,label,feature,pse,args,feature_source=0,line=''):

    if method == "entropy":
        transfer_ability,pred_acc=entropy_principle(output,label,feature,pse,args)
    elif method == "MI":
        transfer_ability,pred_acc=MI_principle(output,label,feature,pse,args)
    elif method == "logme_pse":
        transfer_ability,pred_acc=LogME_pse_principle(output,label,feature,pse,args)
    elif method == "Leep_pse":
        transfer_ability,pred_acc=Leep_pse_principle(output,label,feature,pse,args)
    elif method == "MDE":
        transfer_ability,pred_acc=MDE_principle(output,label,feature,pse,args,T=1)
    elif method == "Temporal":
        transfer_ability,pred_acc=Temporal_principle(output,label,feature,pse,args,feature_source,line)
    elif method=='SUTE':
        nce,nentropy_loss,diversity_loss,pred_acc=SUTE_principle(output,label,feature,pse,args)
    elif method == "GroupN":
        min_dist,transfer_ability,pred_acc=GroupN_principle(output,label,feature,pse,args)


        transfer_ability=1*nce+1*nentropy_loss+1*diversity_loss

        print(nce,nentropy_loss,diversity_loss)

    try:
        pred_acc=pred_acc.item()
        transfer_ability=transfer_ability.item()
    except:
        pass

    return transfer_ability,pred_acc

def transfer_calcualte_for_a_model_Temporal(method,output,label,feature,pse,feature_path_t,args,feature_source=0,line=''):

    if method == "Temporal":
        transfer_ability,pred_acc=Temporal_principle(output,label,feature,pse,feature_path_t,args,feature_source,line)

    try:
        pred_acc=pred_acc.item()
        transfer_ability=transfer_ability.item()
    except:
        pass

    return transfer_ability,pred_acc



def transfer_calcualte_for_individual_models(model_config_file,feature_path,transferability_output_filename,method,strategy,args):
    print(model_config_file)
    print(transferability_output_filename)

    file_contents = read_text_file(model_config_file)
    loss_set=[]
    pred_acc_set=[]
    pred_acc_set_all=[]
    all_output=[]

    nce_set = []
    nentropy_loss_set = []
    diversity_loss_set = []
    mde_set = []

    # pdb.set_trace()
    if file_contents:
        if method == 'SUTE_I':
            ins_trans_list,pred_acc,all_output,ins_trans_list_i,pred_acc_set,line_set=SUTE_I(transferability_output_filename,file_contents,feature_path,method,args)
            return ins_trans_list,pred_acc,all_output,ins_trans_list_i,pred_acc_set,line_set
        else:
            with open(transferability_output_filename, 'w') as output_file:
                for line in file_contents:
                    source=line.split(".mat")[0][-3]
                
                    feature,output,label,pse=load_mat(feature_path+line)
                    feature_source,_,_,_=load_mat1(feature_path+line)
                    contains_nan = np.isnan(feature).any()
                    if contains_nan:
                        continue
                    transfer_ability,pred_acc=transfer_calcualte_for_a_model(method,output,label,feature,pse,args,feature_source,line)
                    if np.isnan(transfer_ability):
                        continue
                    print(line,transfer_ability,pred_acc)
                    all_output.append(output)
                    output_file.write(line+" "+str(transfer_ability)+" "+str(pred_acc)+ '\n')
                    loss_set.append(transfer_ability)
                    pred_acc_set.append(pred_acc)

                a=pd.Series(loss_set)
                b=pd.Series(pred_acc_set)
                print("Spearman",stats.spearmanr(a,b))

                return all_output


def print_acc_of_each_model(model_config_file,feature_path,transferability_output_filename,args):
    print(model_config_file)
    print(transferability_output_filename)
    method="entropy"
    file_contents = read_text_file(model_config_file)
    loss_set=[]
    pred_acc_set=[]
    pred_acc_set_all=[]
    all_output=[]

    if file_contents:
        with open(transferability_output_filename, 'w') as output_file:
            for line in file_contents:
                source=line.split(".mat")[0][-3]
                source_line=line[:-5]+source+'.mat'
                feature,output,label,pse=load_mat(feature_path+line)
                feature_source,_,_,_=load_mat(feature_path+source_line)
                transfer_ability,pred_acc=transfer_calcualte_for_a_model(method,output,label,feature,pse,args,feature_source)
                print(line,'Acc:',pred_acc)
                pred_acc_set_all.append(pred_acc)

                all_output.append(output)
                output_file.write(line+" Acc:"+str(pred_acc)+ '\n')
             
    return 1


def ins_transferability(output,feature,pse):
    try:
        output=torch.from_numpy(output)
    except:
        pass
    try:    
        feature=torch.from_numpy(feature)
    except:
        pass
    pred=nn.Softmax(dim=1)(output).cuda()

    loss=entropy_loss(pred)
    loss=-loss 
    return loss




def SUTE_IG(transferability_output_filename,file_contents,feature_path,method,args,threshold=1):
    line_set = []
    line_set_t = []

    all_output=[]
    SUTE_trans_set=[]
    list_trans_sute_t = []
    pred_acc_set=[]
    all_feature=[]

    ins_acc = []
    ins_sute = []

    group_trans_set=[]
    group_dist_set = []
    pred_acc_set_g=[]

    all_output_t=[]
    Temporal_trans_set=[]
    pred_acc_set_t=[]

    all_feature_t=[]

    ins_trans_list=[]
    ins_trans_list_i=[]
    with open(transferability_output_filename, 'w') as output_file:
        for line in file_contents:

            source=line.split(".mat")[0][-3]
        
            feature,output,label,pse=load_mat(feature_path+line)
            try:
                all_output_1=torch.from_numpy(output)
                all_label=torch.from_numpy(label).long()
                all_fea=torch.from_numpy(feature)
                pse_label=torch.from_numpy(pse).long()
            except:
                pass

            if np.isnan(feature).any():
                continue

            c = args.class_num  
            sute_threshold = -(1.1 * math.log(c,10))    

            all_feature.append(feature)
            feature_source,_,_,_=load_mat1(feature_path+line)
          
            method = 'SUTE'
            transfer_ability,pred_acc=transfer_calcualte_for_a_model(method,output,label,feature,pse,args,feature_source,line)
            
            if transfer_ability<sute_threshold:
                continue
            if np.isnan(transfer_ability):
                continue
            ins_trans_list.append(transfer_ability)
            _, predict = torch.max(all_output_1, 1)
            pred=nn.Softmax(dim=1)(all_output_1)
            ins_accuracy = nn.CrossEntropyLoss(reduction='none')(pred,all_label)
            ins_acc.append(ins_accuracy)

            print(line,transfer_ability,pred_acc)
            all_output.append(output)

            pred_acc_set.append(pred_acc)
            line_set.append(line)
            list_trans = []
            list_trans_sute = []
            list_acc = []

            method = 'GroupN'
            min_dist,transfer_ability_g,pred_acc=transfer_calcualte_for_a_model(method,output,label,feature,pse,args,feature_source,line)
            if np.isnan(transfer_ability_g):
                continue
            print(line,transfer_ability_g,pred_acc)
            group_trans_set.append(transfer_ability_g)
            group_dist_set.append(min_dist)
            pred_acc_set_g.append(pred_acc)
            list_trans_sute_t = []

            SUTE_trans_set.append(transfer_ability)
            empty_tensor = torch.empty_like(ins_accuracy)
            ins_sute_acc = empty_tensor.fill_(transfer_ability)
            ins_sute.append(ins_sute_acc)



    a=pd.Series(SUTE_trans_set)
    b=pd.Series(pred_acc_set)
    print("SUTE Spearman",stats.spearmanr(a,b))
    
    a=pd.Series(group_trans_set)
    b=pd.Series(pred_acc_set_g)
    print("Class Spearman",stats.spearmanr(a,b))

    
    t = 8
    SUTE_trans_set_1=np.array(SUTE_trans_set)
    SUTE_trans_set_1 = torch.from_numpy(SUTE_trans_set_1)

    SUTE_trans_set_1=nn.Softmax(dim=0)(SUTE_trans_set_1/t)
 

    for i in range(len(SUTE_trans_set_1)):
        ins_sute[i].fill_(SUTE_trans_set_1[i])



    num_models=len(all_output)
    num_samples=(all_output[0]).shape[0]
    

    for iter_num in range(1):
        

        try:
            all_output=np.array(all_output)
        except:
            pass

        all_output=torch.from_numpy(all_output)
        all_output1=torch.mean(all_output,dim=0)
        _,pse1=torch.max(all_output1,1)
        try:
            label=torch.from_numpy(label)
        except:
            pass


        criterion=nn.CrossEntropyLoss()
        criterion1=nn.CrossEntropyLoss(reduction='none')
        for i in range(num_models):

            ins_trans_i = Entropy1(all_output[i])

            ins_trans_list_i.append(ins_trans_i.cpu())

        combine_sutei = torch.stack(ins_trans_list_i)


        ins_trans_list_i = combine_sutei.tolist()


        ins_trans_list_i = [torch.tensor(lst) for lst in ins_trans_list_i]
        for i in range(len(ins_trans_list_i)):
            ins_trans_list_i[i] = (ins_trans_list_i[i] - ins_trans_list_i[i].min()) / (ins_trans_list_i[i].max() - ins_trans_list_i[i].min())
        group_dist_set = [torch.tensor(lst) for lst in group_dist_set]
        for i in range(len(group_dist_set)):
            group_dist_set[i] = (group_dist_set[i] - group_dist_set[i].min()) / (group_dist_set[i].max() - group_dist_set[i].min())
        trans_group_i = []
        trans_ins_i = []
        for tensor in ins_trans_list_i:

            nested_list = tensor.tolist()

            for sublist in nested_list:

                if isinstance(sublist, list):
                    trans_ins_i.extend(sublist)
                else:
                    trans_ins_i.append(sublist)

        for tensor in group_dist_set:

            nested_list = tensor.tolist()

            for sublist in nested_list:

                if isinstance(sublist, list):
                    trans_group_i.extend(sublist)
                else:
                    trans_group_i.append(sublist)

        lists = [ins_trans_list_i,group_dist_set,ins_sute]
       
        
        tmp_lists = [ b*(1+torch.log(1+a)) for a, b in zip(ins_sute,group_dist_set)]
        new_lists = [ b*(1+torch.log(1+a)) for a, b in zip(tmp_lists,ins_trans_list_i)]

        ins_trans_result = []
        for tensor in new_lists:

            nested_list = tensor.tolist()

            for sublist in nested_list:
                if isinstance(sublist, list):
                    ins_trans_result.extend(sublist)
                else:
                    ins_trans_result.append(sublist)
        
        ins_acc_result = []
        for tensor in ins_acc:

            nested_list = tensor.tolist()

            for sublist in nested_list:

                if isinstance(sublist, list):
                    ins_acc_result.extend(sublist)
                else:
                    ins_acc_result.append(sublist)

        a=pd.Series(trans_ins_i)

        b=pd.Series(ins_acc_result)

        print("Instance level Spearman",stats.spearmanr(a,b))

        a=pd.Series(trans_group_i)

        b=pd.Series(ins_acc_result)

        print("Group Instance level Spearman",stats.spearmanr(a,b))

        a=pd.Series(ins_trans_result)

        b=pd.Series(ins_acc_result)

        print("Instance Group Dataset Spearman",stats.spearmanr(a,b))

        n = ins_trans_list_i[0].shape[0]
        ins_trans_result = chunk_list(ins_trans_result,n)

    return ins_trans_list,pred_acc,all_output,ins_trans_result,pred_acc_set,line_set


def read_transferability_text_file(file_path):
    model_names = []
    transferability_metrics = []
    accuracies = []

    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:

            columns = line.strip().split()
            

            model_name = columns[0]
            transferability_metric = float(columns[1])
            accuracy = float(columns[2])
            
            model_names.append(model_name)
            transferability_metrics.append(transferability_metric)
            accuracies.append(accuracy)
    return model_names,transferability_metrics,accuracies







    