from operator import pos
import random
from numpy.core.fromnumeric import sort
import torch_geometric
from torch_geometric.data import Dataset as GDataset
from torch_geometric.data import Data
from torch.nn import MarginRankingLoss
import torch
from torch.utils.data import WeightedRandomSampler
import torch.nn as nn
import os, itertools, tqdm, json, time
import numpy as np
from scipy.stats import spearmanr
from torch.distributed import ReduceOp
from torch.cuda.amp import autocast
import torch.nn.functional as F
'''
File - utils.py
This file defines some utility functions
'''

class GeometricDataset3(GDataset):
    def __init__(self, labels1, labels0, labels_true, data_dir, edge_sets, should_cache=False):
        self.labels1 = labels1
        self.labels0 = labels0
        self.labels_true = labels_true
        self.data_dir = data_dir
        self.edge_sets = edge_sets
        if should_cache:
            self.cache = dict()
        else:
            self.cache = None

    def __len__(self):
        return len(self.labels1)

    def __getitem__(self, idx):
        res = None
        if (not self.cache) or (idx not in self.cache):
            data_name = self.labels_true[idx][0].split("|||")[0]


            path = os.path.join(self.data_dir, self.labels_true[idx][0].split("|||")[0] + ".npz")
            edges = np.load(os.path.join(self.data_dir, self.labels_true[idx][0].split("|||")[0] + "_Edges.npz"))

            nodes = np.load(path)
            tokens = torch.from_numpy(nodes['node_rep'])

            edges_index= torch.from_numpy(edges['edges_index'])
            edges_attr= torch.from_numpy(edges['edges_attr'])

            label1 = torch.tensor(self.labels1[idx][1])
            label0 = torch.tensor(self.labels0[idx][1])
            label_true = torch.tensor(self.labels_true[idx][1])
            problemType = torch.tensor([float(self.labels_true[idx][0].split("|||")[1])])

            res = Data(x=tokens.float(), edge_index=edges_index, edge_attr=edges_attr, problemType=problemType,data_name=str(data_name)), label1, label0, label_true

        if self.cache and idx not in self.cache:
            self.cache[idx] = res
        elif self.cache:
            res = self.cache[idx]

        return res

class GeometricDataset2(GDataset):
    def __init__(self, labels, data_dir, edge_sets, should_cache=False):
        self.labels = labels
        self.data_dir = data_dir
        self.edge_sets = edge_sets
        if should_cache:
            self.cache = dict()
        else:
            self.cache = None

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        res = None
        if (not self.cache) or (idx not in self.cache):
            data_name = self.labels[idx][0].split("|||")[0]
            
            
            path = os.path.join(self.data_dir, self.labels[idx][0].split("|||")[0] + ".npz")
            edges = np.load(os.path.join(self.data_dir, self.labels[idx][0].split("|||")[0] + "_Edges.npz"))

            nodes = np.load(path)
            tokens = torch.from_numpy(nodes['node_rep'])
#            unique_rows = [tuple(row.tolist()) for row in tokens.float()]
#            unique_rows_set = set(unique_rows)
#            unique_rows_dict = {row: index for index, row in enumerate(unique_rows_set)}
       #     one_hot_list = []

       #     for row in tokens.float():
       #         row_tuple = tuple(row.tolist())  # ~F row è½¬~Mä¸º~E~C~D~L~[| ä¸º~H~Wè¡¨~M~O~S~H~L~L~F~E~C~D~Oä»¥
                # ~B~^~\è¯¥~L~\~G~N~G~L~H~F~E~M~@ä¸ª~V~Z~D One-hot ~V| ~A
       #         index = unique_rows_dict[row_tuple]
                # ä¸ºè¯¥~L~T~_~H~På¯¹~T~Z~D One-hot ~V| ~A
       #         one_hot = torch.zeros(len(unique_rows_dict))
       #         one_hot[index] = 1  # è®¾ç½®è¯¥~Mç½®ä¸º 1
       #         one_hot_list.append(one_hot)

            # Step 2: è½¬~Mä¸º tensor å¹¶~S~G~S~^~\
           # one_hot_x = torch.stack(one_hot_list)
#            max_dim = 303
#            unique_rows = [tuple(row.tolist()) for row in tokens.float()]
#            unique_rows_set = list(set(unique_rows))
#            unique_rows_dict = {row: index for index, row in enumerate(unique_rows_set)}
     #   num_dimensions = (len(unique_rows_dict) + max_dim -1) // max_dim
#            one_hot_list = []

#            for row in tokens.float():
#                row_tuple = tuple(row.tolist())  # ~F row è½¬~Mä¸º~E~C~D~L~[| ä¸º~H~Wè¡¨~M~O~S~H~L~L~F~E~C~D~Oä»¥
            # ~B~^~\è¯¥~L~\~G~N~G~L~H~F~E~M~@ä¸ª~V~Z~D One-hot ~V| ~A
#                index = unique_rows_dict[row_tuple]
            # ä¸ºè¯¥~L~T~_~H~På¯¹~T~Z~D One-hot ~V| ~A
#                one_hot = torch.zeros(max_dim)
#                if index < max_dim:
#                    one_hot[index] = 1
#                else:
#                    bit_count = index // max_dim
#                    remainder = index % max_dim
#                    for i in range(bit_count + 1):
#                        if i == 0:
#                            one_hot[remainder] = 1
#                        else:
#                            one_hot[i-1] = 1
         #   one_hot[dim_index][] = 1  # è®¾ç½®è¯¥~Mç½®ä¸º 1
 #               one_hot_list.append(one_hot)  

            edges_index= torch.from_numpy(edges['edges_index'])
            edges_attr= torch.from_numpy(edges['edges_attr'])

            label = torch.tensor(self.labels[idx][1])
            problemType = torch.tensor([float(self.labels[idx][0].split("|||")[1])])

            res = Data(x=tokens.float(), edge_index=edges_index, edge_attr=edges_attr, problemType=problemType,data_name=str(data_name)), label
            
        if self.cache and idx not in self.cache:
            self.cache[idx] = res
        elif self.cache:
            res = self.cache[idx]

        return res

class GeometricDataset(GDataset):
    def __init__(self, labels, data_dir, edge_sets, should_cache=False):
        self.labels = labels
        self.data_dir = data_dir
        self.edge_sets = edge_sets
        if should_cache:
            self.cache = dict()
        else:
            self.cache = None 
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        res = None
        if (not self.cache) or (idx not in self.cache):
            path = os.path.join(self.data_dir, self.labels[idx][0].split("|||")[0]+".npz")
            edges = np.load(os.path.join(self.data_dir, self.labels[idx][0].split("|||")[0]+"Edges.npz"))

            edges_tensor = [torch.from_numpy(edges[edgeSet]) for edgeSet in self.edge_sets]
            if "AST" in self.edge_sets:
                edges_tensor[0] = torch.cat((edges_tensor[0],(edges_tensor[0].flip([1]))))

            edge_labels = torch.cat([torch.full((len(edges_tensor[i]),1),i) for i in range(len(edges_tensor))], dim=0).float()        
            edges_tensor = torch.cat(edges_tensor).transpose(0,1).long()

            data = np.load(path)
            tokens = torch.from_numpy(data['node_rep'])

            label = torch.tensor(self.labels[idx][1])
            problemType = torch.tensor([float(self.labels[idx][0].split("|||")[1])])

            res = Data(x=tokens.float(), edge_index=edges_tensor, edge_attr=edge_labels, problemType=problemType),label

        if self.cache and idx not in self.cache:
            self.cache[idx] = res
        elif self.cache:
            res = self.cache[idx]

        return res

class ModifiedMarginRankingLoss(nn.Module):
    def __init__(self, margin=0,gpu=0):
        super(ModifiedMarginRankingLoss, self).__init__()
        self.margin=margin
        self.gpu = gpu
    
    def forward(self, scores, labels):
        loss = torch.zeros(1).to(device=self.gpu)
        for i, j in itertools.combinations(list(range(len(labels[0]))),2):
            loss_fn = MarginRankingLoss(margin=self.margin*abs(i-j))
            indx = labels.argsort()
            loss += loss_fn(scores.gather(1, indx[:,i].unsqueeze(1)), scores.gather(1, indx[:,j].unsqueeze(1)), torch.tensor([[1 if i > j else -1]*scores.size(0)]).to(device=self.gpu))
        return loss

class topKLoss(nn.Module):
    def __init__(self, k=1):
        super(topKLoss, self).__init__()
        self.lossFn = nn.MSELoss()
        self.k=k

    def forward(self, scores, labels):
        topKVals, topKIdxs = labels.topk(self.k)

        return self.lossFn(scores[:,topKIdxs.squeeze(0)], topKVals)

class RankLoss(nn.Module):
    def __init__(self,margin=0,gpu=0):
        super(RankLoss,self).__init__()
        self.gpu=gpu
        self.margin=margin

    def forward(self, scores, labels):
        scores = scores.view(-1).type(torch.float32)
        labels = labels.view(-1).type(torch.float32)
      #  scores_indices = torch.argsort(scores)
      #  scoresRank = torch.empty_like(scores_indices)
      #  scoresRank[scores_indices] = torch.arange(len(scores)).to(device=self.gpu)
        rank_loss = torch.tensor(0.0, device = scores.device, requires_grad = True)
        indx = labels.argsort()
        for i, j in itertools.combinations(list(range(len(labels))),2):
           # loss_fn = MarginRankingLoss(margin=self.margin*abs(i-j))
            loss_fn = MarginRankingLoss(margin=self.margin)
            rank_loss = rank_loss + loss_fn(scores[indx[i]].unsqueeze(0), scores[indx[j]].unsqueeze(0), torch.tensor([1 if i > j else -1]).to(device=self.gpu))
        
        return rank_loss

class MseLoss(nn.Module):
    def __init__(self,gpu=0):
        super(MseLoss,self).__init__()
        self.gpu=gpu
        self.lossFn = nn.MSELoss()

    def forward(self, scores, labels):
        scores = scores.view(-1).type(torch.float32)
        labels = labels.view(-1).type(torch.float32)
        scores_signs = torch.sign(scores)
        labels_signs = torch.sign(labels)
        mask = scores_signs != labels_signs
       # print(scores[mask])
       # print(labels[mask])
        if mask.any():
            mse_loss = self.lossFn(scores[mask], labels[mask])
        else:
            mse_loss = torch.tensor(0.0, device = scores.device, requires_grad = True)
       # mse_loss = self.lossFn(scores, labels)
        return mse_loss

class SignLoss(nn.Module):
    def __init__(self,margin=1, gpu=0):
        super(SignLoss,self).__init__()
        self.gpu = gpu
        self.margin = margin
       # self.margin_loss = nn.MarginRankingLoss(margin = margin)
        self.lossFn = nn.MSELoss()

    def forward(self, scores, labels):
        scores = scores.view(-1).type(torch.float32)
        labels = labels.view(-1).type(torch.float32)
      #  print(scores.dtype)
      #  print(labels.dtype)
        # æ£æ¥æ æå¼
       # if torch.isnan(scores).any() or torch.isinf(scores).any():
       #     print("Scores contain NaN or Inf.")
       # if torch.isnan(labels).any() or torch.isinf(labels).any():
       #     print("Labels contain NaN or Inf.")
      #  print(scores)
      #  print(labels)
      #  mse_loss = self.lossFn(scores, labels)
      ##  print(mse_loss)
       # if torch.isnan(mse_loss) or torch.isinf(mse_loss):
       #     print("MSE loss is NaN or Inf.")
       # selected_preds = scores[labels <= 0]
        # åªå¨ selected_preds ä¸­æ¥æ¾å¤§äº 0 çå¼
       # if selected_preds.numel() > 0:  # ç¡®ä¿ä¸ä¸ºç©º
       #     positive_preds = selected_preds[selected_preds > 0]
       #     if positive_preds.numel() > 0:  # ç¡®ä¿æå¤§äº 0 çé¢æµå¼
       #         penalty = torch.sum(torch.clamp(positive_preds, min=1e-6))
       #     else:
       #         penalty = torch.tensor(0.0, device=scores.device,requires_grad=True)  # æ²¡æå¤§äº 0 çé¢æµå¼
       # else:
       #     penalty = torch.tensor(0.0, device=scores.device,requires_grad=True)  # æ²¡æéä¸­çé¢æµå¼
      #  scores_signs = torch.sign(scores)
      #  labels_signs = torch.sign(labels)
       # print("Losslabels requires_grad:", labels.requires_grad)
       # print("Lossscores requires_grad:", scores.requires_grad)
      #  positive_mask = labels > 0
      #  positive_scores = scores > 0
      #  negative_mask = labels < 0
     #   print(positive_mask)
      #  positive_true = positive_mask & positive_scores
       # positive_mse_losses = mse_loss[positive_true]
     #   positive_true = positive_mask & positive_scores
      #  positiveScores = scores[positive_mask]
      #  positiveLabels = labels[positive_mask]
      #  trueLabels = labels[positive_mask]
       # print(positiveScores)
       # print(positiveLabels)
      #  if positiveScores.numel() == 0 or positiveLabels.numel() == 0:
      #      rank_loss = torch.tensor(0.0, device = scores.device, requires_grad = True)
      #  else:
           # scores_indices = torch.argsort(positiveScores)
           # scoresRank = torch.empty_like(scores_indices)
           # scoresRank[scores_indices] = torch.arange(len(positiveScores)).to(device=self.gpu)
     #   scores_indices = torch.argsort(scores)
     #   scoresRank = torch.empty_like(scores_indices)
     #   scoresRank[scores_indices] = torch.arange(len(scores)).to(device=self.gpu)
           # print(scoresRank)
           # labels_indices = torch.argsort(positiveLabels)
          # # print(scores_indices)
           # print(labels_indices)
        #    pairs = []
        #    for i in range(len(positiveScores) - 1):
        #        for j in range(i + 1,len(positiveScores)):
        #            pairs.append((positiveScores[i], positiveScores[j],positiveLabels[i] < positiveLabels[j]))
     #   rank_loss = torch.tensor(0.0, device = scores.device, requires_grad = True)
     #   rank_loss = rank_loss + torch.mean((scores_indices - labels_indices).float().pow(2))
        #    for p1, p2, target in pairs:
        #        rank_loss = rank_loss + self.margin_loss(p1.unsqueeze(0), p2.unsqueeze(0),torch.tensor([1 if target else -1], device=self.gpu))
     #   indx = labels.argsort()
     #   for i, j in itertools.combinations(list(range(len(labels))),2):
           #     loss_fn = MarginRankingLoss(margin=self.margin*abs(i-j))
     #       loss_fn = MarginRankingLoss(margin=self.margin)
               # indx = positiveLabels.argsort()
     #       rank_loss = rank_loss + loss_fn(scores[indx[i]].unsqueeze(0), scores[indx[j]].unsqueeze(0), torch.tensor([1 if i > j else -1]).to(device=self.gpu))
             #   print(scoresRank[indx[i]].unsqueeze(0))
             #   print(scoresRank[indx[j]].unsqueeze(0))
             #   print(rank_loss)
          #  print(positiveScores.size(0))
       #     rank_loss = rank_loss * (18/positiveLabels.size(0))
          #  print(rank_loss)
              # print(indx)
               # print(i)
               # print(j)
               # print(indx[i])
               # print(indx[j])
               # print(positiveScores[indx[i]].unsqueeze(0))
               # print(positiveScores[indx[j]].unsqueeze(0))
           # print(scores[indx[i]].unsqueeze(0))
           # print(scores[indx[j]].unsqueeze(0))
           # print(torch.tensor([1 if i > j else -1]))
           # print(rank_loss)

      #  mse_loss = positive_mse_losses.sum() if positive_mse_losses.numel() > 0 else torch.tensor(0.0, device = scores.device, requires_grad = True)
       # sign_penalty = torch.tensor(0.0, device=scores.device, requires_grad = True)  # é»è®¤æ©ç½ä¸º0
       # if positive_mask.any():
       #     sign_penalty = sign_penalty + torch.sum((scores[positive_mask] < 0).float())

       # penalized_mask = positive_scores & negative_mask
       # if penalized_mask.any():
       #     sign_penalty = sign_penalty + torch.sum((scores[penalized_mask] > 0).float())
      #  print(rank_loss)
      #  total_loss = mse_loss + sign_penalty
       # total_loss = rank_loss
       # sign_loss = torch.sum(torch.abs(scores_signs - labels_signs))
        scores_signs = torch.sign(scores)
        labels_signs = torch.sign(labels)
        mask = scores_signs != labels_signs 
      #  p_mask = scores_signs == labels_signs
       # sign_loss = torch.mean((scores_signs != labels_signs).float() * torch.tanh(abs((torch.tanh(scores) - torch.tanh(labels)))))
        if mask.any():
          #  mse_loss = self.lossFn(scores[mask], labels[mask])
            smooth_scores = scores[mask] / (1 + torch.abs(scores[mask]))
            smooth_labels = labels[mask] / (1 + torch.abs(labels[mask]))
            sign_loss = torch.mean((smooth_scores - smooth_labels) ** 2)
        else:
            sign_loss = torch.tensor(0.0, device = scores.device, requires_grad = True)
           # mse_loss = torch.tensor(0.0, device = scores.device, requires_grad = True)
       # mse_loss = self.lossFn(scores,labels)
      # total_loss = 0.5*(0.5*sign_loss + 0.5*mse_loss) + 0.5*rank_loss
      #  print(scores_signs)
      #  print(labels_signs)
     #   print(sign_loss)
     #   print(mse_loss)
     #   print(rank_loss)
      #  total_loss = 10*sign_loss + 10*rank_loss + mse_loss
      #  print(total_loss)
      #  print(mse_loss)
      #  print(penalty)
       # print("Lossmatches requires_grad:", matches.requires_grad)
    
       # print("Loss requires_grad:", loss.requires_grad)
       # print(total_loss)
        return 10*sign_loss


