# new_20 : 近似𝒎𝒂𝒙⁡(𝒛_𝒋^𝟎,𝒛)−𝒛)、𝒎𝒂𝒙⁡(𝒛_𝒋^𝟏,𝒛)−𝒛)
#
import random
import numpy as np
import torch
import argparse
import scipy.sparse as sp
import dgl
import pickle, gzip
import sys
import itertools
from pyscipopt import SCIP_PARAMSETTING, Model, Branchrule, Nodesel, SCIP_RESULT
from torch.utils.data import Dataset
from treelib import Tree
sys.path.append('../')
from src.logger import Logger
from src.model import GCNN_Net, PD_Net
import time
import torch.nn.functional as F
from collections import defaultdict
import copy
from src.augment import augment

logger = Logger.logger

# 创建 logs 文件夹（如不存在）
# os.makedirs("train_logs", exist_ok=True)

# # 自动生成时间戳文件名
# timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# log_file_path = os.path.join("train_logs", f"train_log_{timestamp}.log")

# # 创建 logger
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# # 创建文件处理器（只写文件，不打印到控制台）
# file_handler = logging.FileHandler(log_file_path)
# file_handler.setLevel(logging.INFO)

# # 设置日志格式
# formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")
# file_handler.setFormatter(formatter)

# # 添加文件 handler 到 logger（不添加 console handler）
# logger.addHandler(file_handler)

# def data_argumentation_new_20(instance):
#     sample = instance['data']
#     depth = instance['node_depth']
#     max_depth = instance['max_depth']
#     new_instance = {}

#     weight_init = depth/max_depth if max_depth else 1.0
#     weight_2 = -0.4 * weight_init + 1
#     weight = weight_2

#     data, bestcand, action_set, scores, lp_scores_0, lp_scores_1 = sample
    
def extract_state_new_20(model, buffer=None):

    if buffer is None or model.getNNodes() == 1:
        buffer = {}

    # 从缓冲区中更新状态 update state from buffer if any
    s = model.getState(buffer['scip_state'] if 'scip_state' in buffer else None)

    buffer['scip_state'] = s

    if 'state' in buffer:
        obj_norm = buffer['state']['obj_norm']
    else:
        obj_norm = np.linalg.norm(s['col']['coefs'])
        obj_norm = 1 if obj_norm <= 0 else obj_norm

    row_norms = s['row']['norms']#所有行约束的系数的范数
    row_norms[row_norms == 0] = 1#排除掉某个约束的系数都为0的情况

    # 列的功能 -> 变量特征
    # Column features
    n_cols = len(s['col']['types'])

    if 'state' in buffer:
        col_feats = buffer['state']['col_feats']
    else:
        col_feats = {}
        col_feats['type'] = np.zeros((n_cols, 4))  # BINARY INTEGER IMPLINT CONTINUOUS
        col_feats['type'][np.arange(n_cols), s['col']['types']] = 1
        col_feats['coef_normalized'] = s['col']['coefs'].reshape(-1, 1) / obj_norm

    col_feats['has_lb'] = ~np.isnan(s['col']['lbs']).reshape(-1, 1)
    col_feats['has_ub'] = ~np.isnan(s['col']['ubs']).reshape(-1, 1)
    col_feats['sol_is_at_lb'] = s['col']['sol_is_at_lb'].reshape(-1, 1)
    col_feats['sol_is_at_ub'] = s['col']['sol_is_at_ub'].reshape(-1, 1)
    col_feats['sol_frac'] = s['col']['solfracs'].reshape(-1, 1)
    col_feats['sol_frac'][s['col']['types'] == 3] = 0  # continuous have no fractionality
    col_feats['basis_status'] = np.zeros((n_cols, 4))  # LOWER BASIC UPPER ZERO
    col_feats['basis_status'][np.arange(n_cols), s['col']['basestats']] = 1
    col_feats['reduced_cost'] = s['col']['redcosts'].reshape(-1, 1) / obj_norm
    col_feats['age'] = s['col']['ages'].reshape(-1, 1) / (s['stats']['nlps'] + 5)
    col_feats['sol_val'] = s['col']['solvals'].reshape(-1, 1)
    col_feats['inc_val'] = s['col']['incvals'].reshape(-1, 1)
    col_feats['avg_inc_val'] = s['col']['avgincvals'].reshape(-1, 1)
    
    col_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in col_feats.items()]
    col_feat_names = [n for names in col_feat_names for n in names]
    col_feat_vals = np.concatenate(list(col_feats.values()), axis=-1)

    variable_features = {
        'names': col_feat_names,
        'values': col_feat_vals,}

    # 行的功能 -> 约束特征 Row features

    if 'state' in buffer:
        row_feats = buffer['state']['row_feats']
        has_lhs = buffer['state']['has_lhs']
        has_rhs = buffer['state']['has_rhs']
    else:
        row_feats = {}
        has_lhs = np.nonzero(~np.isnan(s['row']['lhss']))[0]
        has_rhs = np.nonzero(~np.isnan(s['row']['rhss']))[0]
        row_feats['obj_cosine_similarity'] = np.concatenate((
            -s['row']['objcossims'][has_lhs],
            +s['row']['objcossims'][has_rhs])).reshape(-1, 1)
        row_feats['bias'] = np.concatenate((
            -(s['row']['lhss'] / row_norms)[has_lhs],
            +(s['row']['rhss'] / row_norms)[has_rhs])).reshape(-1, 1)

    row_feats['is_tight'] = np.concatenate((
        s['row']['is_at_lhs'][has_lhs],
        s['row']['is_at_rhs'][has_rhs])).reshape(-1, 1)

    row_feats['age'] = np.concatenate((
        s['row']['ages'][has_lhs],
        s['row']['ages'][has_rhs])).reshape(-1, 1) / (s['stats']['nlps'] + 5)

    # todo: check whether this change helps
    tmp = s['row']['dualsols'] / (row_norms * obj_norm)
    # tmp = s['row']['dualsols'] * row_norms / obj_norm
    row_feats['dualsol_val_normalized'] = np.concatenate((
            -tmp[has_lhs],
            +tmp[has_rhs])).reshape(-1, 1)

    row_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in row_feats.items()]
    row_feat_names = [n for names in row_feat_names for n in names]
    row_feat_vals = np.concatenate(list(row_feats.values()), axis=-1)
    
    constraint_features = {
        'names': row_feat_names,
        'values': row_feat_vals,}

    # 边特征 Edge features
    if 'state' in buffer:
        edge_row_idxs = buffer['state']['edge_row_idxs']
        edge_col_idxs = buffer['state']['edge_col_idxs']
        edge_feats = buffer['state']['edge_feats']
    else:
        coef_matrix = sp.csr_matrix(#是一个压缩的指针
            (s['nzrcoef']['vals'] / row_norms[s['nzrcoef']['rowidxs']],
            (s['nzrcoef']['rowidxs'], s['nzrcoef']['colidxs'])),
            shape=(len(s['row']['nnzrs']), len(s['col']['types'])))
        coef_matrix = sp.vstack((
            -coef_matrix[has_lhs, :],
            coef_matrix[has_rhs, :])).tocoo(copy=False)

        edge_row_idxs, edge_col_idxs = coef_matrix.row, coef_matrix.col
        edge_feats = {}

        edge_feats['coef_normalized'] = coef_matrix.data.reshape(-1, 1)

    edge_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in edge_feats.items()]
    edge_feat_names = [n for names in edge_feat_names for n in names]
    edge_feat_indices = np.vstack([edge_row_idxs, edge_col_idxs])
    edge_feat_vals = np.concatenate(list(edge_feats.values()), axis=-1)
    
    # edge_feat_vals.shape : (397, 1)

    edge_features = {
        'names': edge_feat_names,
        'indices': edge_feat_indices,
        'values': edge_feat_vals,}

    if 'state' not in buffer:
        buffer['state'] = {
            'obj_norm': obj_norm,
            'col_feats': col_feats,
            'row_feats': row_feats,
            'has_lhs': has_lhs,
            'has_rhs': has_rhs,
            'edge_row_idxs': edge_row_idxs,
            'edge_col_idxs': edge_col_idxs,
            'edge_feats': edge_feats,
        }

    return constraint_features, edge_features, variable_features

