import re
import html
import os
from gensim.models import Word2Vec
import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
from tqdm import tqdm

# �~H~F�~M�~G��~U�
def tokenize_expression(expression: str) -> list:
    # �~L��~E~M�| ~G记�~L�~L~E�~P��| ~G�~F符�~@~A�~U��~W�~@~A�~P�~W符�~@~A�~K��~O��~@~A�~W符串�~I
    tokens = re.findall(r'\w+|\S', expression)
    # �~N��~Y��~M�~\~@�~A�~Z~D符�~O�
    tokens = [token for token in tokens if token not in {',', '(', ')', '[', ']', '"'}]
    return tokens

def parse_dot_file(dot_content):
    nodes = []
    edges = []
    node_alltokens=[]
    node_id_to_index={}
    index=0
    
    # 正则表达式匹配节点和边
    node_pattern = r'"(\d+)"\s+\[label\s*=\s*<([^>]+)>'
    edge_pattern = r'"(\d+)"\s*->\s*"(\d+)"\s*\[label\s*=\s*"([^"]+)"\]'

    # 匹配节点
    for match in re.finditer(node_pattern, dot_content):
        node_id = match.group(1)
        label = match.group(2)
        # 处理 HTML 实体并去除 <SUB> 和 </SUB> 标签及其内容
        clean_label = html.unescape(label)
        clean_label = re.sub(r'<SUB', '', clean_label)  # 去除 <SUB> 标签及其内容
        # 匹配以大写字母开头的单词（如 IDENTIFIER）
        first_word_match = re.search(r'\b[A-Z_]+\b', clean_label.split(',')[0])
        if first_word_match:
            first_word = first_word_match.group(0)
        # 提取剩余部分，去掉第一个逗号之前的内容
            remaining_code = clean_label[len(first_word_match.group(0)) + 1:]
            remaining_code = remaining_code.strip()
            remaining_code = re.sub(r'^[,(]+', '', remaining_code)  # 去掉开头的 ( 或 ,
            node_text = re.sub(r'\)$', '', remaining_code, 1)  # 去掉最后一个 )     
        else:
        # 如果没有匹配到第一个大写单词，返回原文本
            first_word="METHOD"
            remaining_code = re.sub(r'^[,(]+', '', clean_label)  # 去掉开头的 ( 或 ,
            node_text = re.sub(r'\)$', '', remaining_code, 1)  # 去掉最后一个 )  
    	# 存储 node_id 和 index 的映射
        node_id_to_index[node_id] = index
        index += 1
        # 存储word2vct的所有token
        node_tokens=tokenize_expression(node_text)
        node_alltokens.append(node_tokens)      
        # 存储nodes
        nodes.append((node_id,clean_label,first_word,node_text))

    # 匹配边
    for match in re.finditer(edge_pattern, dot_content):
        source = match.group(1)
        target = match.group(2)
        label = match.group(3)
        edge_type_match = re.search(r'\b[A-Z_]+\b', label.split(':')[0])
        edge_type=edge_type_match.group(0)
        edge_code = label[len(edge_type_match.group(0)) + 1:]
        edge_code = edge_code.strip()
        edge_text = re.sub(r'^[:(]+', '', edge_code)  # 去掉开头的 ( 或 ,       
        edges.append((source, target, label,edge_type,edge_text))  
             
    return nodes, edges,node_alltokens,node_id_to_index
    
def save_to_file(unique_words, file_path):
    with open(file_path, 'w') as file:
        for word, count in unique_words.items():
            file.write(f'"{word}": {count}\n')

def collect_unique_words(directory_path):
    unique_words={}
    counter=0
    
    for filename in os.listdir(directory_path):
        if filename.endswith('.dot'):
            file_path = os.path.join(directory_path, filename)
            
            # 从文件中读取内容
            with open(file_path, 'r') as file:
                dot_content = file.read()
            # 解析 .dot 文件内容
            nodes, edges, _, _ = parse_dot_file(dot_content)
            
            #获取node中唯一token
            for _, _, first_word, _ in nodes:
                if first_word not in unique_words:
                    unique_words[first_word] = counter
                    counter += 1
            #获取edge中的唯一token
            for _, _, _, edge_type, _ in edges:
                if edge_type not in unique_words:
                    unique_words[edge_type] = counter
                    counter += 1 

    return unique_words