def train_model(model, loss_fn, batchSize, trainset, valset, optimizer, scheduler, num_epochs, gpu, task, k=1, trainWeights=None, valWeights=None):
    '''
    Function used to train networks
    '''
    if trainWeights is None:
        train_loader = torch_geometric.loader.DataLoader(dataset=trainset, batch_size=batchSize, shuffle=True)
        val_loader = torch_geometric.loader.DataLoader(dataset=valset, batch_size=batchSize, shuffle=True)
    else:
        trainSampler = WeightedRandomSampler(weights=trainWeights, num_samples=len(trainWeights))
        valSampler = WeightedRandomSampler(weights=valWeights, num_samples=len(valWeights))

        train_loader = torch_geometric.loader.DataLoader(dataset=trainset, batch_size=batchSize, sampler=trainSampler)
        val_loader = torch_geometric.loader.DataLoader(dataset=valset, batch_size=batchSize, sampler=valSampler)

    train_accuracies = []; val_accuracies = []
    train_losses = []; val_losses = []

    for epoch in range(0, num_epochs):
        corr_sum = 0.0
        cum_loss = 0.0

        topk_acc = 0.0
        success_acc = 0.0
        model.train()
        torch.enable_grad()
        success_counter = 0
        for (i, (graphs, labels)) in enumerate(tqdm.tqdm(train_loader)):
            graphs = graphs.to(device=gpu)
            labels = labels.to(device=gpu)