def graph_transform_new_20(state):
    """
    :param state: features extracted from scip
    参数状态:从scip中提取的特征
    :return: dgl heterograph graph
    返回:DGL异质图
    """
    idx = state[1]['indices']
    # make sure that feature shape matches the graph
    # 确保特征形状与图形匹配
    v2c_index = (idx[1, :], idx[0, :])
    c2v_index = (idx[0, :], idx[1, :])

    # 获得边特征、约束特征、变量特征。
    edge_feats = torch.tensor(state[1]['values'], dtype=torch.float).view(-1, 1)
    c_feats = torch.tensor(state[0]['values'], dtype=torch.float)
    v_feats = torch.tensor(state[2]['values'][:,:17], dtype=torch.float)
    # 获得约束和变量的索引
    num_nodes_dict = {'c': c_feats.shape[0], 'v': v_feats.shape[0]}
    
    # 构建异构图，一个异构图由一系列子图构成，一个子图对应一种关系。
    # 每个关系由一个字符串三元组定义（源节点类型，边类型，目标节点类型）。
    graph = dgl.heterograph({
        ('v', 'v2c', 'c'): v2c_index,
        ('c', 'c2v', 'v'): c2v_index,
    }, num_nodes_dict)

    # 给异构图赋值
    graph.edges['v2c'].data['h'] = edge_feats
    graph.edges['c2v'].data['h'] = edge_feats
    graph.nodes['v'].data['h'] = v_feats
    graph.nodes['c'].data['h'] = c_feats

    # 返回异构图
    return graph