def load_from_file(file_path):
    unique_words = {}
    counter = 0
    with open(file_path, 'r') as file:
        for line in file:
            word, index = line.strip().split(': ')
            unique_words[word.strip('"')] = int(index)
            counter = max(counter, int(index) + 1)
    return unique_words, counter

def process_file(file_path, unique_words, w2vmodel, outputgraph_path):
    filename = os.path.basename(file_path).replace('.dot', '.c')
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            dot_content = file.read()
        nodes, edges, node_alltokens, node_id_to_index = parse_dot_file(dot_content)

        nodeRepresentations = []
        edges_index = []
        edges_attr = []

        for node_id, clean_label, first_word, node_text in nodes:
            label_embedding = unique_words.get(first_word, 0)
            if node_text:
                transfer_node = html.unescape(node_text)
                node_tokens = tokenize_expression(transfer_node)
                sentence_node = [w2vmodel.wv[word] for word in node_tokens if word in w2vmodel.wv]
                if sentence_node:
                    source_embedding = np.mean(sentence_node, axis=0)
                else:
                    source_embedding = np.zeros(100)
                node_embedding = np.concatenate((np.array([label_embedding]), source_embedding), axis=0)
            else:
                zero_vecnode = np.zeros(100)
                node_embedding = np.concatenate((np.array([label_embedding]), zero_vecnode), axis=0)
            nodeRepresentations.append(node_embedding)

        for source, target, label, edge_type, edge_text in edges:
            if source not in node_id_to_index or target not in node_id_to_index:
                continue

            source_index = node_id_to_index[source]
            target_index = node_id_to_index[target]
            if edge_text:
                transfer_edge = html.unescape(edge_text)
                edge_tokens = tokenize_expression(transfer_edge)
                sentence_edge = [w2vmodel.wv[word] for word in edge_tokens if word in w2vmodel.wv]
                if sentence_edge:
                    edgetext_embedding = np.mean(sentence_edge, axis=0)
                else:
                    edgetext_embedding = np.zeros(100)
                label_attr = unique_words.get(edge_type, 0)
                edge_embedding = np.concatenate((np.array([label_attr]), edgetext_embedding), axis=0)
            else:
                label_attr = unique_words.get(edge_type, 0)
                zero_vector = np.zeros(100)
                edge_embedding = np.concatenate((np.array([label_attr]), zero_vector), axis=0)
            edges_index.append([source_index, target_index])
            edges_attr.append(edge_embedding)

        edges_array = np.array(edges_index).T
        np.savez_compressed(os.path.join(outputgraph_path, f"{filename}.npz"), node_rep=nodeRepresentations)
        np.savez_compressed(os.path.join(outputgraph_path, f"{filename}_Edges.npz"), edges_index=edges_array, edges_attr=np.array(edges_attr))
        return filename + " 已处理"
    except Exception as e:
        return f"Error processing {filename}: {e}"

def predata_npz(directory_paths, outputgraph_path, out_file_path, model_load_path, batch_size=8, max_workers=8):
    if not os.path.exists(outputgraph_path):
        os.makedirs(outputgraph_path)

    # 第一步：收集所有文件夹中的唯一词汇
    unique_words, _ = load_from_file(out_file_path)

    # 第二步：加载已训练的Word2Vec模型
    model_load_path = "word2vec_model.model"
    w2vmodel = Word2Vec.load(model_load_path)

    # 第三步：并行处理所有文件夹中的文件
    files = []
    for directory_path in directory_paths:
        files.extend([os.path.join(directory_path, f) for f in os.listdir(directory_path) if f.endswith('.dot')])

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_file, file_path, unique_words, w2vmodel, outputgraph_path): file_path for file_path in files}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
            result = future.result()
            print(result)

    print("----所有文件处理已完成----")