#            print("Graphs requires_grad:", graphs.edge_index.requires_grad)
#            print("Labels requires_grad:", labels.requires_grad)
            assert graphs.x.size(0) == graphs.batch.size(0)
            
            if task == "success" and labels.max()<0:
                pass

            with autocast():
                scores = model(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
                if task == "rank":
                    loss = loss_fn(scores, labels)
                    cum_loss+=loss.cpu().detach().item()
                elif task == "topk" or task == "success":
                    loss = loss_fn(nn.functional.log_softmax(scores, dim=1), labels.argmax(dim=1))    
                    cum_loss+=loss.cpu().detach().item()


            for j in range(len(labels)):
                corr, _ = spearmanr(labels[j].cpu().detach(), scores[j].cpu().detach().tolist())
      #          corr = 0
                corr_sum+=corr
                # assert corr <=1, str(scores) + " " + str(labels)
                _, scoreTopk = scores[j].topk(k)
                labelTopk = labels[j].argmax()
                topk_acc += labelTopk in scoreTopk
                # exit()
                         
                success_counter += (labels[j].max()>0).item()
                success_acc += (labels[j][scores[j].argmax()]>0).item()


            optimizer.zero_grad()
            loss.backward()
            model.float()
            optimizer.step()
            if round(len(trainset)//batchSize, -2) != 0:
                condition = (((i+1)/round(len(trainset)//batchSize, -2))*100)%10==0
            else:
                condition = False
            if condition or (i+1)==len(train_loader):
                mystr = "Train-epoch "+ str(epoch) + ", Avg-Loss: "+ str(round(cum_loss/(i*batchSize), 4)) + ", Avg-Corr:" +  str(round(corr_sum/(i*batchSize), 4)) + ", TopK-Acc:"+str(round(topk_acc/(i*batchSize), 4)) + ", Success-Acc:"+str(round(success_acc/success_counter,4))
                print(mystr)
                train_accuracies.append(round(corr_sum/i, 4))
                train_losses.append(round(cum_loss/i, 4))
            

        corr_sum = 0.0
        cum_loss = 0.0
        model.eval()

        topk_acc = 0.0
        topk_loss = 0.0
        success_acc = 0.0
        success_counter = 0

        for (i, (graphs,labels)) in enumerate((val_loader)):
            graphs = graphs.to(device=gpu)
            labels = labels.to(device=gpu)
            if task == "success" and labels.max()<0:
                pass
            if torch.isnan(graphs.x).any() or torch.isinf(graphs.x).any():
                 print(f"graphs.x NaN or Inf found in input data at batch {i} with data_name: {graphs.data_name}")
                 continue  # è·³è¿è¯¥æ¹æ¬¡
            
            with autocast():
                with torch.no_grad():
                    scores = model(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
                  #  print(f"Scores:{scores}")  
                    if torch.isnan(scores).any() or torch.isinf(scores).any():
                        print(f"scores NaN or Inf found in input data at batch {i} with data_name: {graphs.data_name}")
                        continue  # è·³ï¿½~Gè¯¥ï¿½~Iï¿½æ¬¡
                    
                    try:
                        if task == "rank":
                            loss = loss_fn(scores, labels)
                        elif task == "topk" or task == "success":
                            loss = loss_fn(nn.functional.log_softmax(scores, dim=1), labels.argmax(dim=1))    
                        if torch.isnan(loss).any() or torch.isinf(loss).any():
                            print(f"loss NaN or Inf found in loss at batch {i} with data_name: {graphs.data_name}")
                            continue  # è·³è¿è¯¥æ¹æ¬¡
                        cum_loss+=loss.cpu().detach().item()
                    except Exception as e:
                        print(f"Error during loss calculation at batch {i}, data_name: {graphs.data_name}: {e}")
                        continue  # è·³è¿è¯¥æ¹æ¬¡

            for j in range(len(labels)):
                corr, _ = spearmanr(labels[j].cpu().detach(), scores[j].cpu().detach().tolist())
#                corr = 0
                if not np.isfinite(corr):
                    print(f"corr NaN or Inf found in correlation at batch {i}, data {j} with data_name: {graphs.data_name}")
                    continue  # è·³è¿è¯¥æ°æ®
                corr_sum+=corr
                # assert corr <=1, str(scores) + " " + str(labels)
                _, scoreTopk = scores[j].topk(k)
                labelTopk = labels[j].argmax()
                topk_acc += labelTopk in scoreTopk
                # exit()
                         
                success_counter += (labels[j].max()>0).item()
                success_acc += (labels[j][scores[j].argmax()]>0).item()

        scheduler.step(cum_loss/(i+1))


        val_accuracies.append(round(corr_sum/i, 4))
        val_losses.append(round(cum_loss/i, 4))

        mystr = "Valid-epoch "+ str(epoch) + ", Avg-Loss: "+ str(round(cum_loss/(i*batchSize), 4)) + ", Avg-Corr:" +  str(round(corr_sum/(i*batchSize), 4)) + ", TopK-Acc:"+str(round(topk_acc/(i*batchSize), 4)) + ", Success-Acc:"+str(round(success_acc/success_counter,4))
        print(mystr)
        if optimizer.param_groups[0]['lr']<1e-7:
            break
    
    return train_accuracies, train_losses, val_accuracies, val_losses

def train_model2(model, loss_fn, batchSize, trainset, valset, optimizer, scheduler, num_epochs, gpu, task, k=1,
                trainWeights=None, valWeights=None):
    '''
    Function used to train networks
    '''
    if trainWeights is None:
        train_loader = torch_geometric.loader.DataLoader(dataset=trainset, batch_size=batchSize, shuffle=True)
        val_loader = torch_geometric.loader.DataLoader(dataset=valset, batch_size=batchSize, shuffle=True)
    else:
        trainSampler = WeightedRandomSampler(weights=trainWeights, num_samples=len(trainWeights))
        valSampler = WeightedRandomSampler(weights=valWeights, num_samples=len(valWeights))

        train_loader = torch_geometric.loader.DataLoader(dataset=trainset, batch_size=batchSize, sampler=trainSampler)
        val_loader = torch_geometric.loader.DataLoader(dataset=valset, batch_size=batchSize, sampler=valSampler)

    train_accuracies = [];
    val_accuracies = []
    train_losses = [];
    val_losses = []

    for epoch in range(0, num_epochs):
        corr_sum = 0.0
        cum_loss = 0.0

        topk_acc = 0.0
        success_acc = 0.0
        
        hit_count = 0
        total_positive = 0
        total_select = 0
        hit_best = 0
        total_best = 0
        select_best = 0
        a = 0
        algorithm_correct = 0.0
        algorithm_correct_sum = 0
        algorithm_best = 0.0
        algorithm_best_sum = 0.0
        model.train()
        torch.enable_grad()
       # torch.set_grad_enabled(True)
        success_counter = 0
        for (i, (graphs, labels)) in enumerate(tqdm.tqdm(train_loader)):
            graphs = graphs.to(device=gpu)
            labels = labels.to(device=gpu)
      #      labels.requires_grad = True
           # graphs.x.requires_grad = True
           # graphs.edge_index.requires_grad = True
          #  graphs.edge_attr.requires_grad = True
           # graphs.batch.requires_grad = True
           # graphs.problemType.requires_grad = True
          #  print("Graphs requires_grad:", graphs.edge_index.requires_grad)
          #  print("Labels requires_grad:", labels.requires_grad)
            assert graphs.x.size(0) == graphs.batch.size(0)

            if labels.max() <= 0:
                a += 1
                pass

#            if not torch.any((graphs.edge_attr[:, 0] == 16) | (graphs.edge_attr[:, 0] == 17)):
#                continue

            with autocast():
                scores = model(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
              #  print(scores.dtype)
              #  print(labels.dtype)

              #  print(scores.requires_grad)
              #  print(labels.requires_grad)
                if task == "rank":
                    loss = loss_fn(scores, labels)
                    cum_loss += loss.cpu().detach().item()
                elif task == "topk" or task == "success":
                    total_loss = loss_fn(nn.functional.log_softmax(scores, dim=1), labels.argmax(dim=1))
                    cum_loss += total_loss.cpu().detach().item()
                elif task == "algorithm":
                   # with torch.autograd.detect_anomaly():
                    #    loss = loss_fn(scores, labels)
                 #   sign_loss = loss_fn(scores, labels)
                 #   rank_loss_fn = RankLoss(gpu=gpu).to(device=scores.device) 
                  #  rank_loss = rank_loss_fn(scores, labels)
                 #   mse_loss_fn = MseLoss(gpu=gpu).to(device=scores.device)
                 #   mse_loss = mse_loss_fn(scores, labels)
                #  i print(sign_loss)
                #    print(mse_loss)
                #    print(rank_loss)
                 #   total_loss = mse_loss
                 #   cum_loss += total_loss.cpu().detach().item()
                #    print("Loss requires_grad:", loss.requires_grad)
          #          if labels.max() <= 0 :
          #              a += 1
          #              continue
               #     print(scores)
                #    print(labels)
        #            positive_indices = (labels[0] >= 0).nonzero(as_tuple=True)[0]
                  #  print(positive_indices)
                  #  print(labels.argmax(dim=1))
        #            if labels.min() < 0:
        #                labels = labels + torch.abs(labels.min())
    #                print(labels)
    #                labels = F.softmax(labels,dim=1)
           #         print(labels)
                    total_sum = labels.sum()
                    sorted_labels, sorted_indices = torch.sort(labels, descending=True)
                    cumulative_sum = torch.cumsum(sorted_labels, dim=1)
                    threshold = total_sum * 0.8
                    selected_indices = sorted_indices[cumulative_sum < threshold]
               #     print(selected_indices)
                    if cumulative_sum[0][-1] > threshold:  # å¦æç´¯å åçæåä¸ä¸ªå¼å¤§äºéå¼
                        selected_indices = torch.cat((selected_indices, sorted_indices[cumulative_sum > threshold][[0]]))
            #        print(selected_indices)
        #            selected_indices = selected_indices[torch.isin(selected_indices, positive_indices)]
               #     print(selected_indices)
        #            NLloss = torch.tensor(0.0, device = scores.device, requires_grad = True)
        #            for idx in selected_indices:
        #                target = torch.zeros(1, dtype=torch.long).to(scores.device)
        #                target[0] = idx  # å¯¹åºç´¢å¼è®¾ç½®ä¸º 1
                     #   print(target)
        #                NLloss = NLloss + loss_fn(nn.functional.log_softmax(scores, dim=1), target)
                    updated_labels = torch.zeros_like(labels).float()
                    updated_labels[0,selected_indices] = 1
                #    scores_signs = torch.sign(scores)
                #    labels_signs = torch.sign(labels)
                #    mask = scores_signs != labels_signs
                #    if mask.any():
                #        smooth_scores = scores[mask] / (1 + torch.abs(scores[mask]))
                #        smooth_labels = labels[mask] / (1 + torch.abs(labels[mask]))
                #        sign_loss = torch.mean((smooth_scores - smooth_labels) ** 2)
                #    else:
                #        sign_loss = torch.tensor(0.0, device = scores.device, requires_grad = True)
              #      print(scores)
              #      print(labels)
           #         binary_labels = (labels > 0).float()
              #      print(binary_labels)
                    mul_loss_fn = nn.BCEWithLogitsLoss(reduction = 'sum')
                    mul_loss = mul_loss_fn(scores, updated_labels)
              #      cosine_loss_fn = nn.CosineEmbeddingLoss()
              #      y = torch.ones(scores.shape[0],device=scores.device)
                #    print(y)
              #      cosine_loss = cosine_loss_fn(scores, updated_labels, y)
             #       print(binary_labels)
                  #  print(mul_loss)
                  #  print(sign_loss)
                   # positive_indices = (labels > 0).nonzero(as_tuple=True)[0]
                    # 2. ä»æ­£å¼ç´¢å¼ä¸­ç§»é¤ selected_indices
              #      topk_loss_fn = topKLoss(k=selected_indices.size(0)).to(device=scores.device)
                   # print(scores)
                   # print(selected_indices)
                   # print(selected_indices.size(0))
              #      topKloss = topk_loss_fn(scores, labels)
                    
              #      print(positive_indices)
              #      print(selected_indices)
       #             remaining_positive_indices = positive_indices[~torch.isin(positive_indices, selected_indices)].long().to(scores.device)
               #     print(remaining_positive_indices)
       #             if remaining_positive_indices.any():
       #                 extra_loss = torch.relu(torch.sigmoid(scores[0][remaining_positive_indices])).sum()
       #             else:
       #                 extra_loss = torch.tensor(0.0, device = scores.device, requires_grad = True)
              #      print(NLloss)
            #        print(mul_loss)
                 #   print(extra_loss)
               #     scores_pos_count = (torch.sigmoid(scores) > 0.5).sum().float()
               #     true_pos_count = selected_indices.size(0)
               #     reg_loss = torch.abs(scores_pos_count - true_pos_count)
            #        print(reg_loss)
                    total_loss = mul_loss
                    cum_loss+=total_loss.cpu().detach().item()
            
            for j in range(len(labels)):
                positive_mask = labels[j] > 0
                positive_scores = scores[j] > 0
            # #   print(positive_mask)
                positive_true = positive_mask & positive_scores
                positiveScores = scores[j][positive_true]
                positiveLabels = labels[j][positive_true]
            #    print(positiveScores)
            #    print(positiveLabels)
                if positiveScores.numel() == 0 or positiveLabels.numel() == 0 or torch.all(positiveLabels == positiveLabels[0]) or torch.all(positiveScores == positiveScores[0]):
                    a += 1
                    corr = torch.tensor(0.0, device = scores[j].device, requires_grad = True)
                else:
                    corr, _ = spearmanr(positiveScores.cpu().detach(), positiveLabels.cpu().detach().tolist())
           #     print(corr)
                corr_sum += corr
                # assert corr <=1, str(scores) + " " + str(labels)
                _, scoreTopk = scores[j].topk(k)
                labelTopk = labels[j].argmax()
                topk_acc += labelTopk in scoreTopk
                # exit()

                success_counter += (labels[j].max() > 0).item()
                success_acc += (labels[j][scores[j].argmax()] > 0).item()
                if labels[j].max() > 0:
                    positive_indices = (scores[j] > 0).nonzero(as_tuple=True)[0]
              #      positive_indices_labels = (labels[j] > 0).nonzero(as_tuple=True)[0]
             #       if labels.min() < 0:
             #           labels = labels + torch.abs(labels.min())
                    total_sum = labels.sum()
                    sorted_labels, sorted_indices = torch.sort(labels, descending=True)
                    cumulative_sum = torch.cumsum(sorted_labels, dim=1)
                    threshold = total_sum * 0.8
                    selected_indices = sorted_indices[cumulative_sum < threshold]
                    if cumulative_sum[0][-1] > threshold:  # å¦æç´¯å åçæåä¸ä¸ªå¼å¤§äºéå¼
                        selected_indices = torch.cat((selected_indices, sorted_indices[cumulative_sum > threshold][[0]]))
             #       selected_indices = selected_indices[torch.isin(selected_indices, positive_indices_labels)]
                    common_indices = torch.isin(positive_indices, selected_indices)

                    # è®¡ç® common_indices ä¸­ä¸º True çåç´ æ°é
                    hit_len = common_indices.sum().item()
                 #   print(positive_indices) 
            #        label_len = (labels[j] > 0).sum().item()
                   # print(label_len)
            #        hit_len = (labels[j][positive_indices] > 0).sum().item()
                #    print(hit_len)
                #    print(selected_indices.size(0))
                 #   print(positive_indices)
                 #   print(selected_indices)
                 #   print(common_indices)
                   # print(hit_len)
                   # print(label_len)
                    if hit_len == positive_indices.size(0):
                        hit_best += 1
                    
                    if hit_len == selected_indices.size(0) and hit_len == positive_indices.size(0):
                        select_best += 1
                    
                    total_best += 1

                   # print(predictions_above_zero.sum().item())
                   # print(len(positive_indices))
                   # print(hit_best)
                   # print(total_best)
                    hit_count += hit_len  # é¢æµæ­£ç¡®çæ°é
                    total_positive += positive_indices.size(0)
                    total_select += selected_indices.size(0)
                   # print(hit_count)
                   # print(total_positive)
               # print(scores)
               # print(labels)
               # print(scores[j])
               # print(labels[j])
               # print(len(labels[0]))
               # print(len(labels[j]))
                scores_signs = torch.sign(scores)
                labels_signs = torch.sign(labels)
               # print(scores_signs)
               # print(labels_signs)
                # ç¬¦å·å¹éçæ¦ç
                algorithm_correct = (scores_signs == labels_signs).float()
               # print(algorithm_correct.sum())
                algorithm_correct_sum += algorithm_correct.sum()
                # æ£æ¥ææç¬¦å·æ¯å¦é½å¹é
                algorithm_best = (scores_signs == labels_signs).all(dim=1).float()  # æ¯ä¸è¡é½è¦å¹é
               # print(algorithm_best.sum())
                algorithm_best_sum += algorithm_best.sum()
               # print("Scores requires_grad:", scores.requires_grad)
               # print("Labels requires_grad:", labels.requires_grad)
               # print("Scores_signs requires_grad:", scores_signs.requires_grad)
               # print("Labels_signs requires_grad:", labels_signs.requires_grad)
           # print("Loss:", loss.item())  # ç¡®ä¿è¿ä¸ªæ¯æ é
           # print("Loss requires_grad:", loss.requires_grad)  # è¿ä¸ªä¹åºè¯¥æ¯ True
           
           # print(scores.requires_grad)
            optimizer.zero_grad()
        #    scores[j].retain_grad()
        #    scores[j].requires_grad_()
       #     print("sign_weight_pre:",sign_weight)
       #     print("mse_weight_pre:",mse_weight)
       #     print("rank_weight_pre:",rank_weight)
            total_loss.backward()
           # sign_grad = sign_weight.grad.clone().detach()
           # mse_grad = mse_weight.grad.clone().detach()
           # rank_grad = rank_weight.grad.clone().detach()
            
       #     print("sign_grad:",sign_grad)
       #     print("mse_grad:",mse_grad)
       #     print("rank_grad:",rank_grad)
          #  sign_norm = torch.norm(sign_grad)
          #  mse_norm = torch.norm(mse_grad)
          #  rank_norm = torch.norm(rank_grad)
            
        #    print("sign_norm:",sign_norm)
        #    print("mse_norm:",mse_norm)
        #    print("rank_norm:",rank_norm)
       #     average_norm = (sign_norm + mse_norm + rank_norm) / 3.0
       #     sign_target = (sign_norm / average_norm) * sign_weight
       #     mse_target = (mse_norm / average_norm) * mse_weight
       #     rank_target = (rank_norm / average_norm) * rank_weight
            
        #    print("average_norm:",average_norm)
        #    print("sign_target:",sign_target)
        #    print("mse_target:",mse_target)
        #    print("rank_target:",rank_target)

         #   weight_loss = torch.abs(sign_target - sign_norm) + torch.abs(mse_target - mse_norm) + torch.abs(rank_target - rank_norm)
         #   weight_loss.backward()

        #    print("sign_weight_aft:",sign_weight)
        #    print("mse_weight_aft:",mse_weight)
        #    print("rank_weight_aft:",rank_weight)

        #    scores_signs = torch.sign(scores[j])
        #    labels_signs = torch.sign(labels[j])
           # with torch.no_grad():
       #     if scores[j].grad is not None:
       #         scores[j].grad[scores_signs == labels_signs] = 0
            model.float()
            optimizer.step()
         #   print("sign_weight_aft:",sign_weight)
         #   print("mse_weight_aft:",mse_weight)
         #   print("rank_weight_aft:",rank_weight)
            if round(len(trainset) // batchSize, -2) != 0:
                condition = (((i + 1) / round(len(trainset) // batchSize, -2)) * 100) % 10 == 0
            else:
                condition = False
            if condition or (i + 1) == len(train_loader):
             #   print(hit_count)
             #   print(total_positive)
                mystr = "Train-epoch " + str(epoch) + ", Avg-Loss: " + str(
                    round(cum_loss / (i * batchSize), 4)) + ", Avg-Corr:" + str(
                    round(corr_sum.item() / ((i-a) * batchSize), 4)) + ", TopK-Acc:" + str(
                    round(topk_acc / (i * batchSize), 4)) + ", AlgPositive-Acc:" + str(
                    round(hit_count / total_positive, 4)) + ", AlgSelect:" + str(round(hit_count / total_select, 4)) + ", AlgPositive-Best:" + str(
                    round(hit_best / total_best,4)) + ", AlgContains:" + str(
                    round(select_best / total_best,4)) + ", Algorithm-Corr:" + str(
                    round(algorithm_correct_sum.item() / (i * len(labels[0]) * batchSize), 4)) + ", Algorithm-Best:"+str(
                    round(algorithm_best_sum.item() / (i * batchSize), 4))
                print(mystr)
                train_accuracies.append(round(corr_sum.item() / (i-a), 4))
                train_losses.append(round(cum_loss / i, 4))

        corr_sum = 0.0
        cum_loss = 0.0
        model.eval()

        topk_acc = 0.0
        topk_loss = 0.0
        success_acc = 0.0
        success_counter = 0
        hit_count = 0
        total_positive = 0
        total_select = 0
        hit_best = 0
        total_best = 0
        select_best = 0
        a = 0
        algorithm_correct = 0.0
        algorithm_best = 0.0
        algorithm_correct_sum = 0.0
        algorithm_best_sum = 0.0

        for (i, (graphs, labels)) in enumerate((val_loader)):
            graphs = graphs.to(device=gpu)
            labels = labels.to(device=gpu)
            if labels.max() <= 0:
                a += 1
                pass
            
#            if not torch.any((graphs.edge_attr[:, 0] == 16) | (graphs.edge_attr[:, 0] == 17)):
#                continue

            if torch.isnan(graphs.x).any() or torch.isinf(graphs.x).any():
                print(f"graphs.x NaN or Inf found in input data at batch {i} with data_name: {graphs.data_name}")
                continue  # è·³è¿è¯¥æ¹æ¬¡

            with autocast():
                with torch.no_grad():
                    scores = model(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
                    #  print(f"Scores:{scores}")
                    if torch.isnan(scores).any() or torch.isinf(scores).any():
                        print(f"scores NaN or Inf found in input data at batch {i} with data_name: {graphs.data_name}")
                        continue  # è·³ï¿½~Gè¯¥ï¿½~Iï¿½æ¬¡

                    try:
                        if task == "rank":
                            loss = loss_fn(scores, labels)
                        elif task == "topk" or task == "success":
                            loss = loss_fn(nn.functional.log_softmax(scores, dim=1), labels.argmax(dim=1))
                        elif task == "algorithm":
                        #    loss = loss_fn(scores, labels)
                          #  rank_loss_fn = RankLoss(gpu=gpu).to(device=scores.device)
                          #  rank_loss = rank_loss_fn(scores, labels)
                           # mse_loss_fn = MseLoss(gpu=gpu).to(device=scores.device)
                           # mse_loss = mse_loss_fn(scores, labels)
                           # loss = mse_loss
               #             if labels.max() <= 0:
               #                 a += 1
               #                 continue
                  #          positive_indices = (labels[0] >= 0).nonzero(as_tuple=True)[0]
                  #          if labels.min() < 0:
                  #              labels = labels + torch.abs(labels.min())
                            total_sum = labels.sum()
                            sorted_labels, sorted_indices = torch.sort(labels, descending=True)
                            cumulative_sum = torch.cumsum(sorted_labels, dim=1)
                            threshold = total_sum * 0.8
                            selected_indices = sorted_indices[cumulative_sum < threshold]
                            if cumulative_sum[0][-1] > threshold:  # å¦æç´¯å åçæåä¸ä¸ªå¼å¤§äºéå¼
                                selected_indices = torch.cat((selected_indices, sorted_indices[cumulative_sum > threshold][[0]]))
                  #          selected_indices = selected_indices[torch.isin(selected_indices, positive_indices)]
                #            NLloss = torch.tensor(0.0, device = scores.device, requires_grad = True)
                #           for idx in selected_indices:
                #                target = torch.zeros(1, dtype=torch.long).to(scores.device)
                #                target[0] = idx  # å¯¹~Tç´¢~Uè®¾ç½®ä¸º 1
                #                NLloss = NLloss + loss_fn(nn.functional.log_softmax(scores, dim=1), target)

                           # positive_indices = (labels > 0).nonzero(as_tuple=True)[0]
                            # 2. ~Næ­£~@ç´¢~Uä¸­ç§»~Y selected_indices
                     #       topk_loss_fn = topKLoss(k=selected_indices.size(0)).to(device=scores.device)
                         # # print(scores)
                          # print(selected_indices)
                          # print(selected_indices.size(0))
                     #       topKloss = topk_loss_fn(scores, labels)
                 #           remaining_positive_indices = positive_indices[~torch.isin(positive_indices, selected_indices)].long().to(scores.device)
                 #           if remaining_positive_indices.any():
                 #               extra_loss = torch.relu(scores[0][remaining_positive_indices]).sum()
                 #           else:
                 #               extra_loss = torch.tensor(0.0, device = scores.device, requires_grad = True)
                           # extra_loss = torch.relu(scores[remaining_positive_indices],dim=1).sum()
                    #        scores_signs = torch.sign(scores)
                    #        labels_signs = torch.sign(labels)
                    #        mask = scores_signs != labels_signs
                    #        if mask.any():
                    #            smooth_scores = scores[mask] / (1 + torch.abs(scores[mask]))
                    #            smooth_labels = labels[mask] / (1 + torch.abs(labels[mask]))
                    #            sign_loss = torch.mean((smooth_scores - smooth_labels) ** 2)
                    #        else:
                    #            sign_loss = torch.tensor(0.0, device = scores.device, requires_grad = True)
              #              binary_labels = (labels > 0).float()
                            updated_labels = torch.zeros_like(labels).float()
                            updated_labels[0,selected_indices] = 1
                            mul_loss_fn = nn.BCEWithLogitsLoss(reduction = 'sum')
                            mul_loss = mul_loss_fn(scores, updated_labels)
             #               remaining_positive_indices = positive_indices[~torch.isin(positive_indices, selected_indices)].long().to(scores.device)
                       #     print(remaining_positive_indices)
             #               if remaining_positive_indices.any():
             #                   extra_loss = torch.relu(torch.sigmoid(scores[0][remaining_positive_indices])).sum()
             #               else:
             #                   extra_loss = torch.tensor(0.0, device = scores.device, requires_grad = True)
                  #          cosine_loss_fn = nn.CosineEmbeddingLoss()
                  #          y = torch.ones(scores.shape[0],device=scores.device)
                          #  print(y)
                  #          cosine_loss = cosine_loss_fn(scores, updated_labels, y)
                      #      scores_pos_count = (torch.sigmoid(scores) > 0.5).sum().float()
                      #      true_pos_count = selected_indices.size(0)
                      #      reg_loss = torch.abs(scores_pos_count - true_pos_count)
                            loss = mul_loss
                        if torch.isnan(loss).any() or torch.isinf(loss).any():
                            print(f"loss NaN or Inf found in loss at batch {i} with data_name: {graphs.data_name}")
                            continue  # è·³è¿è¯¥æ¹æ¬¡
                        cum_loss += loss.cpu().detach().item()
                    except Exception as e:
                        print(f"Error during loss calculation at batch {i}, data_name: {graphs.data_name}: {e}")
                        continue  # è·³è¿è¯¥æ¹æ¬¡

            for j in range(len(labels)):
                positive_mask = labels[j] > 0
                positive_scores = scores[j] > 0
             #   print(positive_mask)
                positive_true = positive_mask & positive_scores
                positiveScores = scores[j][positive_true]
                positiveLabels = labels[j][positive_true]
             #   print(positiveScores)
             #   print(positiveLabels)
                if positiveScores.numel() == 0 or positiveLabels.numel() == 0 or torch.all(positiveLabels == positiveLabels[0]) or torch.all(positiveScores == positiveScores[0]):
                    a += 1
                    corr = torch.tensor(0.0, device = scores[j].device, requires_grad = True)
                else:
                    corr, _ = spearmanr(positiveScores.cpu().detach(), positiveLabels.cpu().detach().tolist())
              #  if not np.isfinite(corr):
              #      print(
              #          f"corr NaN or Inf found in correlation at batch {i}, data {j} with data_name: {graphs.data_name}")
              #      continue  # è·³è¿è¯¥æ°æ®
                corr_sum += corr
                # assert corr <=1, str(scores) + " " + str(labels)
                _, scoreTopk = scores[j].topk(k)
                labelTopk = labels[j].argmax()
                topk_acc += labelTopk in scoreTopk
                # exit()

                success_counter += (labels[j].max() > 0).item()
                success_acc += (labels[j][scores[j].argmax()] > 0).item()
                if labels[j].max() > 0:
                  #  positive_indices = (scores[j] > 0).nonzero(as_tuple=True)[0]

                 #   corresponding_scores = scores[j][positive_indices]
                 #   print(corresponding_scores)
                  #  predictions_above_zero = corresponding_scores > 0
                   # label_len = (labels[j] > 0).sum().item()
                   # hit_len = (labels[j][positive_indices] > 0).sum().item()
                   # if positive_indices.size(0) == label_len and hit_len == label_len:
                   #     hit_best += 1
                    positive_indices = (scores[j] > 0).nonzero(as_tuple=True)[0]
       #             positive_indices_labels = (labels[j] >= 0).nonzero(as_tuple=True)[0]
       #             if labels.min() < 0:
       #                 labels = labels + torch.abs(labels.min())
                    total_sum = labels.sum()
                    sorted_labels, sorted_indices = torch.sort(labels, descending=True)
                    cumulative_sum = torch.cumsum(sorted_labels, dim=1)
                    threshold = total_sum * 0.8
                    selected_indices = sorted_indices[cumulative_sum < threshold]
                    if cumulative_sum[0][-1] > threshold:  # å¦æç´¯å åçæåä¸ä¸ªå¼å¤§äºéå¼
                        selected_indices = torch.cat((selected_indices, sorted_indices[cumulative_sum > threshold][[0]]))
       #             selected_indices = selected_indices[torch.isin(selected_indices, positive_indices_labels)]
                    common_indices = torch.isin(positive_indices, selected_indices)

                    # è®¡~W common_indices ä¸­ä¸º True ~Z~D~E~C| ~U~G~O
                    hit_len = common_indices.sum().item()
                   # print(positive_indices) 
              #      label_len = (labels[j] > 0).sum().item()
                   # print(label_len)
              #      hit_len = (labels[j][positive_indices] > 0).sum().item()
                   # print(hit_len)
                   # print(hit_len)
                   # print(label_len)
                    if hit_len == positive_indices.size(0):
                        hit_best += 1

                    if hit_len == selected_indices.size(0) and hit_len == positive_indices.size(0):
                        select_best += 1

                    total_best += 1
                    hit_count += hit_len  # ï¿½~Dï¿½~Kæ­£ç¡®ï¿½~Z~Dï¿½~Uï¿½ï¿½~G~O
                    total_positive += positive_indices.size(0)
                    total_select += selected_indices.size(0)

                scores_signs = torch.sign(scores)
                labels_signs = torch.sign(labels) 

                # ç¬¦å·å¹éçæ¦ç
                algorithm_correct = (scores_signs == labels_signs).float()
                algorithm_correct_sum += algorithm_correct.sum()
                # æ£æ¥ææç¬¦å·æ¯å¦é½å¹é
                algorithm_best = (scores_signs == labels_signs).all(dim=1).float()  # æ¯ä¸è¡é½è¦å¹é
                algorithm_best_sum += algorithm_best.sum()

        scheduler.step(cum_loss / (i + 1))

        val_accuracies.append(round(corr_sum.item() / (i-a), 4))
        val_losses.append(round(cum_loss / i, 4))
       # print(hit_count)
       # print(total_positive)
        mystr = ("Valid-epoch " + str(epoch) + ", Avg-Loss: " + str(
            round(cum_loss / (i * batchSize), 4)) + ", Avg-Corr:" + str(
            round(corr_sum.item() / ((i-a) * batchSize), 4)) + ", TopK-Acc:" + str(
            round(topk_acc / (i * batchSize), 4)) + ", AlgPositive-Acc:" + str(round(hit_count / total_positive, 4)) + ", AlgSelect:" + str(round(hit_count / total_select, 4))
            + ", AlgPositive-Best:" + str(round(hit_best / total_best,4)) + ", AlgContains:" + str(round(select_best / total_best,4))
            + ", Algorithm-Corr:" + str(round(algorithm_correct_sum.item() / (i * len(labels[0]) * batchSize), 4))
            + ", Algorithm-Best:" + str(round(algorithm_best_sum.item() / (i * batchSize), 4)))
        print(mystr)
        if optimizer.param_groups[0]['lr'] < 1e-7:
            break

    return train_accuracies, train_losses, val_accuracies, val_losses

def train_modelTool(model, algPositiveModel, algNegativeModel, loss_fn, batchSize, trainset, valset, optimizer, scheduler, num_epochs, gpu, task, k=1,
                trainWeights=None, valWeights=None):
    '''
    Function used to train networks
    '''
    if trainWeights is None:
        train_loader = torch_geometric.loader.DataLoader(dataset=trainset, batch_size=batchSize, shuffle=True)
        val_loader = torch_geometric.loader.DataLoader(dataset=valset, batch_size=batchSize, shuffle=True)
    else:
        trainSampler = WeightedRandomSampler(weights=trainWeights, num_samples=len(trainWeights))
        valSampler = WeightedRandomSampler(weights=valWeights, num_samples=len(valWeights))

        train_loader = torch_geometric.loader.DataLoader(dataset=trainset, batch_size=batchSize, sampler=trainSampler)
        val_loader = torch_geometric.loader.DataLoader(dataset=valset, batch_size=batchSize, sampler=valSampler)

    train_losses = [];
    val_losses = []

    for epoch in range(0, num_epochs):
        cum_loss = 0.0
        topk_acc = 0.0
        top1_acc = 0.0
        a = 0

        success_acc = 0.0
        success_counter = 0
        total_batches = len(train_loader)
        model.train()
        torch.enable_grad()
        for (i, (graphs, labels1, labels0, labels_true)) in enumerate(tqdm.tqdm(train_loader)):
            graphs = graphs.to(device=gpu)
            labels1 = labels1.to(device=gpu)
            labels0 = labels0.to(device=gpu)
            labels_true = labels_true.to(device=gpu)
          #  print(labels)
          #  print(labels_true)
            assert graphs.x.size(0) == graphs.batch.size(0)

            if labels_true.max() <= 0:
                a += 1
                continue

            with torch.no_grad():
                algPositiveScores = algPositiveModel(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
                algNegativeScores = algNegativeModel(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
             #   print(algPositiveScores)
             #   print(algNegativeScores)

            with autocast():
          #      binary_algPositiveScores = torch.where(algPositiveScores > 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
          #      binary_algNegativeScores = torch.where(algNegativeScores > 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
                sig_algPScores = torch.sigmoid(algPositiveScores)
                sig_algNScores = torch.sigmoid(algNegativeScores)
  #              sig_algNScores = 1 - torch.sigmoid(algPositiveScores)
              #  combine_scores = torch.where((binary_algPositiveScores == 1) & (binary_algNegativeScores != 1), binary_algPositiveScores, torch.tensor(0.0,device='cuda:0'))
                combine = torch.cat((sig_algPScores, sig_algNScores), dim=1)
            #    combine = torch.cat((combine, combine_scores), dim=1)
           #     print(binary_algPositiveScores)
           #     print(binary_algNegativeScores)
           #     print(combine_scores)
           #     print(combine)
           #     print("---------------------------")
                scores = model(combine)
         #       print(scores)
                if task == "topk":
                    loss = loss_fn(nn.functional.log_softmax(scores, dim=1), labels_true.argmax(dim=1))
                    reward = 0 if scores.argmax(dim=1) == labels_true.argmax(dim=1) else (scores.argmax(dim=1) == labels_true.argmax(dim=1)).float() * 2 - 1
                 #   print(loss)
                 #   print(reward)
                 #   loss = loss - reward
          #          print(loss)
                    cum_loss += loss.cpu().detach().item()

            for j in range(len(labels_true)):
                _, scoreTopk = scores[j].topk(k)
                scoreTop1 = scores[j].argmax()
                labelTopk = labels_true[j].argmax()
                top1_acc += (scoreTop1 == labelTopk).sum().float().item()
                topk_acc += labelTopk in scoreTopk
           #     print(labels_true)
           #     print(scoreTop1)
        #        print(labelTopk)
           #     print(scoreTopk)
            #    print(top1_acc)
            #    print(topk_acc)
                # exit()

                success_counter += (labels_true[j].max() > 0).item()
                success_acc += (labels_true[j][scores[j].argmax()] > 0).item()

            optimizer.zero_grad()
            loss.backward()
            model.float()
            optimizer.step()
            if round(len(trainset) // batchSize, -2) != 0:
                condition = (((i + 1) / round(len(trainset) // batchSize, -2)) * 100) % 10 == 0
            else:
                condition = False
            if condition or (i + 1) == len(train_loader):
                mystr = "Train-epoch " + str(epoch) + ", Avg-Loss: " + str(
                    round(cum_loss / ((i-a) * batchSize), 4)) + ", TopK-Acc:" + str(
                    round(topk_acc / ((i-a) * batchSize), 4)) + ", Top1-Acc:" + str(
                    round(top1_acc / ((i-a) * batchSize), 4)) + ", Success-Acc:" + str(
                    round(success_acc / success_counter, 4))

                print(mystr)
                train_losses.append(round(cum_loss / (i-a), 4))

        cum_loss = 0.0
        model.eval()

        topk_acc = 0.0
        top1_acc = 0.0
        a = 0
        total_batches = len(val_loader)
        success_acc = 0.0
        success_counter = 0

        for (i, (graphs, labels1, labels0, labels_true)) in enumerate((val_loader)):
            graphs = graphs.to(device=gpu)
            labels1 = labels1.to(device=gpu)
            labels0 = labels0.to(device=gpu)
            labels_true = labels_true.to(device=gpu)
            if labels_true.max() <= 0:
                a += 1
                pass
            if torch.isnan(graphs.x).any() or torch.isinf(graphs.x).any():
                print(f"graphs.x NaN or Inf found in input data at batch {i} with data_name: {graphs.data_name}")
                continue  # è·³è¿è¯¥æ¹æ¬¡

            with autocast():
                with torch.no_grad():
                    algPositiveScores = algPositiveModel(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
                    algNegativeScores = algNegativeModel(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
               #     binary_algPositiveScores = torch.where(algPositiveScores > 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
               #     binary_algNegativeScores = torch.where(algNegativeScores > 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
                    sig_algPScores = torch.sigmoid(algPositiveScores)
                    sig_algNScores = torch.sigmoid(algNegativeScores)
   #                 sig_algNScores = 1 - torch.sigmoid(algPositiveScores)
               #     combine_scores = torch.where((binary_algPositiveScores == 1) & (binary_algNegativeScores != 1), binary_algPositiveScores, torch.tensor(0.0,device='cuda:0'))
                    combine = torch.cat((sig_algPScores, sig_algNScores), dim=1)
               #     combine = torch.cat((combine, combine_scores), dim=1)
                    scores = model(combine)
                    if torch.isnan(scores).any() or torch.isinf(scores).any():
                        print(f"scores NaN or Inf found in input data at batch {i} with data_name: {graphs.data_name}")
                        continue  # è·³ï¿½~Gè¯¥ï¿½~Iï¿½æ¬¡

                    try:
                        if task == "topk":
                            loss = loss_fn(nn.functional.log_softmax(scores, dim=1), labels_true.argmax(dim=1))
               #             reward = (scores.argmax(dim=1) == labels_true.argmax(dim=1)).float() * 2 -1
                            reward = 0 if scores.argmax(dim=1) == labels_true.argmax(dim=1) else (scores.argmax(dim=1) == labels_true.argmax(dim=1)).float() * 2 - 1
                          #  loss = loss - reward
                        cum_loss += loss.cpu().detach().item()
                    except Exception as e:
                        print(f"Error during loss calculation at batch {i}, data_name: {graphs.data_name}: {e}")
                        continue  # è·³è¿è¯¥æ¹æ¬¡

            for j in range(len(labels_true)):
                # assert corr <=1, str(scores) + " " + str(labels)
                _, scoreTopk = scores[j].topk(k)
                scoreTop1 = scores[j].argmax()
                labelTopk = labels_true[j].argmax()
                top1_acc += (scoreTop1 == labelTopk).sum().float().item()
                topk_acc += labelTopk in scoreTopk
                # exit()

                success_counter += (labels_true[j].max() > 0).item()
                success_acc += (labels_true[j][scores[j].argmax()] > 0).item()

        scheduler.step(cum_loss / (i + 1))

        val_losses.append(round(cum_loss / (i-a), 4))

        mystr = "Valid-epoch " + str(epoch) + ", Avg-Loss: " + str(
                    round(cum_loss / ((i-a) * batchSize), 4)) + ", TopK-Acc:" + str(
                    round(topk_acc / ((i-a) * batchSize), 4)) + ", Top1-Acc:" + str(
                    round(top1_acc / ((i-a) * batchSize), 4)) + ", Success-Acc:" + str(
                    round(success_acc / success_counter, 4))

        print(mystr)
        if optimizer.param_groups[0]['lr'] < 1e-7:
            break

    return train_losses, val_losses

def evaluateTool(model, algPositiveModel, algNegativeModel, test_set, files, gpu=0, k=3):
    '''
    Function used to evaluate model on test set
    '''
    topKAcc = np.array([0.0] * 4)
    bestPredicts = np.array([0] * 4)
    correctPredicts = np.array([0] * 4)
    possibleCorrect = np.array([0] * 4)
    predSpot = np.array([[0] * test_set[0][1].size(0)] * 4)
    probCounter = np.array([0] * 4)

    predicted = np.array([[0] * test_set[0][1].size(0)] * 5)
    predicts = dict()
    top1_acc = 0.0
    a = 0

    model.eval()
    algPositiveModel.eval()
    algNegativeModel.eval()

    test_loader = torch_geometric.loader.DataLoader(dataset=test_set, batch_size=1)

    for (i, (graphs, labels1, labels0, labels_true)) in enumerate(tqdm.tqdm(test_loader, leave=False)):
        graphs = graphs.to(device=gpu)
        labels1 = labels1.to(device=gpu)
        labels0 = labels0.to(device=gpu) 
        labels_true = labels_true.to(device=gpu)
        problemTypes = graphs.problemType

#        if labels_true.min() > 0:
#            a += 1
#            continue

        with autocast():
            with torch.no_grad():
                algPositiveScores = algPositiveModel(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
                algNegativeScores = algNegativeModel(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
           #     binary_algPositiveScores = torch.where(algPositiveScores > 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
           #     binary_algNegativeScores = torch.where(algNegativeScores > 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
                sig_algPScores = torch.sigmoid(algPositiveScores)
                sig_algNScores = torch.sigmoid(algNegativeScores)
                sig_algNScores = 1 - torch.sigmoid(algPositiveScores)
           #     combine_scores = torch.where((binary_algPositiveScores == 1) & (binary_algNegativeScores != 1), binary_algPositiveScores, torch.tensor(0.0,device='cuda:0'))
                combine = torch.cat((sig_algPScores, sig_algNScores), dim=1)
              #  combine = torch.cat((combine, combine_scores), dim=1)
                scores = model(combine)

        predicts[files[i]] = scores.tolist()

        for j in range(len(labels_true)):
            _, scoreTopk = scores[j].topk(k)
            scoreTop1 = scores[j].argmax()
            labelTopk = labels_true[j].argmax()
            top1_acc += (scoreTop1 == labelTopk).sum().float().item()
            topKAcc[int(problemTypes.item())] += labelTopk in scoreTopk

        bestPredicts[int(problemTypes.item())] += (scores.argmax(dim=1) == labels_true.argmax(dim=1)).sum().item()

        for idx in scores.argmax(dim=1):
            predicted[0][idx.item()] += 1
            predicted[int(problemTypes.item()) + 1][idx.item()] += 1

        maxScoresIdx = scores.argmax(dim=1).reshape(len(scores), 1)
        gather = labels_true.gather(1, maxScoresIdx)
        if labels_true.min() <= 0:
            correctPredicts[int(problemTypes.item())] += (gather > 0).sum().item()

        predSpot[int(problemTypes.item())][
            np.where((-labels_true).argsort().cpu().numpy() == scores.argmax().item())[1]] += 1

        if labels_true.max() > 0 and labels_true.min() <= 0:
            possibleCorrect[int(problemTypes.item())] += 1
        probCounter[int(problemTypes.item())] += 1

    res = [[], [], [], [], []]

    res[0] = np.array(
        [topKAcc.sum() / (probCounter.sum()-a), top1_acc / (probCounter.sum()-a), bestPredicts.sum() / (probCounter.sum()-a),
         correctPredicts.sum() / possibleCorrect.sum(), predSpot.sum(axis=0),a], dtype=object)
    for i in range(0, len(res) - 1):
        res[i + 1] = np.array([topKAcc[i] / probCounter[i] if probCounter[i] != 0 else 0,
                               bestPredicts[i] / probCounter[i] if probCounter[i] != 0 else 0,
                               correctPredicts[i] / possibleCorrect[i] if possibleCorrect[i] != 0 else 0, predSpot[i]],
                              dtype=object)

    return res, predicted, predicts

def evaluate(model, test_set, files, gpu=0, k=3):
    '''
    Function used to evaluate model on test set
    '''
    corr_sum = 0.0
    topKAcc = np.array([0.0]*4)
    bestPredicts = np.array([0]*4)
    correctPredicts = np.array([0]*4)
    possibleCorrect = np.array([0]*4)
    predSpot = np.array([[0]*test_set[0][1].size(0)]*4)
    probCounter = np.array([0]*4)

    predicted = np.array([[0]*test_set[0][1].size(0)]*5)
    predicts = dict()
    hit_count = 0
    total_positive = 0
    total_select = 0
    hit_best = 0
    select_best = 0
    total_best = 0
    algorithm_correct = 0.0
    algorithm_best = 0.0
    algorithm_correct_sum = 0.0
    algorithm_best_sum = 0.0
    sum = 0
    a = 0
    model.eval()

    test_loader = torch_geometric.loader.DataLoader(dataset=test_set, batch_size=1)

    for (i, (graphs,labels)) in enumerate(tqdm.tqdm(test_loader, leave=False)):
        graphs = graphs.to(device=gpu)
        labels = labels.to(device=gpu)
        problemTypes = graphs.problemType
        
        if not torch.any((graphs.edge_attr[:, 0] == 16) | (graphs.edge_attr[:, 0] == 17)):
            continue
#        if labels.max() <= 0:
#            continue

        sum += 1

        with autocast():
            with torch.no_grad():
                scores = model(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
        
        predicts[files[i]] = scores.tolist()

        for j in range(len(labels)):
            positive_mask = labels[j] > 0
            positive_scores = scores[j] > 0
             #   print(positive_mask)
            positive_true = positive_mask & positive_scores
            positiveScores = scores[j][positive_true]
            positiveLabels = labels[j][positive_true]
          #  print(positiveScores)
          #  print(positiveLabels)
            if positiveScores.numel() == 0 or positiveLabels.numel() == 0 or torch.all(positiveLabels == positiveLabels[0]) or torch.all(positiveScores == positiveScores[0]):
                a += 1
                corr = torch.tensor(0.0, device = scores[j].device, requires_grad = True)
            else:
                corr, _ = spearmanr(positiveScores.cpu().detach(), positiveLabels.cpu().detach().tolist())
           # corr, _ = spearmanr(labels[j].cpu().detach(), scores[j].cpu().detach().tolist())
            corr_sum += corr
            _, scoreTopk = scores.topk(k)
            labelTopk = labels.argmax()
            topKAcc[int(problemTypes.item())] += labelTopk in scoreTopk

            if labels[j].max() > 0:
             #   scores_sigmoid = torch.sigmoid(scores[j])
                positive_indices = (scores[j] > 0).nonzero(as_tuple=True)[0]
            #    positive_indices = (scores_sigmoid > (0.8*(10 / 18))).nonzero(as_tuple=True)[0]
      #          positive_indices_labels = (labels[j] >= 0).nonzero(as_tuple=True)[0]

                 #   corresponding_scores = scores[j][positive_indices]
                 #   print(corresponding_scores)
                  #  predictions_above_zero = corresponding_scores > 0
              #  label_len = (labels[j] > 0).sum().item()
              #  hit_len = (labels[j][positive_indices] > 0).sum().item()
              #  if positive_indices.size(0) == label_len and hit_len == label_len:
              #      hit_best += 1
           #     positive_indices = (scores[j] > 0).nonzero(as_tuple=True)[0]
      #          if labels.min() < 0:
      #              labels = labels + torch.abs(labels.min())
                total_sum = labels.sum()
                sorted_labels, sorted_indices = torch.sort(labels, descending=True)
                cumulative_sum = torch.cumsum(sorted_labels, dim=1)
                threshold = total_sum * 0.8
                selected_indices = sorted_indices[cumulative_sum < threshold]
                if cumulative_sum[0][-1] > threshold:  # å¦æç´¯å åçæåä¸ä¸ªå¼å¤§äºéå¼
                    selected_indices = torch.cat((selected_indices, sorted_indices[cumulative_sum > threshold][[0]]))
      #          selected_indices = selected_indices[torch.isin(selected_indices, positive_indices_labels)]
                common_indices = torch.isin(positive_indices, selected_indices)
                    # è®¡~W common_indices ä¸­ä¸º True ~Z~D~E~C| ~U~G~O
                hit_len = common_indices.sum().item()
                   # print(positive_indices) 
        #        label_len = (labels[j] > 0).sum().item()
                   # print(label_len)
        #        hit_len = (labels[j][positive_indices] > 0).sum().item()
                   # print(hit_len)
                   # print(hit_len)
                   # print(label_len)
                if hit_len == positive_indices.size(0):
                    hit_best += 1
                
                if hit_len == selected_indices.size(0) and hit_len == positive_indices.size(0):
                    select_best += 1
                total_best += 1

                hit_count += hit_len  # ï¿½~Dï¿½~Kæ­£ç¡®ï¿½~Z~Dï¿½~Uï¿½ï¿½~G~O
                total_positive += positive_indices.size(0)
                total_select += selected_indices.size(0)
            scores_signs = torch.sign(scores)
            labels_signs = torch.sign(labels)
      
            # ç¬¦å·å¹éçæ¦ç
            algorithm_correct = (scores_signs == labels_signs).float()
            # æ£æ¥ææç¬¦å·æ¯å¦é½å¹é
            algorithm_best = (scores_signs == labels_signs).all(dim=1).float()  # æ¯ä¸è¡é½è¦å¹é
            algorithm_correct_sum += algorithm_correct.sum()
            algorithm_best_sum += algorithm_best.sum()
            

        bestPredicts[int(problemTypes.item())] += (scores.argmax(dim=1) == labels.argmax(dim=1)).sum().item()

        for idx in scores.argmax(dim=1):
            predicted[0][idx.item()]+=1
            predicted[int(problemTypes.item())+1][idx.item()]+=1

        maxScoresIdx = scores.argmax(dim=1).reshape(len(scores),1)
        gather = labels.gather(1, maxScoresIdx)
        if labels.min()<=0:
            correctPredicts[int(problemTypes.item())]+=(gather>0).sum().item()

        predSpot[int(problemTypes.item())][np.where((-labels).argsort().cpu().numpy()==scores.argmax().item())[1]] +=1
        
        
        if labels.max() > 0 and labels.min()<=0:
            possibleCorrect[int(problemTypes.item())]+=1
        probCounter[int(problemTypes.item())]+=1
    
    res = [[],[],[],[],[]]
    
    res[0] = np.array([corr_sum.item() /(sum - a), topKAcc.sum()/probCounter.sum(), bestPredicts.sum()/probCounter.sum(), correctPredicts.sum()/possibleCorrect.sum(),hit_count / total_positive if total_positive > 0 else 0.0,hit_count / total_select, hit_best / total_best, select_best / total_best, algorithm_correct_sum.item() /(sum * len(labels[0])), algorithm_best_sum.item() / sum, sum, predSpot.sum(axis=0)], dtype=object)
    for i in range(0,len(res)-1):
        res[i+1] = np.array([topKAcc[i]/probCounter[i] if probCounter[i] != 0 else 0, bestPredicts[i]/probCounter[i] if probCounter[i] != 0 else 0, correctPredicts[i]/possibleCorrect[i] if possibleCorrect[i] != 0 else 0, predSpot[i]], dtype=object)


    return res, predicted, predicts

def evaluate2(Tmodel, Pmodel,Nmodel, test_set, files, gpu=0, k=3):
    '''
    Function used to evaluate model on test set
    '''
    corr_sum = np.array([0.0] * 4)
    topKAcc = np.array([0.0] * 4)
    bestPredicts = np.array([0] * 4)
    correctPredicts = np.array([0] * 4)
    possibleCorrect = np.array([0] * 4)
    predSpot = np.array([[0] * test_set[0][3].size(0)] * 4)
    probCounter = np.array([0] * 4)
    evaluation_score = 0  # æ°å¢çè¯ä¼°ææåé
    total_predictions = 0  # æ»é¢æµæ¬¡æ°
    valid_max_score = 0
    predicted = np.array([[0] * test_set[0][3].size(0)] * 5)
    predicts = dict()
    mingzhong = 0
    total = 0
    count = 0
    log_file = 'command_results_memory.txt'
    best = 0
   # start_time = time.time()
    # ä¸å¾æ°ç»
   # array = np.array([
   #     [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1],
   #     [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1],
   #     [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1],
   #     [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1],
   #     [0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
   #     [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
   #     [1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1],
   #     [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0],
   #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
   #     [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]
   # ])

    array = np.array([
        [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1]
    ])

#    array = np.array([
#        [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
#        [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1],
#        [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1],
#        [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
#        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
#        [0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
#        [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1],
#        [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0],
#        [1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1],
#        [1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
#        [1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
#        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
#        [1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1],
#        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#        [0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1],
#        [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1],
#        [1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1],
#        [0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
#        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
#        ])

   # weight_tools =np.array([
   #     [0.5931579771410678, 0.21118170266836087, 0.49768637532133675, 0.8311752988047809],
   #     [-4.828088304368248, 0.12864040660736975, 0.2683804627249357,  0.7579681274900398],
   #     [0.5573038985439174, 0.16289707750952986, 0.47917737789203285, 0.81050796812749],
   #     [0.6777047126976672, 0.21768742058449808, 0.7005141388174807, 0.7878486055776892],
   #     [0.7577892594332237, 0.5279288437102923, 0.4025706940874036, 0.836902390438247],
   #     [0.46485047753248787,0.4484879288437103, 0.4755784061696658, 0.029133466135458166],
   #     [0.658290277125411, 0.18312579415501906, 0, 0.8314243027888446],
   #     [0.5683419445749178, 0.15135959339263025, 0, 0.5530378486055777],
   #     [0.6701111633004541, 0.57123258285895807, 0.4670951156812339, 0.8095119521912351],
   #     [0.5865821199311101, 0.0839135959339263, 0.4570694087403599, 0.7664342629482072]
   #     ])
    
    weight_tools = np.array([
        [0.4233023937469467,0.4149049093321539, 0.4957519116397621, 0.07915690866510539],
        [0.033341475329750854, 0.1525873507297656, 0.221750212404418,  0.8004683840749415],
        [0.5787738153395212, 0.24586466165413534, 0.4524214103653356, 0.8098360655737705],
        [0.6369687347337567, 0.2191065900044228, 0.3903993203058624, 0.8180327868852459],
        [0.6834391792867611, 0.5720477664750111, 0.5135938827527613, 0.8220140515222483],
        [0.7639838788470933, 0.5313578062804069, 0.33262531860662703, 0.8344262295081967],
        [0.6098558866634098, 0.24900486510393632, 0.4821580288870008, 0.8665105386416861],
        [0.5534929164631167, 0.22799646174259178, 0.6928632115548004, 0.7667447306791569],
        [0.5396922325354176, 0.15643520566121186, 0, 0.5517564402810304],
        [0.658290277125411, 0.19862892525431225, 0, 0.8070257611241218]
    ])

   # weight_tools = np.array([
   #     [0.46485047753248787,0.4484879288437103, 0.4755784061696658, 0.029133466135458166],
   #     [-4.828088304368248, 0.12864040660736975, 0.2683804627249357,  0.7579681274900398],
   #     [0.5573038985439174, 0.16289707750952986, 0.47917737789203285, 0.81050796812749],
   #     [0.5865821199311101, 0.0839135959339263, 0.4570694087403599, 0.7664342629482072],
   #     [0.6701111633004541, 0.57123258285895807, 0.4670951156812339, 0.8095119521912351],
   #     [0.7577892594332237, 0.5279288437102923, 0.4025706940874036, 0.836902390438247],
   #     [0.5824330671676844, 0.0554002541296061, 0.31722365038560413, 0.7385458167330677],
   #     [0.11069359636762173, 0.20284625158831004, 0.3822622107969152, 0.7995517928286853],
   #     [0.6322217003287929, 0.09047013977128335, 0, 0.6563745019920318],
   #     [0.5931579771410678, 0.21118170266836087, 0.49768637532133675, 0.8311752988047809],
   #     [0, 0.07227445997458704, 0, 0],
   #     [0.6777047126976672, 0.21768742058449808, 0.7005141388174807, 0.7878486055776892],
   #     [0.5683419445749178, 0.15135959339263025, 0, 0.5530378486055777],
   #     [0.658290277125411, 0.18312579415501906, 0, 0.8314243027888446],
   #     [0, 0.5959339263024143, 0, 0],
   #     [0, 0.6213468869123253, 0, 0],
   #     [0.5962893377172381, 0.47557814485387545, 0.38046272493573263, 0.8102589641434262],
   #     [0.050023485204321275, 0.11786531130876747, 0.25861182519280207, 0],
   #     [0, 0, 0, 0.8478585657370518],
   #     [0, 0, 0.837789203084833, 0]
   #    ])

    Tmodel.eval()
    Pmodel.eval()
    Nmodel.eval()

    test_loader = torch_geometric.loader.DataLoader(dataset=test_set, batch_size=1)

    for (i, (graphs, labels1, labels0, labels)) in enumerate(tqdm.tqdm(test_loader, leave=False)):
        graphs = graphs.to(device=gpu)
        labels1 = labels1.to(device=gpu)
        labels0 = labels0.to(device=gpu)
        labels = labels.to(device=gpu)
    #    print(labels)
    #    print(len(labels[0]))
        problemTypes = graphs.problemType
#        if not torch.any((graphs.edge_attr[:, 0] == 16) | (graphs.edge_attr[:, 0] == 17)):
#            continue
        if labels.min() > 0:
               # print(labels.min())
            count += 1
            continue
   #     start_time = time.time()
        with autocast():
            with torch.no_grad():
                algPositiveScores = Pmodel(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
                algNegativeScores = Nmodel(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
            #    binary_algPositiveScores = torch.where(algPositiveScores > 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
            #    binary_algNegativeScores = torch.where(algNegativeScores > 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
                sig_algPScores = torch.sigmoid(algPositiveScores)
                sig_algNScores = torch.sigmoid(algNegativeScores)
 #               sig_algNScores = 1 - torch.sigmoid(algPositiveScores)
            #    select2Scores = torch.where((binary_algPositiveScores == 1) & (binary_algNegativeScores != 1), binary_algPositiveScores, torch.tensor(0.0,device='cuda:0'))
                combine = torch.cat((sig_algPScores, sig_algNScores), dim=1)
               # N_combine = torch.where(binary_algPositiveScores == 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
            #    select2Scores = torch.where(binary_algPositiveScores == 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
         #       print(binary_algPositiveScores)
         #       print(select2Scores)
                scores = Tmodel(combine)
  #              end_time = time.time()
  #              elapsed_time = end_time - start_time
                _, scoreTopk = scores.topk(k)
                scoreTopk_flat = scoreTopk.cpu().numpy().flatten()
                array_tensor = torch.tensor(array, device='cuda:0')
                NTool_0 = torch.where(array_tensor[scoreTopk_flat[0]] == 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0,device='cuda:0'))
               # print(combine)
               # print(array_tensor[scoreTopk_flat[0]])
               # print(NTool_0)
                NTool_combine = torch.cat((array_tensor[scoreTopk_flat[0]], NTool_0),dim = 0)
               # print(NTool_combine)
                select2Scores = combine - NTool_combine * 0.1
               # print(select2Scores)
                D2scores = Tmodel(select2Scores)
                _, score2Topk = D2scores.topk(k)
                score2Topk_flat = score2Topk.cpu().numpy().flatten()
                if scoreTopk_flat[0] == score2Topk_flat[0]:
                    index = score2Topk_flat[1]
                else:
                    index = score2Topk_flat[0]
                NTool2_0 = torch.where(array_tensor[index] == 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0,device='cuda:0'))
                NTool2_combine = torch.cat((array_tensor[index], NTool2_0), dim = 0)
                select3Scores = select2Scores - NTool2_combine * 0.1
                D3scores = Tmodel(select3Scores)

        predicts[files[i]] = scores.tolist()

        for j in range(len(labels)):
            corr, _ = spearmanr(labels[j].cpu().detach(), scores[j].cpu().detach().tolist())
#            print(labels[j].shape)
#            print(scores[j].shape)
            corr_sum[int(problemTypes.item())] += corr
            _, scoreTopk = scores.topk(k)
            _, score2Topk = D2scores.topk(k)
            _, score3Topk = D3scores.topk(k)
            labelTopk = labels.argmax()
            _,toperror = scores.topk(1)
           # print(labelTopk)
           # print(labelTopk.item())
      #      topKAcc[int(problemTypes.item())] += labelTopk in scoreTopk
            topKAcc[int(problemTypes.item())] += labelTopk in toperror
           # print(scoreTopk) 
           # print(scores[j,0].shape)
          # # print(scores)
            # è®¡ç®é¢æµå¾åå¤§äº0æ¶å¯¹åºåçç´¯å 
           # for idx in scoreTopk.cpu().numpy():
           #     print(idx)
            
            scores_np = scores[j].cpu().detach().numpy().flatten()
            scoreTopk_indices = scoreTopk.cpu().numpy()
            scoreTopk_values = scores[j,scoreTopk_indices].cpu().detach().numpy()  # top-k åæ
            scoreTopk_flat = scoreTopk.cpu().numpy().flatten()  # flatten ç¡®ä¿æ¯ 1D æ°ç»
            score2Topk_flat = score2Topk.cpu().numpy().flatten()
            score3Topk_flat = score3Topk.cpu().numpy().flatten()
         #   print(scoreTopk_flat)
         #   print(score2Topk_flat)
          # scoreTopk_flat = scoreTopk_indices
            min_score = np.min(scoreTopk_values)
            adjustment = -min_score * 2 if min_score < 0 else 0
            scoreTopk_values += adjustment
            scores_sum = np.sum(scoreTopk_values)  # è®¡ç®æ°ç»å
            weights = scoreTopk_values / scores_sum  # è®¡ç®æé
          #  print("weights")
          #  print(weights)
        # valid_topk_idx = [idx for idx in scoreTopk_flat if scores[0,idx].item() > 0]
          #  print("scoreTopk")
          #  print(scoreTopk)
      #      print("scores")
      #      print(scores)
          #  print("valid_topk_idx")
          #  print(valid_topk_idx)
            col_sums = np.zeros(18)
            if len(scoreTopk_flat) > 0:
                a = 0
              #  weights = weights.flatten()
          #      print(weights)
                for idx in scoreTopk_flat:
                    score_value = scores[j, idx].item()
                    scaling_factor = 1 / (2 ** a)  # ç¬¬ä¸è½®ä¸º 1ï¼ä¹åæ¯ä¸è½®ä¾æ¬¡åå°
                    for d in range(len(col_sums)):
                       # if col_sums[d] != 1:
                         #   if score_value > 0:
                            col_sums[d] += array[idx,d] * scaling_factor
                           # else:
                           #     col_sums[d] -= array[idx,d] * scaling_factor
                       # elif col_sums[d] == 1:
                         #   if score_value < 0:
                         #       col_sums[d] -= array[idx,d] * scaling_factor
                    a += 1
        #        print("col_sums")
        #        print(col_sums)
                if col_sums.size > 0:
                    max_value = col_sums.max()

                    if max_value >= 0:
                        max_cols = np.where(col_sums == col_sums.max())[0]  # æ¾å°æå¤§å¼çåç´¢å¼ï¼å¯è½æå¤ä¸ªï¼
                    else:
                        continue
               #print(f"æå¤§åç´¢å¼: {max_cols}")
                else:
                    max_cols = []
                    continue
            else:
                max_cols = []
                continue
         #   print("max_cols")
         #   print(max_cols)
            # è®¡ç®çå®æ ç­¾çç´¯å å
            real_label_sums = np.zeros(18)
            sorted_labelsIndex = np.argsort(-labels[j].cpu().numpy()) 
            label_np = labels[j].cpu().detach().numpy().flatten()
         #   min_label = np.min(label_np)
         #   adjustment_label = -min_label * 2 if min_label < 0 else 0
         #   label_np += adjustment_label
         #   labels_sum = np.sum(label_np)
         #   weights_label = label_np / labels_sum
          #  print("labels")
          #  print(labels)
          #  print("sorted_labelsIndex")
          #  print(sorted_labelsIndex)
        #    weights_label = weights_label.flatten()
          #  print("weights_label")
          #  print(weights_label)
            b = 0
            for label_idx in sorted_labelsIndex:
                label_value = labels[j, label_idx].item()
                label_factor = 1 / (2 ** b)
                for c in range(len(real_label_sums)):
                   # if real_label_sums[c] != 1:
                    if label_value > 0:
                            # ç´¯å å¯¹åºçåå¼
                        real_label_sums[c] += array[label_idx,c] * label_factor
                           # real_label_sums[c] += array[label_idx,c] * weights_label[label_idx]
                    else:
                            # æ ç­¾ä¸ºè´æ°æ¶ï¼å°å¯¹åºåçå¼ä¹ä»¥ -1 ç´¯å 
                        real_label_sums[c] -= array[label_idx,c] * label_factor
                  #  elif real_label_sums[c] == 1:
                    #    if label_value < 0:
                     #       real_label_sums[c] -= array[label_idx,c]
                b += 1

            #print(f"çå®æ ç­¾ç´¯å å: {real_label_sums}")
          #  print("real_label_sums")
          #  print(real_label_sums)
            # å¯¹ç´¯å åçç´¢å¼è¿è¡ä»å¤§å°å°æåº
            sorted_idx = np.argsort(-real_label_sums)  # ä»å¤§å°å°æåºï¼å¹¶å 1ä»¥ä»1å¼å§
            sorted_sums = np.sort(-real_label_sums) * -1  # å¯¹ç´¯å åè¿è¡ä»å¤§å°å°æåº
         #   print("sorted_idx")
         #   print(sorted_idx)
         #   print("sorted_sums")
         #   print(sorted_sums)
            #print(f"æåºåçåç´¢å¼: {sorted_idx}")
            #print(f"å¯¹åºçç´¯å å: {sorted_sums}")

            # ç¡®ä¿ç»æä¸ºé¿åº¦ä¸º[1,10]çæ°ç»
            #sorted_idx = sorted_idx[:10]
            #sorted_sums = sorted_sums[:10]
         #   print("sorted_idx")
         #   print(sorted_idx)
         #   print("sorted_sums")
         #   print(sorted_sums)
            #print(f"æç»ç[1,10]æ°ç»: {sorted_idx}")
            #print(f"å¯¹åºçç´¯å åæ°ç»: {sorted_sums}")

            # è¯ä¼°åçæ§ï¼ä½¿ç¨é¢æµæå¤§åç´¢å¼ä¸ç´¯å åç´¢å¼è¿è¡å¯¹æ¯
            for max_col in max_cols:
               # # è·åè¯¥æå¤§åç´¢å¼å¯¹åºç´¯å åçä½ç½®
                if max_col in sorted_idx:
                    position = np.where(sorted_idx == max_col)[0][0]
              #      print("position")
              #      print(position)
                    total_predictions += 1
                    if col_sums.max() > 0 and sorted_sums[position] > 0:
                        evaluation_score += 1  # åççé¢æµ
                        if position == 0:
                            valid_max_score += 1
                    if col_sums.max() == 0 and sorted_sums[position] >= 0:
                        evaluation_score += 1  # ï¿½~P~Hï¿½~P~Fï¿½~Z~Dï¿½~Dï¿½~K
                        if position == 0:
                            valid_max_score += 1
            num = 0
          #  if labels.max() > 0:
          #      while num < 3:
          #          if num == 0:
          #              num1 = scoreTopk_flat[num]
          #              if labels[j,num1].item() > 0:
          #                  mingzhong += 1
          #                  break;
          #          elif num == 1:
          #              num1 = scoreTopk_flat[0]
          #              col_sums -= array[num1]
          #              max_cols = np.where(col_sums == col_sums.max())[0]
          #              col_sum_per_row = np.sum(array[:,max_cols],axis=1)
          #              max_tool = np.where(col_sum_per_row == col_sum_per_row.max())[0]
          #              num2 = scoreTopk_flat[num]
          #              if num2 in max_tool:
          #                  if labels[j,num2].item() > 0:
          #                      mingzhong += 1
          #                      break
          #              else:
          #                  if labels[j, max_tool[0]].item() > 0:
          #                      mingzhong += 1
          #                      break
          #          else:
          #              zero_cols = np.where(col_sums[0] == 0)[0]
          #              col_sum_per_row = np.sum(array[:,zero_cols], axis=1)
          #              nextTool = col_sum_per_row.max()
          #              if labels[j,nextTool].item() > 0:
          #                  mingzhong += 1
          #                  break;
                       # if num == 3:
                        #    if col_sum_per_row[col_sum_per_row != nextTool].size > 0:
                         #       second_max_value = col_sum_per_row[col_sum_per_row != nextTool].max()
                          #      nextTool = np.where(col_sum_per_row == second_max_value)[0][0]
                           #     if labels[j,nextTool].item() > 0:
                            #        mingzhong += 1
                             #       break
           #         num += 1
           #     total += 1
           # if labels.max() > 0:
     #       tools = np.array([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19])
            tools = np.array([0,1,2,3,4,5,6,7,8,9])
            intersection = np.intersect1d(scoreTopk_flat, score2Topk_flat)
            useTool = np.array([])
#            if labels.min() > 0:
               # print(labels.min())
#                count += 1
#                continue
              #  print(graphs.data_name)
            if labels.max() > 0:
                problemTypeIndex = int(problemTypes.item())
                while num < 1:
                    if num == 0:
                       # select2Scores_flat = select2Scores.cpu().numpy().flatten()
                       # suitableAlg = np.where(select2Scores_flat > 0)[0]
                       # suitable_values = np.zeros_like(select2Scores_flat)
                       # suitable_values[suitableAlg] = 1
                       # selected_values = array[scoreTopk_flat]
                       # euclidean_distances = []
                       # for i in range(selected_values.shape[0]):
                       #     row_to_compare = selected_values[i]
                       #     
                       #     if len(suitable_values) > len(row_to_compare):
                       #         suitable_values_subset = suitable_values[suitableAlg]
                       #     else:
                       #         suitable_values_subset = suitable_values
                       #
                       #     distance = np.linalg.norm(row_to_compare - suitable_values_subset)
                       #     euclidean_distances.append((distance, scoreTopk_flat[i]))
                       
                        #sorted_distances = sorted(euclidean_distances, key=lambda x: x[0])

                        toolIndex = scoreTopk_flat[num]
                        useTool = np.append(useTool, toolIndex)
                        remaining_tools = np.setdiff1d(tools, toolIndex)
                        remaining_intersection = np.setdiff1d(intersection, toolIndex)
                    elif num == 1:
                       # select2Scores_flat = select2Scores.cpu().numpy().flatten()
                       # suitableAlg = np.where(select2Scores_flat >0)[0]
                       # suitable_values = np.zeros_like(select2Scores_flat)
                       # suitable_values[suitableAlg] = 1
                       # selected_values = array[score2Topk_flat]
                       # euclidean_distances = []
                       # for i in range(selected_values.shape[0]):
                       #     row_to_compare = selected_values[i]
                            
                       #     if len(suitable_values) > len(row_to_compare):
                       #         suitable_values_subset = suitable_values[suitableAlg]
                       #     else:
                       #         suitable_values_subset = suitable_values

                       #     distance = np.linalg.norm(row_to_compare - suitable_values_subset)
                       #     euclidean_distances.append((distance, score2Topk_flat[i]))

                       # sorted_distances = sorted(euclidean_distances, key=lambda x: x[0])
                        if scoreTopk_flat[0] == score2Topk_flat[0]:
                            toolIndex = score2Topk_flat[1]
                            useTool = np.append(useTool, toolIndex)
                     #       if remaining_intersection.size > 0:
                     #           toolIndex = remaining_intersection[0]
                     #       else:
                     #           toolIndex = score2Topk_flat[1]
           #                 if scoreTopk_flat[1] in score2Topk_flat:
           #                     toolIndex = scoreTopk_flat[1]
           #                 elif scoreTopk_flat[2] in score2Topk_flat:
           #                     toolIndex = scoreTopk_flat[2]
           #                 else:
           #                     toolIndex = score2Topk_flat[1]
                        else:
                   #         if scoreTopk_flat[1] in score2Topk_flat:
                   #             toolIndex = scoreTopk_flat[1]
                   #         elif scoreTopk_flat[2] in score2Topk_flat:
                   #             toolIndex = scoreTopk_flat[2]
                   #         else:
                    #        if remaining_intersection.size > 0:
                    #            toolIndex = remaining_intersection[0]
                    #        else:
                    #            toolIndex = score2Topk_flat[0]
                            toolIndex = score2Topk_flat[0]
                            useTool = np.append(useTool, toolIndex)
         #               toolIndex = scoreTopk_flat[num]
                        remaining_tools = np.setdiff1d(tools, toolIndex)
                        remaining_intersection = np.setdiff1d(intersection, toolIndex)
                    else:
                        weights_in_col = weight_tools[remaining_tools, problemTypeIndex]
                        max_value_index = np.argmax(weights_in_col)
                        toolIndex = remaining_tools[max_value_index]
                     #   if remaining_intersection.size > 0:
                     #       toolIndex = remaining_intersection[0]
                     #   else:
    #                    toolIndex = scoreTopk_flat[num]
                        if score3Topk_flat[0] not in useTool:
                            if score3Topk_flat[0] == toolIndex:
                                toolIndex = score3Topk_flat[0]
                            else:
                                toolIndex = score3Topk_flat[0]
                        elif score3Topk_flat[1] not in useTool:
                            if score3Topk_flat[1] == toolIndex:
                                toolIndex = score3Topk_flat[1]
                            else:
                                toolIndex = score3Topk_flat[1]
                        elif score3Topk_flat[2] not in useTool:
                            if score3Topk_flat[2] == toolIndex:
                                toolIndex = score3Topk_flat[2]
                            else:
                                toolIndex = score3Topk_flat[2]
                  #      toolIndex = score3Topk_flat[2]
                      #  elif 
                     #   toolIndex = remaining_tools[max_value_index]
         #               toolIndex = scoreTopk_flat[num]

                    if labels[j,toolIndex].item() > 0 :
                        mingzhong += 1
                      #  problemTypeIndex = int(problemTypes.item())
                        if toolIndex == labelTopk.item():
                            best += 1
                        log_entry = f"{graphs.data_name[0]}|||{problemTypeIndex}:{toolIndex}\n"
                        with open(log_file, 'a') as f:
                            f.write(log_entry)
                        break
              #      elif num == 1:
              #          num2 = score2Topk_flat[0]
                   #     col_sums -= array[num1]
                   #     max_cols = np.where(col_sums == col_sums.max())[0]
                   #     col_sum_per_row = np.sum(array[:,max_cols],axis=1)
                   #     max_tool = np.where(col_sum_per_row == col_sum_per_row.max())[0]
                   #     num2 = scoreTopk_flat[num]
                   #     if num2 in max_tool:
                   #         if labels[j,num2].item() > 0:
                   #             mingzhong += 1
                   #             break
                   #     else:
                   #         if labels[j, max_tool[0]].item() > 0:
                   #             mingzhong += 1
                   #             break
                #    if num == 3:
                     #   zero_cols = np.where(col_sums[0] == 0)[0]
                      #  col_sum_per_row = np.sum(array[:,zero_cols], axis=1)
                      #  nextTool = col_sum_per_row.max()
               #         remaining_tools = np.setdiff1d(tools, scoreTopk_flat)
               #         problemTypeIndex = int(problemTypes.item())
                   # print(problemTypeIndex)
                #        weights_in_col = weight_tools[remaining_tools, problemTypeIndex]
                #        max_value_index = np.argmax(weights_in_col)
                #        nextTool = remaining_tools[max_value_index]
                #        if labels[j,nextTool].item() > 0:
                #            mingzhong += 1
                #            if nextTool == labelTopk.item():
                #                best += 1
                #            log_entry = f"{graphs.data_name[0]}|||{problemTypeIndex}:{nextTool}\n"
                #            with open(log_file, 'a') as f:
                #                f.write(log_entry)
                #            break
           #             if num == 4:
           #                 if col_sum_per_row[col_sum_per_row != nextTool].size > 0:
           #                     second_max_value = col_sum_per_row[col_sum_per_row != nextTool].max()
           #                     nextTool = np.where(col_sum_per_row == second_max_value)[0][0]
           #                     if labels[j,nextTool].item() > 0:
           #                         mingzhong += 1
           #                         break
                    num += 1
                total += 1

 
        bestPredicts[int(problemTypes.item())] += (scores.argmax(dim=1) == labels.argmax(dim=1)).sum().item()

        for idx in scores.argmax(dim=1):
            predicted[0][idx.item()] += 1
            predicted[int(problemTypes.item()) + 1][idx.item()] += 1

        maxScoresIdx = scores.argmax(dim=1).reshape(len(scores), 1)
        gather = labels.gather(1, maxScoresIdx)
        if labels.min() <= 0:
            correctPredicts[int(problemTypes.item())] += (gather > 0).sum().item()

        predSpot[int(problemTypes.item())][
            np.where((-labels).argsort().cpu().numpy() == scores.argmax().item())[1]] += 1

        if labels.max() > 0 and labels.min() <= 0:
            possibleCorrect[int(problemTypes.item())] += 1
        probCounter[int(problemTypes.item())] += 1

    res = [[], [], [], [], []]

    res[0] = np.array(
        [corr_sum.sum() / probCounter.sum(), topKAcc.sum() / probCounter.sum(), bestPredicts.sum() / probCounter.sum(),
         correctPredicts.sum() / possibleCorrect.sum(), predSpot.sum(axis=0),
         evaluation_score / total_predictions if total_predictions > 0 else 0,
         valid_max_score / total_predictions if total_predictions > 0 else 0,
         mingzhong / total if total > 0 else 0,
         total,
         count,
         best / total if total > 0 else 0], dtype=object)
    for i in range(0, len(res) - 1):
        res[i + 1] = np.array([corr_sum[i] / probCounter[i] if probCounter[i] != 0 else 0,
                               topKAcc[i] / probCounter[i] if probCounter[i] != 0 else 0,
                               bestPredicts[i] / probCounter[i] if probCounter[i] != 0 else 0,
                               correctPredicts[i] / possibleCorrect[i] if possibleCorrect[i] != 0 else 0, predSpot[i]],
                              dtype=object)

    return res, predicted, predicts

def evaluate3(test_set, files, gpu=0, k=3):
    mingzhong = 0
    total = 0
    best = 0
    random.seed(42)

#    weight_tools = np.array([
#        [0.4233023937469467, 0.4149049093321539, 0.4957519116397621, 0.07915690866510539],
#        [0.033341475329750854, 0.1525873507297656, 0.221750212404418, 0.8004683840749415],
#        [0.5787738153395212, 0.24586466165413534, 0.4524214103653356, 0.8098360655737705],
#        [0.6369687347337567, 0.2191065900044228, 0.3903993203058624, 0.8180327868852459],
#        [0.6834391792867611, 0.5720477664750111, 0.5135938827527613, 0.8220140515222483],
#        [0.7639838788470933, 0.5313578062804069, 0.33262531860662703, 0.8344262295081967],
#        [0.6098558866634098, 0.24900486510393632, 0.4821580288870008, 0.8665105386416861],
#        [0.5534929164631167, 0.22799646174259178, 0.6928632115548004, 0.7667447306791569],
#        [0.5396922325354176, 0.15643520566121186, 0, 0.5517564402810304],
#        [0.658290277125411, 0.19862892525431225, 0, 0.8070257611241218]
#    ])

    weight_tools = np.array([
        [0.4233023937469467, 0.4149049093321539, 0.4957519116397621, 0.07915690866510539],
        [0.033341475329750854, 0.1525873507297656, 0.221750212404418, 0.8004683840749415],
        [0.5787738153395212, 0.24586466165413534, 0.4524214103653356, 0.8098360655737705],
        [0.6369687347337567, 0.2191065900044228, 0.3903993203058624, 0.8180327868852459],
        [0.6834391792867611, 0.5720477664750111, 0.5135938827527613, 0.8220140515222483],
        [0.7639838788470933, 0.5313578062804069, 0.33262531860662703, 0.8344262295081967],
        [0.5534929164631167, 0.22799646174259178, 0.6928632115548004, 0.7667447306791569],
        [0.5396922325354176, 0.15643520566121186, 0, 0.5517564402810304],
        [0.658290277125411, 0.19862892525431225, 0, 0.8070257611241218],
        [0, 0.5943387881468377, 0, 0],
        [0, 0.6256523662096417, 0, 0],
        [0.454323400097703, 0.04820875718708536, 0.262107051826678, 0.6946135831381733],
        [0.14951312707999506, 0.3262787381928355, 0.48738033072236725, 0.8129807392307692],
        [0.4931607230092819, 0.07872622733303848, 0, 0.6163934426229508],
        [0.6098558866634098, 0.24900486510393632, 0.4821580288870008, 0.8665105386416861],
        [0, 0.07863777089783282, 0, 0],
        [0.6049096238397655, 0.5159663865546219, 0.43734069668649106, 0.8098360655737705],
        [0.0989769505731542, 0.1919889502762431, 0.21888598781549173, 0],
        [0, 0, 0, 0.8531615925058548],
        [0, 0, 0.8018266779949023, 0]
    ])

 #   weight_tools = np.array([
 #       [0.4233023937469467, 0.4149049093321539, 0.4957519116397621, 0.07915690866510539],
 #       [0.033341475329750854, 0.1525873507297656, 0.221750212404418, 0.8004683840749415],
 #       [0.5787738153395212, 0.24586466165413534, 0.4524214103653356, 0.8098360655737705],
 #       [0.6369687347337567, 0.2191065900044228, 0.3903993203058624, 0.8180327868852459],
 #       [0.6834391792867611, 0.5720477664750111, 0.5135938827527613, 0.8220140515222483],
 #       [0.7639838788470933, 0.5313578062804069, 0.33262531860662703, 0.8344262295081967],
 #       [0.5534929164631167, 0.22799646174259178, 0.6928632115548004, 0.7667447306791569],
 #       [0.5396922325354176, 0.15643520566121186, 0, 0.5517564402810304],
 #       [0.658290277125411, 0.19862892525431225, 0, 0.8070257611241218],
 #       [0, 0.5943387881468377, 0, 0],
 #       [0.454323400097703, 0.04820875718708536, 0.262107051826678, 0.6946135831381733],
 #       [0.14951312707999506, 0.3262787381928355, 0.48738033072236725, 0.8129807392307692],
 #       [0.4931607230092819, 0.07872622733303848, 0, 0.6163934426229508],
 #       [0.6098558866634098, 0.24900486510393632, 0.4821580288870008, 0.8665105386416861],
 #       [0, 0.07863777089783282, 0, 0],
 #   ])

    weight_cpu_tools = np.array([
        [0.06117021,0.055666004,0.038053,0.024727],
        [0.04388298,0.035785288,0.004228,0.003239],
        [0.04388298,0.103379722,0.171239,0.031693],
        [0.07446809,0.087475149,0.063422,0.052241],
        [0.09308511,0.125248509,0.194493,0.062689],
        [0.11835106,0.061630219,0.295968,0.094034],
        [0.09042553,0.099403579,0.141642,0.062689],
        [0.12898936,0.143141153,0.090904,0.205482],
        [0.19946809,0.137146938,2.53687E-05,0.31693],
        [0.1462766,0.151093439,2.53687E-05,0.146275]
    ])

    weight_mem_tools = np.array([
        [0.020208,0.04995005,0.1189477,0.033241],
        [0.039804,0.008491508,0.0111951,0.01569714],
        [0.042866,0.054945055,0.0874615,0.05355494],
        [0.070423,0.074925075,0.0839631,0.07017544],
        [0.205144,0.112397612,0.4897845,0.1754386],
        [0.073484,0.067432567,0.0874615,0.03878116],
        [0.064299,0.127372627,0.1049538,0.07756233],
        [0.13166,0.174825175,0.0150434,0.16620499],
        [0.208206,0.152347652,0.0005947,0.23084026],
        [0.143907,0.177322677,0.0005947,0.13850416]
    ])

    test_loader = torch_geometric.loader.DataLoader(dataset=test_set, batch_size=1)

    for (i,(graphs, labels1, labels0,labels)) in enumerate(tqdm.tqdm(test_loader, leave=False)):
        labels = labels.to(device=gpu)

        problemTypes = graphs.problemType

        for j in range(len(labels)):
            labelTopk = labels.argmax()
            if labels.min() > 0:
                continue

            if labels.max() > 0:
                problemTypeIndex = int(problemTypes.item())
                weights = weight_tools[problemTypeIndex]
                cpu_weights = weight_cpu_tools[problemTypeIndex]
                mem_weights = weight_mem_tools[problemTypeIndex]
                top_k = np.argsort(weights)[-k:][::-1]
                selected_tools = []
                for num in range(1):
                    if num == 0:
                        toolIndex = top_k[num]
                   #     toolIndex = random.randint(0,19)
                    elif num == 1:
                        # ç¬¬äºä¸ªå·¥å·ï¼ä»å©ä½å·¥å·ä¸­éæ© CPU+åå­å åæå°ç
                     #   remaining_indices = [idx for idx in top_k if idx not in selected_tools]
                     #   combined_values = weight_cpu_tools[problemTypeIndex][remaining_indices] + \
                     #                     weight_mem_tools[problemTypeIndex][remaining_indices]
                     #   min_index = np.argmin(combined_values)
                     #   toolIndex = remaining_indices[min_index]
                        remaining_indices = [idx for idx in range(10) if idx not in selected_tools]
                        toolIndex = random.choice(remaining_indices) 
                    else:
                        # ç¬¬ä¸ä¸ªå·¥å·ï¼éæ©å©ä¸çæåä¸ä¸ª
                        remaining_indices = [idx for idx in range(10) if idx not in selected_tools]
                        toolIndex = random.choice(remaining_indices)  # è¿éå¿ç¶åªå©ä¸ä¸ªå·¥å·

                    selected_tools.append(toolIndex)

                    if labels[j,toolIndex].item() > 0:
                        mingzhong += 1
                        if toolIndex == labelTopk.item():
                            best += 1

                        break
                total += 1

    res = np.array([mingzhong / total, best / total],dtype=object)
        
    return res

def evaluate4(model, test_set, files, gpu=0, k=3):
    '''
    Function used to evaluate model on test set
    '''
    corr_sum = np.array([0.0]*4)
    topKAcc = np.array([0.0]*4)
    bestPredicts = np.array([0]*4)
    correctPredicts = np.array([0]*4)
    possibleCorrect = np.array([0]*4)
    predSpot = np.array([[0]*test_set[0][1].size(0)]*4)
    probCounter = np.array([0]*4)
    best = 0
    total = 0
    mingzhong = 0
    predicted = np.array([[0]*test_set[0][1].size(0)]*5)
    predicts = dict()
    a = 0

    model.eval()

    test_loader = torch_geometric.loader.DataLoader(dataset=test_set, batch_size=1)

    for (i, (graphs,labels)) in enumerate(tqdm.tqdm(test_loader, leave=False)):
        graphs = graphs.to(device=gpu)
        labels = labels.to(device=gpu)
        problemTypes = graphs.problemType
        if labels.max() <= 0:
            a += 1
            pass
        with autocast():
            with torch.no_grad():
#<<<<<<< HEAD
                scores = model(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
#=======
                #scores = model(graphs.x, graphs.edge_index, graphs.batch,graphs.problemType)
#>>>>>>> 94747a9bd621d22b8ceffb4c2659a981b803222c

        predicts[files[i]] = scores.tolist()
       

        for j in range(len(labels)):
            if labels.min() > 0:
                a += 1
                continue
            corr, _ = spearmanr(labels[j].cpu().detach(), scores[j].cpu().detach().tolist())
    #        corr = 0
            corr_sum[int(problemTypes.item())] += corr
            _, scoreTopk = scores.topk(k)
            scoreTopk_flat = scoreTopk.cpu().numpy().flatten()
            labelTopk = labels.argmax()
            topKAcc[int(problemTypes.item())] += labelTopk in scoreTopk
            problemTypeIndex = int(problemTypes.item())
            if labels.max() > 0:
                for num in range(1):
                    if num == 0:
                        toolIndex = scoreTopk_flat[num]
                    elif num == 1:
                        toolIndex = scoreTopk_flat[num]
                    else:
                        toolIndex = scoreTopk_flat[num]

                    if labels[j, toolIndex].item() > 0:
                        mingzhong += 1
                        if toolIndex == labelTopk.item():
                            best += 1
                        break

                total += 1

        bestPredicts[int(problemTypes.item())] += (scores.argmax(dim=1) == labels.argmax(dim=1)).sum().item()

        for idx in scores.argmax(dim=1):
            predicted[0][idx.item()]+=1
            predicted[int(problemTypes.item())+1][idx.item()]+=1

        maxScoresIdx = scores.argmax(dim=1).reshape(len(scores),1)
        gather = labels.gather(1, maxScoresIdx)
        if labels.min()<=0:
            correctPredicts[int(problemTypes.item())]+=(gather>0).sum().item()

        predSpot[int(problemTypes.item())][np.where((-labels).argsort().cpu().numpy()==scores.argmax().item())[1]] +=1


        if labels.max() > 0 and labels.min()<=0:
            possibleCorrect[int(problemTypes.item())]+=1
        probCounter[int(problemTypes.item())]+=1

    res = [[],[],[],[],[]]

    res[0] = np.array([corr_sum.sum()/probCounter.sum(), topKAcc.sum()/probCounter.sum(), bestPredicts.sum()/probCounter.sum(), correctPredicts.sum()/possibleCorrect.sum(), predSpot.sum(axis=0), a], dtype=object)
    for i in range(0,len(res)-1):
        res[i+1] = np.array([corr_sum[i]/probCounter[i], topKAcc[i]/probCounter[i], bestPredicts[i]/probCounter[i], correctPredicts[i]/possibleCorrect[i], predSpot[i]], dtype=object)


    return res, predicted, predicts

def evaluateAlg(Pmodel, test_set, files, gpu=0, k=3):
    '''
    Function used to evaluate model on test set
    '''
    corr_sum = np.array([0.0] * 4)
    corr_sum = 0.0
    topKAcc = np.array([0.0] * 4)
    bestPredicts = np.array([0] * 4)
    correctPredicts = np.array([0] * 4)
    possibleCorrect = np.array([0] * 4)
    predSpot = np.array([[0] * test_set[0][1].size(0)] * 4)
    probCounter = np.array([0] * 4)

    predicted = np.array([[0] * test_set[0][1].size(0)] * 5)
    predicts = dict()
    mingzhong = 0
    total = 0
    best = 0
    num = 0
    a = 0

    array = np.array([
        [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1]
    ])


    Pmodel.eval()
#    Nmodel.eval()

    test_loader = torch_geometric.loader.DataLoader(dataset=test_set, batch_size=1)

    for (i, (graphs, labels, labels0, labels_true)) in enumerate(tqdm.tqdm(test_loader, leave=False)):
        graphs = graphs.to(device=gpu)
        labels = labels.to(device=gpu)
        labels_true = labels_true.to(device=gpu)
        problemTypes = graphs.problemType
        with autocast():
            with torch.no_grad():
                scores = Pmodel(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
    #            algNegativeScores = Nmodel(graphs.x, graphs.edge_index, graphs.edge_attr, graphs.batch, graphs.problemType)
                binary_algPositiveScores = torch.where(scores > 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
    #            binary_algNegativeScores = torch.where(algNegativeScores > 0, torch.tensor(1.0, device='cuda:0'), torch.tensor(0.0, device='cuda:0'))
    #            scores = torch.where((binary_algPositiveScores == 1) & (binary_algNegativeScores != 1), binary_algPositiveScores, torch.tensor(0.0,device='cuda:0'))
        predicts[files[i]] = scores.tolist()

        for j in range(len(labels)):
            positive_mask = labels[j] > 0
            positive_scores = scores[j] > 0
             #   print(positive_mask)
            positive_true = positive_mask & positive_scores
            positiveScores = scores[j][positive_true]
            positiveLabels = labels[j][positive_true]
          #  print(positiveScores)
          #  print(positiveLabels)
            if positiveScores.numel() == 0 or positiveLabels.numel() == 0 or torch.all(positiveLabels == positiveLabels[0]) or torch.all(positiveScores == positiveScores[0]):
                a += 1
                corr = torch.tensor(0.0, device = scores[j].device, requires_grad = True)
            else:
                corr, _ = spearmanr(positiveScores.cpu().detach(), positiveLabels.cpu().detach().tolist())
            corr_sum += corr
            _, scoreTopk = scores.topk(k)
            labelTopk = labels.argmax()
            labels_trueTopk = labels_true.argmax()
            topKAcc[int(problemTypes.item())] += labelTopk in scoreTopk

            scores_np = scores[j].cpu().detach().numpy().flatten()
            # æ¾å°éåçç­ç¥
            suitableAlg = np.where(scores_np > 0)[0]
            suitable_values = np.zeros_like(scores_np)
            suitable_values[suitableAlg] = 1
          #  suitable_values = scores_np[suitableAlg]

            # æ¾å°åä¸­è³å°æä¸ä¸ªå¼ä¸º1çè¡ç´¢å¼
            result_indices = np.where(np.any(array[:, suitableAlg] == 1, axis=1))[0]
            selected_values = array[result_indices]
            # è®¡ç®æ¬§å¼è·ç¦»
            euclidean_distances = []

            # ç¡®ä¿ selected_values å suitable_values å½¢ç¶ä¸è´
            for i in range(selected_values.shape[0]):
                row_to_compare = selected_values[i]

                # å¦æ suitable_values é¿åº¦å¤§äº row_to_compareï¼åååå ä¸ª
                if len(suitable_values) > len(row_to_compare):
                    suitable_values_subset = suitable_values[suitableAlg]
                else:
                    suitable_values_subset = suitable_values

                # è®¡ç®æ¬§æ°è·ç¦»
           #     print(suitable_values_subset)
                distance = np.linalg.norm(row_to_compare - suitable_values_subset)
                euclidean_distances.append((distance, result_indices[i]))

            # æ ¹æ®è·ç¦»æåº
            sorted_distances = sorted(euclidean_distances, key=lambda x: x[0])

            # Todo åé¦æ¡æ¶
           # num = 0
            if labels_true.max() > 0:
                #i = 0
                #while num < 4:
                if sorted_distances:
                    toolChoice = sorted_distances[2][1]
                    if labels_true[j,toolChoice].item() > 0 :
                        mingzhong += 1
                        if toolChoice == labels_trueTopk.item():
                            best += 1
                        # TODO ä¹åçåé¦éå
                   # i += 1
                   # num += 1
                
                    total += 1
                num += 1

        bestPredicts[int(problemTypes.item())] += (scores.argmax(dim=1) == labels.argmax(dim=1)).sum().item()

        for idx in scores.argmax(dim=1):
            predicted[0][idx.item()] += 1
            predicted[int(problemTypes.item()) + 1][idx.item()] += 1

        maxScoresIdx = scores.argmax(dim=1).reshape(len(scores), 1)
        gather = labels.gather(1, maxScoresIdx)
        if labels.min() <= 0:
            correctPredicts[int(problemTypes.item())] += (gather > 0).sum().item()

        predSpot[int(problemTypes.item())][
            np.where((-labels).argsort().cpu().numpy() == scores.argmax().item())[1]] += 1

        if labels.max() > 0 and labels.min() <= 0:
            possibleCorrect[int(problemTypes.item())] += 1
        probCounter[int(problemTypes.item())] += 1

    res = [[], [], [], [], []]

    res[0] = np.array(
        [corr_sum.item() /(i - a), topKAcc.sum() / probCounter.sum(), bestPredicts.sum() / probCounter.sum(),
         correctPredicts.sum() / possibleCorrect.sum(), predSpot.sum(axis=0),
         mingzhong / total if total > 0 else 0,
         best / total if total > 0 else 0, total, num], dtype=object)
    for i in range(0, len(res) - 1):
        res[i + 1] = np.array([topKAcc[i] / probCounter[i] if probCounter[i] != 0 else 0,
                               bestPredicts[i] / probCounter[i] if probCounter[i] != 0 else 0,
                               correctPredicts[i] / possibleCorrect[i] if possibleCorrect[i] != 0 else 0, predSpot[i]],
                              dtype=object)

    return res, predicted, predicts

def getCorrectProblemTypes(labels, problemTypes):
    '''
    Function used to make sure we are only looking at problem types that we want
    '''
    if "overflow" not in problemTypes:
	    labels = [item for item in labels if item[0].split("|||")[1]!="0"]
    if "reachSafety" not in problemTypes:
    	labels = [item for item in labels if item[0].split("|||")[1]!="1"]
    if "termination" not in problemTypes:
        labels = [item for item in labels if item[0].split("|||")[1]!="2"]
    if "memSafety" not in problemTypes:
        labels = [item for item in labels if item[0].split("|||")[1]!="3"]

    return labels

def groupLabels(labels, mapping="../../data/toolMapping.json"):
    toolToAlg, algToTool = json.load(open(mapping))

    newLabels = []
    for label in labels:
        vals = []
        for alg in algToTool:
            val = []
            for tool in algToTool[alg]:
                val.append(label[1][toolToAlg[tool][0]]/len(toolToAlg[tool][1]))
            vals.append(sum(val)/len(algToTool[alg]))
        newLabels.append((label[0],vals))

    assert(len(newLabels) == len(labels))
    return newLabels

def getWeights(train_set):
    labels = np.array([x[1] for x in train_set])
    weightDict = {tuple(t):1./sum(np.all(labels.argsort() == t, axis=1)) for t in np.unique(labels.argsort(),axis=0)}

    return [weightDict[tuple(t)] for t in labels.argsort()]