class BranchDataset_new_20(Dataset):
    def __init__(self, dirs):
        super().__init__()
        self.data = []
        for dir in dirs:
            sample = pickle.load(gzip.open(dir, 'rb'))['data']
            depth = pickle.load(gzip.open(dir, 'rb'))['node_depth']
            max_depth = pickle.load(gzip.open(dir, 'rb'))['max_depth']

            weight_init = depth/max_depth if max_depth else 1.0
            weight_2 = -0.4 * weight_init + 1
            weight = weight_2

            data, bestcand, action_set, scores, lp_scores_0, lp_scores_1 = sample
            #data gcn state
            label = bestcand

            label_sb = scores.index(max(scores))
            label_lp_0 = lp_scores_0.index(max(lp_scores_0))
            label_lp_1 = lp_scores_1.index(max(lp_scores_1))


            candidates = action_set
            scores = scores

            self.data.append((data, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def dgl_collate_new_20(batch):
    data, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight = map(list, zip(*batch))
    graphs = []
    for d in data:
        graphs.append(graph_transform_new_20(d))
    batched_graph = dgl.batch(graphs)
    label_sb = torch.LongTensor(label_sb)
    label_lp_0 = torch.LongTensor(label_lp_0)
    label_lp_1 = torch.LongTensor(label_lp_1)
    weight = torch.FloatTensor(weight)

    return batched_graph, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight

class BranchDataset_new_20_depth(Dataset):
    def __init__(self, dirs):
        super().__init__()
        self.data = []
        for dir in dirs:
            sample = pickle.load(gzip.open(dir, 'rb'))['data']
            depth = pickle.load(gzip.open(dir, 'rb'))['node_depth']
            max_depth = pickle.load(gzip.open(dir, 'rb'))['max_depth']

            weight_init = depth/max_depth if max_depth else 1.0
            weight_2 = -0.4 * weight_init + 1
            weight = weight_2

            data, bestcand, action_set, scores, lp_scores_0, lp_scores_1 = sample
            
            label = bestcand

            label_sb = scores.index(max(scores))
            label_lp_0 = lp_scores_0.index(max(lp_scores_0))
            label_lp_1 = lp_scores_1.index(max(lp_scores_1))


            candidates = action_set
            scores = scores

            self.data.append((data, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth))
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
def dgl_collate_new_20_depth(batch):
    data, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth = map(list, zip(*batch))
    graphs = []
    for d in data:
        graphs.append(graph_transform_new_20(d))
    batched_graph = dgl.batch(graphs)
    label_sb = torch.LongTensor(label_sb)
    label_lp_0 = torch.LongTensor(label_lp_0)
    label_lp_1 = torch.LongTensor(label_lp_1)
    weight = torch.FloatTensor(weight)
    depth = torch.LongTensor(depth)

    return batched_graph, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth

def balanced_collate_fn(batch):
    """
    输入 batch: data, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth
    输出增强后的 batch，使得类别大致平衡
    """
    method_counts = defaultdict(int)
    # 按类别分组
    class_to_samples = defaultdict(list)
    for item in batch:
        depth = int(item[-1])  # depth 是最后一项
        if (depth < 10):
            label = 0
        # elif(depth <= 4):
        #     label = 1
        # elif(depth <= 7):
        #     label = 2
        else:
            label = 1
        class_to_samples[label].append(item)

    max_class_count = max(len(v) for v in class_to_samples.values())

    augmented_batch = []

    for cls, samples in class_to_samples.items():
        augmented = []
        for item in samples:
            new_item = copy.deepcopy(item)
            new_item = list(new_item)
            new_item[0], method_name = augment(new_item[0])  # 仅增强 data（第0项）
            augmented.append(tuple(new_item))
            method_counts[method_name] += 1
        augmented_batch.extend(samples + augmented)

    random.shuffle(augmented_batch)
    
    # # 增强不足的类别
    # augmented_batch = []
    # for cls, samples in class_to_samples.items():
    #     count = len(samples)
    #     if count < max_class_count:
    #         augmented = []
    #         for _ in range(max_class_count - count):
    #             item = copy.deepcopy(random.choice(samples))
    #             item = list(item)  # 转为可变列表
    #             item[0], method_name = augment(item[0])  # 仅增强 data（第0项）
    #             augmented.append(tuple(item))
    #             if method_name in method_counts.keys():
    #                 method_counts[method_name] += 1
    #             else:
    #                 method_counts[method_name] = 1
    #         samples = samples + augmented
    #     augmented_batch.extend(samples)

    # 打乱
    # random.shuffle(augmented_batch)

    # 拆包
    data, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth = map(list, zip(*augmented_batch))
    graphs = []
    for d in data:
        graphs.append(graph_transform_new_20(d))
    # 转 tensor
    batched_graph = dgl.batch(graphs)
    label_sb = torch.LongTensor(label_sb)
    label_lp_0 = torch.LongTensor(label_lp_0)
    label_lp_1 = torch.LongTensor(label_lp_1)
    weight = torch.FloatTensor(weight)
    depth = torch.LongTensor(depth)
    
    # print(f"Augmentation counts: {dict(method_counts)}")
    return batched_graph, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth

# def balanced_collate_fn(batch):
#     """
#     输入 batch: data, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth
#     输出增强后的 batch，使得类别大致平衡
#     """
#     method_counts = defaultdict(int)
#     # 按类别分组
#     class_to_samples = defaultdict(list)
#     for item in batch:
#         depth = int(item[-1])  # depth 是最后一项
#         if (depth < 10):
#             label = 0
#         # elif(depth <= 4):
#         #     label = 1
#         # elif(depth <= 7):
#         #     label = 2
#         else:
#             label = 1
#         class_to_samples[label].append(item)

#     max_class_count = max(len(v) for v in class_to_samples.values())

#     # 增强不足的类别
#     augmented_batch = []
#     for cls, samples in class_to_samples.items():
#         count = len(samples)
#         if count < max_class_count:
#             augmented = []
#             for _ in range(max_class_count - count):
#                 item = copy.deepcopy(random.choice(samples))
#                 item = list(item)  # 转为可变列表
#                 item[0], method_name = augment(item[0])  # 仅增强 data（第0项）
#                 augmented.append(tuple(item))
#                 if method_name in method_counts.keys():
#                     method_counts[method_name] += 1
#                 else:
#                     method_counts[method_name] = 1
#             samples = samples + augmented
#         augmented_batch.extend(samples)

#     # 打乱
#     random.shuffle(augmented_batch)

#     # 拆包
#     data, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth = map(list, zip(*augmented_batch))
#     graphs = []
#     for d in data:
#         graphs.append(graph_transform_new_20(d))
#     # 转 tensor
#     batched_graph = dgl.batch(graphs)
#     label_sb = torch.LongTensor(label_sb)
#     label_lp_0 = torch.LongTensor(label_lp_0)
#     label_lp_1 = torch.LongTensor(label_lp_1)
#     weight = torch.FloatTensor(weight)
#     depth = torch.LongTensor(depth)
    
#     # print(f"Augmentation counts: {dict(method_counts)}")
#     return batched_graph, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth

# def dgl_collate_new_20_aug(batch):
#     # 解压原始批次数据
#     data, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth = map(list, zip(*batch))
    
#     # 定义四种增强方法
    

#     # 增强方法列表
#     augmentation_methods = [
#         flip_variables,
#         add_redundant_constraint,
#         perturb_objective,
#         perturb_constraints,
#         perturb_duals
#     ]

#     # 应用数据增强
#     augmented_data = []
#     method_counts = defaultdict(int)
#     augmented_sample, method_name = augment(copy.deepcopy(sample))
#     augmented_data.append(augmented_sample)
#     method_counts['original'] += 1

#     # 打印当前批次增强统计信息 (调试用)
#     print(f"Augmentation counts: {dict(method_counts)}")

#     # 处理增强后的数据
#     graphs = []
#     for d in augmented_data:
#         graphs.append(graph_transform_new_20(d))
    
#     batched_graph = dgl.batch(graphs)
#     label_sb = torch.LongTensor(label_sb)
#     label_lp_0 = torch.LongTensor(label_lp_0)
#     label_lp_1 = torch.LongTensor(label_lp_1)
#     weight = torch.FloatTensor(weight)
#     depth = torch.LongTensor(depth)

#     return batched_graph, candidates, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth

def train_new_20(args, model, trainData, optimizer, epoch, log_interval, device):
    model.train()
    batch_loss = 0.0

    loss_ce_sb = 0.0
    loss_ce_lp_0 = 0.0
    loss_ce_lp_1 = 0.0
    loss_mse_lp_0 = 0.0
    loss_mse_lp_1 = 0.0

    acc_dict = {
        1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0
    }

    total_count = 0.0
    correct_count = 0.0
    hard_correct_count = 0.0
    for batch_idx, (bg, candistaes, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight) in enumerate(trainData):
        optimizer.zero_grad()
        bg = bg.to(device)
        label_sb = label_sb.to(device)
        label_lp_0 = label_lp_0.to(device)
        label_lp_1 = label_lp_1.to(device)
        weight = weight.to(device)

        count = len(label_sb)
        total_count += count
        
        bg = model.module.forward_sb(bg)
        graphs = dgl.unbatch(bg)

        regu_ce = 0.5
        regu_mse = 0.01

        logits_sb = [g.nodes['v'].data['s'][candistaes[i]] for i, g in enumerate(graphs)]
        logits_0 = [g.nodes['v'].data['s_0'][candistaes[i]] for i, g in enumerate(graphs)]
        logits_1 = [g.nodes['v'].data['s_1'][candistaes[i]] for i, g in enumerate(graphs)]

        loss_ce_sb = sum([F.cross_entropy(logits_sb[i].T, label_sb[i:i+1]) for i in range(count)]) / count

        loss_mse_lp_0 = sum([F.mse_loss(torch.FloatTensor(lp_scores_0[i]).to(device), logits_0[i].squeeze()) for i in range(count)]) / count
        loss_mse_lp_1 = sum([F.mse_loss(torch.FloatTensor(lp_scores_1[i]).to(device), logits_1[i].squeeze()) for i in range(count)]) / count
        # if (args.top_k):
            
        #     loss = loss_ce_sb + 
        # else:
        loss = loss_ce_sb + loss_mse_lp_0 + loss_mse_lp_1

        loss.backward()
        optimizer.step()

        loss_ce_sb += loss_ce_sb.item() * count
        # loss_ce_lp_0 = regu_ce * loss_ce_lp_0.item() * count
        # loss_ce_lp_1 = regu_ce * loss_ce_lp_1.item() * count
        loss_mse_lp_0 += regu_mse * loss_mse_lp_0.item() * count
        loss_mse_lp_1 += regu_mse * loss_mse_lp_1.item() * count

        batch_loss += loss.item() * count

        logits = logits_sb
        label = label_sb

        pred = [logits[i].argmax(dim=0) for i in range(count)]
        best_scores = [max(scores[i]) for i in range(count)]
        
        top_k = [1, 3, 5, 10]
        for k in top_k:
            topk_index = [logits[i].squeeze().topk(min(k,len(logits[i].squeeze())))[1].tolist() for i in range(count)]
            pred_top_k_true_scores = [[scores[i][j] for j in topk_index[i]] for i in range(count)]
            acc_dict[k] += sum([max(pred_top_k_true_scores[i])==best_scores[i] for i in range(count)])


        if (batch_idx+1) % log_interval == 0:
            logger.info('Train epoch {}:\t [{}/{}({:.0f}%)]\tTrain Loss: {:.6f}\t acc@1: {:.6f}\t acc@3: {:.6f}\t acc@5: {:.6f}\t acc@10: {:.6f}'.format(
                epoch + 1, batch_idx+1, len(trainData), 100. * (batch_idx+1) / len(trainData), batch_loss / total_count,
            acc_dict[1]/total_count, acc_dict[3]/total_count, acc_dict[5]/total_count, acc_dict[10]/total_count))
        
    return batch_loss / total_count, acc_dict[1]/total_count

def train_new_20_cons(args, model, trainData, optimizer, epoch, log_interval, device, alpha = 0.5):
    model.train()
    batch_loss = 0.0

    loss_ce_sb = 0.0
    loss_ce_lp_0 = 0.0
    loss_ce_lp_1 = 0.0
    loss_mse_lp_0 = 0.0
    loss_mse_lp_1 = 0.0

    acc_dict = {
        1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0
    }

    total_count = 0.0
    correct_count = 0.0
    hard_correct_count = 0.0
    low_threshold = 83
    high_threshold = 166
    for batch_idx, (bg, candistaes, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth) in enumerate(trainData):
        optimizer.zero_grad()
        bg = bg.to(device)
        label_sb = label_sb.to(device)
        label_lp_0 = label_lp_0.to(device)
        label_lp_1 = label_lp_1.to(device)
        print('len(candistaes)',len(candistaes), len(candistaes[0]))
        # print('depth',depth.shape,depth)
        
        depth_label = []
        for i in range(len(candistaes)):
            
            if len(candistaes[i]) <= low_threshold:
                depth = torch.tensor(0)
            elif len(candistaes[i]) <= high_threshold:
                depth = torch.tensor(1)
            else:
                depth = torch.tensor(2)
            depth_label.append(depth)
        # cand_len = torch.tensor(len(candistaes))
        depth_label = torch.tensor(depth_label).to(device)
        # torch.where(cand_len <= low_threshold, torch.tensor(0),
                        # torch.where(cand_len <= high_threshold, torch.tensor(1),
                                    # torch.tensor(2)))
        # depth_label = torch.where(len(candistaes) <= low_threshold, torch.tensor(0),
        #                torch.where(len(candistaes) <= high_threshold, torch.tensor(1),
        #             #    torch.where(depth <= 7, torch.tensor(2),
        #                            torch.tensor(2)))
        
        # print('depth_label',depth_label.shape,depth_label)
        # depth_label = depth_label.to(device)
        # print('depth_label', depth_label.shape, depth_label)
        weight = weight.to(device)

        count = len(label_sb)
        total_count += count
        
        bg, graph_embeds = model.module.forward_sb(bg)
        graphs = dgl.unbatch(bg)

        regu_ce = 0.5
        regu_mse = 0.01

        logits_sb = [g.nodes['v'].data['s'][candistaes[i]] for i, g in enumerate(graphs)]
        logits_0 = [g.nodes['v'].data['s_0'][candistaes[i]] for i, g in enumerate(graphs)]
        logits_1 = [g.nodes['v'].data['s_1'][candistaes[i]] for i, g in enumerate(graphs)]

        loss_ce_sb = sum([F.cross_entropy(logits_sb[i].T, label_sb[i:i+1]) for i in range(count)]) / count

        loss_mse_lp_0 = sum([F.mse_loss(torch.FloatTensor(lp_scores_0[i]).to(device), logits_0[i].squeeze()) for i in range(count)]) / count
        loss_mse_lp_1 = sum([F.mse_loss(torch.FloatTensor(lp_scores_1[i]).to(device), logits_1[i].squeeze()) for i in range(count)]) / count
        # if (args.top_k):
            
        #     loss = loss_ce_sb + 
        # else:
        print("type(embeds):", type(graph_embeds))
        contrastive_loss = model.module.supervised_contrastive_loss(graph_embeds, depth_label)
        loss = loss_ce_sb + loss_mse_lp_0 + loss_mse_lp_1 + alpha * contrastive_loss

        loss.backward()
        optimizer.step()
        logger.info(f"alpha: {model.module.alpha.item():.4f}, beta: {model.module.beta.item():.4f}")
        loss_ce_sb += loss_ce_sb.item() * count
        # loss_ce_lp_0 = regu_ce * loss_ce_lp_0.item() * count
        # loss_ce_lp_1 = regu_ce * loss_ce_lp_1.item() * count
        loss_mse_lp_0 += regu_mse * loss_mse_lp_0.item() * count
        loss_mse_lp_1 += regu_mse * loss_mse_lp_1.item() * count

        batch_loss += loss.item() * count

        logits = logits_sb
        label = label_sb

        pred = [logits[i].argmax(dim=0) for i in range(count)]
        best_scores = [max(scores[i]) for i in range(count)]
        
        top_k = [1, 3, 5, 10]
        for k in top_k:
            topk_index = [logits[i].squeeze().topk(min(k,len(logits[i].squeeze())))[1].tolist() for i in range(count)]
            pred_top_k_true_scores = [[scores[i][j] for j in topk_index[i]] for i in range(count)]
            acc_dict[k] += sum([max(pred_top_k_true_scores[i])==best_scores[i] for i in range(count)])


        if (batch_idx+1) % log_interval == 0:
            logger.info('Train epoch {}:\t [{}/{}({:.0f}%)]\tTrain Loss: {:.6f}\t acc@1: {:.6f}\t acc@3: {:.6f}\t acc@5: {:.6f}\t acc@10: {:.6f}'.format(
                epoch + 1, batch_idx+1, len(trainData), 100. * (batch_idx+1) / len(trainData), batch_loss / total_count,
            acc_dict[1]/total_count, acc_dict[3]/total_count, acc_dict[5]/total_count, acc_dict[10]/total_count))
        
    return batch_loss / total_count, acc_dict[1]/total_count


def test_new_20(args, model, testData, optimizer, epoch, device):
    model.eval()
    batch_loss = 0.0

    loss_ce_sb = 0.0
    loss_ce_lp_0 = 0.0
    loss_ce_lp_1 = 0.0
    loss_mse_lp_0 = 0.0
    loss_mse_lp_1 = 0.0

    acc_dict = {
        1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0
    }


    correct_count = 0
    total_count = 0
    hard_correct_count = 0
    with torch.no_grad():
        for bg, candistaes, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight in testData:
            optimizer.zero_grad()
            bg = bg.to(device)
            label_sb = label_sb.to(device)
            label_lp_0 = label_lp_0.to(device)
            label_lp_1 = label_lp_1.to(device)
            weight = weight.to(device)

            count = len(label_sb)
            total_count += count

            bg = model.module.forward_sb(bg)
            graphs = dgl.unbatch(bg)
            
            regu_ce = 0.5
            regu_mse = 0.01

            logits_sb = [g.nodes['v'].data['s'][candistaes[i]] for i, g in enumerate(graphs)]
            logits_0 = [g.nodes['v'].data['s_0'][candistaes[i]] for i, g in enumerate(graphs)]
            logits_1 = [g.nodes['v'].data['s_1'][candistaes[i]] for i, g in enumerate(graphs)]


            loss_ce_sb = sum([F.cross_entropy(logits_sb[i].T, label_sb[i:i+1]) for i in range(count)]) / count

            loss_mse_lp_0 = sum([F.mse_loss(torch.FloatTensor(lp_scores_0[i]).to(device), logits_0[i].squeeze()) for i in range(count)]) / count
            loss_mse_lp_1 = sum([F.mse_loss(torch.FloatTensor(lp_scores_1[i]).to(device), logits_1[i].squeeze()) for i in range(count)]) / count

            loss = loss_ce_sb + loss_mse_lp_0 + loss_mse_lp_1


            loss_ce_sb += loss_ce_sb.item() * count
            # loss_ce_lp_0 = regu_ce * loss_ce_lp_0.item() * count
            # loss_ce_lp_1 = regu_ce * loss_ce_lp_1.item() * count
            loss_mse_lp_0 += regu_mse * loss_mse_lp_0.item() * count
            loss_mse_lp_1 += regu_mse * loss_mse_lp_1.item() * count

            logits = logits_sb
            label = label_sb

            pred = [logits[i].argmax(dim=0) for i in range(count)]

            best_scores = [max(scores[i]) for i in range(count)]
            top_k = [1, 3, 5, 10]
            for k in top_k:
                topk_index = [logits[i].squeeze().topk(min(k,len(logits[i].squeeze())))[1].tolist() for i in range(count)]

                pred_top_k_true_scores = [[scores[i][j] for j in topk_index[i]] for i in range(count)]
                acc_dict[k] += sum([max(pred_top_k_true_scores[i])==best_scores[i] for i in range(count)])
            
            batch_loss += loss.item() * count

    loss = batch_loss / total_count
    logger.info('Test epoch {}:\t Test Loss: {:.6f} acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(epoch + 1, loss,
    acc_dict[1]/total_count, acc_dict[3]/total_count, acc_dict[5]/total_count, acc_dict[10]/total_count))
    logger.info('Test epoch {}:\t Test Loss: {:.6f} acc: {:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(epoch + 1, loss,
    acc_dict[1]/total_count, acc_dict[3]/total_count, acc_dict[5]/total_count, acc_dict[10]/total_count))
    return loss, acc_dict[1]/total_count, acc_dict[3]/total_count

