import os
import sys
import json
import os.path as osp
import pandas as pd
import numpy as np
import torch
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import dgl
# from ogb.nodeproppred import DglNodePropPredDataset
dgl.use_libxsmm(False)

def convert_code_to_token(code, tokenizer, block_size):
    """将代码转换为token ID序列并进行填充处理"""
    code = ' '.join(code.split())
    code_tokens = tokenizer.tokenize(code)[:block_size-2]
    source_tokens = [tokenizer.cls_token] + code_tokens + [tokenizer.sep_token]
    source_ids = tokenizer.convert_tokens_to_ids(source_tokens)
    padding_length = block_size - len(source_ids)
    return source_ids + [tokenizer.pad_token_id] * padding_length

def preprocess_edges(edges, num_nodes):
    """预处理边数据，确保节点索引在有效范围内"""
    src, dst = [], []
    for s, d in edges:
        if s < num_nodes and d < num_nodes:
            src.append(s)
            dst.append(d)
    return [src, dst]

def preprocess_features(features, block_size):
    """特征矩阵预处理和填充"""
    pad = block_size - len(features)
    padded = np.pad(features, ((0, pad), (0, 0)), mode='constant')
    return padded

def process_sample(sample, tokenizer, word_embeddings, token_index, block_size, feature_dim):
    """处理单个样本的函数"""
    # 1. 处理节点特征
    node_features = []
    nodes = sample["nodes"]
    edges = sample["edges"]
    nodes_codes = sample["nodes_codes"]
    node_labels = sample["nodes_label"]
    
    num_nodes = min(len(nodes), block_size)
    edges = preprocess_edges(edges, num_nodes)
    
    # 2. 为每个节点生成特征向量
    for i in range(num_nodes):
        # 类型嵌入
        node_label = node_labels[i]
        type_emb = token_index.get(node_label, [0])  # 未知类型映射到0
        
        # 代码嵌入
        code_ids = convert_code_to_token(nodes_codes[i], tokenizer, block_size)
        emb_seq = []
        for cid in code_ids:
            try:
                emb_seq.append(word_embeddings[cid])
            except IndexError:
                emb_seq.append(np.zeros(feature_dim))
        
        # 合并特征
        if emb_seq:
            code_emb = np.sum(emb_seq, axis=0)
        else:
            code_emb = np.zeros(feature_dim)
        
        full_feature = np.concatenate([type_emb, code_emb])
        node_features.append(full_feature)
    
    # 3. 填充节点特征
    node_features = preprocess_features(node_features, block_size)
    
    # 4. 处理目标数据
    node_target = sample["nodes_codes"][:num_nodes] + [0] * (block_size - num_nodes)
    node_lines = sample["code_lines"][:num_nodes] + [-1] * (block_size - num_nodes)
    
    return {
        "filename": sample["filename"],
        "node_feature": node_features.tolist(),
        "node_target": node_target,
        "edges": edges,
        "node_lines": node_lines,
        "target": sample["target"]
    }

def build_token_index(input_file):
    """构建节点类型索引字典"""
    token_index = {"UNK": [0]}  # 保留0作为未知类型
    with open(input_file, 'r') as f:
        for line in f:
            sample = json.loads(line)
            for label in sample["nodes_label"]:
                if label not in token_index:
                    token_index[label] = [len(token_index)]
    return token_index

def get_confidence(output, with_softmax=False):
    if not with_softmax:
        output = torch.softmax(output, dim=1)

    confidence, pred_label = torch.max(output, dim=1)
    return confidence, pred_label

def visualize_vulnerability(g, node_preds, orig_file):
    """生成节点级漏洞热力图"""
    node_lines = g.ndata['node_lines'].cpu().numpy()
    vul_map = {line: float(pred[1]) for node, pred in enumerate(node_preds)
               for line in node_lines[node].tolist()}  # 处理多行对应一个节点的情况
    
    # 读取源代码文件
    with open(orig_file, 'r') as f:
        lines = f.readlines()
    
    # 生成可视化
    for i, line in enumerate(lines):
        prob = vul_map.get(i, 0.0)
        if prob > 0.5:
            print(f"{i:4d} [{prob:.2f}]: {line.rstrip()}")