# 主函数
if __name__ == "__main__":
    directory_paths = ['./data/split-file/folder_1/', './data/split-file/folder_2/', './data/split-file/folder_3/','./data/split-file/folder_4/', './data/split-file/folder_5/', './data/split-file/folder_6/','./data/split-file/folder_7/', './data/split-file/folder_8/', './data/split-file/folder_9/','./data/split-file/folder_10/', './data/split-file/folder_11/', './data/split-file/folder_12/','./data/split-file/folder_13/', './data/split-file/folder_14/', './data/split-file/folder_15/', './data/split-file/folder_16/','./data/split-file/folder_17/','./data/split-file/folder_18/','./data/split-file/folder_19/','./data/split-file/folder_20/','./data/split-file/folder_21/','./data/split-file/folder_22/','./data/split-file/folder_23/','./data/split-file/folder_24/']  # 添加你所有的文件夹路径
    outputgraph_path = './data/GNNInput-npz/'
    out_file_path = './data/token.txt'
    model_load_path = "./word2vec_model.model"
    predata_npz(directory_paths, outputgraph_path, out_file_path, model_load_path, batch_size=8, max_workers=8)

## 测试代码
# 遍历目录中的所有 .dot 文件
#for filename in os.listdir(directory_path):
    #if filename.endswith('.dot'):
        #file_path = os.path.join(directory_path, filename)
        
        # 从文件中读取内容
        #with open(file_path, 'r') as file:
            #dot_content = file.read()
        
        # 解析 .dot 文件内容
        #nodes, edges, node_alltokens, node_id_to_index = parse_dot_file(dot_content)
        
        #for node_id, clean_label, first_word,node_text in nodes:
            #if first_word not in unique_words:
                #unique_words[first_word] = counter
                #counter += 1
            #tokens=tokenize_expression(node_text)
            #all_tokens.append(tokens)
        
            # 将 node_text 转换为向量 (例如通过 word2vec, GloVe, 或者直接使用预训练的模型)
            #node_embedding = get_embedding_from_text(node_text)  # 自定义的函数
        
            # 存储 node_id 和 index 的映射
            #node_id_to_index[node_id] = index

            # 将向量存入 node_embeddings 列表中
            #node_embeddings.append(node_embedding)

            #index += 1
        #w2v_init = True
        #w2vmodel.build_vocab(sentences=all_tokens, update=not w2v_init)
        #w2vmodel.train(all_tokens, total_examples=w2vmodel.corpus_count, epochs=1)
        #if w2v_init:
            #w2v_init = False    
        # 打印结果
        # 打印或保存 unique_words 字典
        #for word, index in unique_words.items():
            #print(f'"{word}": {index}')
        #print(f"\n文件: {filename}")
        #print("节点:")
        #count1=0
        #count2=0
        #for node_id, clean_label,first_word,node_text in nodes:     
            #print(f'{node_id}: {clean_label}')
            #print(first_word)
            #print(node_text)
            #count1=count1+1
           
        #print(count1)
        #print("\n边:")
        #for source, target, label, edge_type, edge_text in edges:          
            #print(f'{source} -> {target} [label = "{label}"]')
            #count2=count2+1
            #source_index=node_id_to_index[source]
            #target_index=node_id_to_index[target]
            #edge_type_match = re.search(r'\b[A-Z_]+\b', label.split(':')[0])
            #edge_type=edge_type_match.group(0)
            #edge_code = label[len(edge_type_match.group(0)) + 1:]
            #edge_code = edge_code.strip()
            #edge_text = re.sub(r'^[:(]+', '', edge_code)  # 去掉开头的 ( 或 ,
            #edges_index.append([source_index,target_index])
            #if edge_type not in unique_words:
                #unique_words[edge_type] = counter
                #counter += 1
            
            #print(edge_text)
        #print(count2)
        #print(edges_index)	