def test_new_20_cons(args, model, testData, optimizer, epoch, device):
    model.eval()
    batch_loss = 0.0

    loss_ce_sb = 0.0
    loss_ce_lp_0 = 0.0
    loss_ce_lp_1 = 0.0
    loss_mse_lp_0 = 0.0
    loss_mse_lp_1 = 0.0

    acc_dict = {
        1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0
    }
    acc_dict1 = {1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0}
    acc_dict2 = {1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0}
    acc_dict3 = {1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0}
    acc_dict4 = {1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0}
    num1 = 0
    num2 = 0
    num3 = 0
    num4 = 0


    correct_count = 0
    total_count = 0
    hard_correct_count = 0
    with torch.no_grad():
        for bg, candistaes, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth in testData:
            optimizer.zero_grad()
            bg = bg.to(device)
            label_sb = label_sb.to(device)
            label_lp_0 = label_lp_0.to(device)
            label_lp_1 = label_lp_1.to(device)
            weight = weight.to(device)
            depth = depth.to(device)

            count = len(label_sb)
            total_count += count

            bg, _ = model.module.forward_sb(bg)
            graphs = dgl.unbatch(bg)
            
            regu_ce = 0.5
            regu_mse = 0.01

            logits_sb = [g.nodes['v'].data['s'][candistaes[i]] for i, g in enumerate(graphs)]
            logits_0 = [g.nodes['v'].data['s_0'][candistaes[i]] for i, g in enumerate(graphs)]
            logits_1 = [g.nodes['v'].data['s_1'][candistaes[i]] for i, g in enumerate(graphs)]


            loss_ce_sb = sum([F.cross_entropy(logits_sb[i].T, label_sb[i:i+1]) for i in range(count)]) / count

            loss_mse_lp_0 = sum([F.mse_loss(torch.FloatTensor(lp_scores_0[i]).to(device), logits_0[i].squeeze()) for i in range(count)]) / count
            loss_mse_lp_1 = sum([F.mse_loss(torch.FloatTensor(lp_scores_1[i]).to(device), logits_1[i].squeeze()) for i in range(count)]) / count

            loss = loss_ce_sb + loss_mse_lp_0 + loss_mse_lp_1


            loss_ce_sb += loss_ce_sb.item() * count
            # loss_ce_lp_0 = regu_ce * loss_ce_lp_0.item() * count
            # loss_ce_lp_1 = regu_ce * loss_ce_lp_1.item() * count
            loss_mse_lp_0 += regu_mse * loss_mse_lp_0.item() * count
            loss_mse_lp_1 += regu_mse * loss_mse_lp_1.item() * count

            logits = logits_sb
            label = label_sb

            pred = [logits[i].argmax(dim=0) for i in range(count)]

            best_scores = [max(scores[i]) for i in range(count)]
            top_k = [1, 3, 5, 10]
            for i in range(count):
                if (depth[i]<2):
                    num1 += 1
                elif(depth[i]<5):
                    num2 += 1
                elif(depth[i]<8):
                    num3 += 1
                else:
                    num4 += 1
            for k in top_k:
                topk_index = [logits[i].squeeze().topk(min(k,len(logits[i].squeeze())))[1].tolist() for i in range(count)]

                pred_top_k_true_scores = [[scores[i][j] for j in topk_index[i]] for i in range(count)]
                cc = [max(pred_top_k_true_scores[i])==best_scores[i] for i in range(count)]
                # print('cc', k, cc)
                acc_dict[k] += sum(cc)
                # print('cc', k, acc_dict[k])
                for i in range(count):
                    if (cc[i]):
                        if (depth[i]<2):
                            acc_dict1[k] += 1
                        elif(depth[i]<5):
                            acc_dict2[k] += 1
                        elif(depth[i]<8):
                            acc_dict3[k] += 1
                        else:
                            acc_dict4[k] += 1
                # acc_dict[k] += sum([max(pred_top_k_true_scores[i])==best_scores[i] for i in range(count)])
            
            batch_loss += loss.item() * count

    loss = batch_loss / total_count
    logger.info('Test epoch {}:\t Test Loss: {:.6f} acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(epoch + 1, loss,
    acc_dict[1]/total_count, acc_dict[3]/total_count, acc_dict[5]/total_count, acc_dict[10]/total_count))
    logger.info('depth 1 num: {}  acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(num1, acc_dict1[1]/num1, acc_dict1[3]/num1, acc_dict1[5]/num1, acc_dict1[10]/num1))
    logger.info('depth 2 num: {}  acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(num2, acc_dict2[1]/num2, acc_dict2[3]/num2, acc_dict2[5]/num2, acc_dict2[10]/num2))
    logger.info('depth 3 num: {}  acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(num3, acc_dict3[1]/num3, acc_dict3[3]/num3, acc_dict3[5]/num3, acc_dict3[10]/num3))
    logger.info('depth 4 num: {}  acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(num4, acc_dict4[1]/num4, acc_dict4[3]/num4, acc_dict4[5]/num4, acc_dict4[10]/num4))
    logger.info('Test epoch {}:\t Test Loss: {:.6f} acc: {:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(epoch + 1, loss,
    acc_dict[1]/total_count, acc_dict[3]/total_count, acc_dict[5]/total_count, acc_dict[10]/total_count))
    return loss, acc_dict[1]/total_count, acc_dict[3]/total_count

def test_new_20_depth(args, model, testData, optimizer, epoch, device):
    model.eval()
    batch_loss = 0.0

    loss_ce_sb = 0.0
    loss_ce_lp_0 = 0.0
    loss_ce_lp_1 = 0.0
    loss_mse_lp_0 = 0.0
    loss_mse_lp_1 = 0.0

    acc_dict = {
        1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0
    }
    acc_dict1 = {1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0}
    acc_dict2 = {1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0}
    acc_dict3 = {1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0}
    acc_dict4 = {1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0}
    num1 = 0
    num2 = 0
    num3 = 0
    num4 = 0

    correct_count = 0
    total_count = 0
    hard_correct_count = 0
    with torch.no_grad():
        for bg, candistaes, label_sb, label_lp_0, label_lp_1, scores, lp_scores_0, lp_scores_1, weight, depth in testData:
            optimizer.zero_grad()
            bg = bg.to(device)
            label_sb = label_sb.to(device)
            label_lp_0 = label_lp_0.to(device)
            label_lp_1 = label_lp_1.to(device)
            weight = weight.to(device)
            depth = depth.to(device)
            
            count = len(label_sb)
            total_count += count

            bg = model.module.forward_sb(bg)
            graphs = dgl.unbatch(bg)
            
            regu_ce = 0.5
            regu_mse = 0.01

            logits_sb = [g.nodes['v'].data['s'][candistaes[i]] for i, g in enumerate(graphs)]
            logits_0 = [g.nodes['v'].data['s_0'][candistaes[i]] for i, g in enumerate(graphs)]
            logits_1 = [g.nodes['v'].data['s_1'][candistaes[i]] for i, g in enumerate(graphs)]


            loss_ce_sb = sum([F.cross_entropy(logits_sb[i].T, label_sb[i:i+1]) for i in range(count)]) / count

            loss_mse_lp_0 = sum([F.mse_loss(torch.FloatTensor(lp_scores_0[i]).to(device), logits_0[i].squeeze()) for i in range(count)]) / count
            loss_mse_lp_1 = sum([F.mse_loss(torch.FloatTensor(lp_scores_1[i]).to(device), logits_1[i].squeeze()) for i in range(count)]) / count

            loss = loss_ce_sb + loss_mse_lp_0 + loss_mse_lp_1


            loss_ce_sb += loss_ce_sb.item() * count
            # loss_ce_lp_0 = regu_ce * loss_ce_lp_0.item() * count
            # loss_ce_lp_1 = regu_ce * loss_ce_lp_1.item() * count
            loss_mse_lp_0 += regu_mse * loss_mse_lp_0.item() * count
            loss_mse_lp_1 += regu_mse * loss_mse_lp_1.item() * count

            logits = logits_sb
            label = label_sb

            pred = [logits[i].argmax(dim=0) for i in range(count)]

            best_scores = [max(scores[i]) for i in range(count)]
            top_k = [1, 3, 5, 10]
            for i in range(count):
                if (depth[i]<2):
                    num1 += 1
                elif(depth[i]<5):
                    num2 += 1
                elif(depth[i]<8):
                    num3 += 1
                else:
                    num4 += 1
            for k in top_k:
                topk_index = [logits[i].squeeze().topk(min(k,len(logits[i].squeeze())))[1].tolist() for i in range(count)]

                pred_top_k_true_scores = [[scores[i][j] for j in topk_index[i]] for i in range(count)]
                cc = [max(pred_top_k_true_scores[i])==best_scores[i] for i in range(count)]
                # print('cc', k, cc)
                acc_dict[k] += sum(cc)
                # print('cc', k, acc_dict[k])
                for i in range(count):
                    if (cc[i]):
                        if (depth[i]<2):
                            acc_dict1[k] += 1
                        elif(depth[i]<5):
                            acc_dict2[k] += 1
                        elif(depth[i]<8):
                            acc_dict3[k] += 1
                        else:
                            acc_dict4[k] += 1
                # acc_dict[k] += sum([max(pred_top_k_true_scores[i])==best_scores[i] for i in range(count)])
            
            batch_loss += loss.item() * count

    loss = batch_loss / total_count
    logger.info('Test epoch {}:\t Test Loss: {:.6f} acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(epoch + 1, loss,
    acc_dict[1]/total_count, acc_dict[3]/total_count, acc_dict[5]/total_count, acc_dict[10]/total_count))
    logger.info('depth 1 num: {}  acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(num1, acc_dict1[1]/num1, acc_dict1[3]/num1, acc_dict1[5]/num1, acc_dict1[10]/num1))
    logger.info('depth 2 num: {}  acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(num2, acc_dict2[1]/num2, acc_dict2[3]/num2, acc_dict2[5]/num2, acc_dict2[10]/num2))
    logger.info('depth 3 num: {}  acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(num3, acc_dict3[1]/num3, acc_dict3[3]/num3, acc_dict3[5]/num3, acc_dict3[10]/num3))
    logger.info('depth 4 num: {}  acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(num4, acc_dict4[1]/num4, acc_dict4[3]/num4, acc_dict4[5]/num4, acc_dict4[10]/num4))
    logger.info('Test epoch {}:\t Test Loss: {:.6f} acc: {:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(epoch + 1, loss,
    acc_dict[1]/total_count, acc_dict[3]/total_count, acc_dict[5]/total_count, acc_dict[10]/total_count))
    return loss, acc_dict[1]/total_count, acc_dict[3]/total_count
