from gensim.models import Word2Vec
import os
from tqdm import tqdm
import re
import html
import numpy as np

# �~H~F�~M�~G��~U�
def tokenize_expression(expression: str) -> list:
    # �~L��~E~M�| ~G记�~L�~L~E�~P��| ~G�~F符�~@~A�~U��~W�~@~A�~P�~W符�~@~A�~K��~O��~@~A�~W符串�~I
    tokens = re.findall(r'\w+|\S', expression)
    # �~N��~Y��~M�~\~@�~A�~Z~D符�~O�
    tokens = [token for token in tokens if token not in {',', '(', ')', '[', ']', '"'}]
    return tokens

def parse_dot_file(dot_content):
    nodes = []
    edges = []
    node_alltokens=[]
    node_id_to_index={}
    index=0

    # 正�~H~Y表达�~O�~L��~E~M�~J~B�~B��~R~L边
    node_pattern = r'"(\d+)"\s+\[label\s*=\s*<([^>]+)>'
    edge_pattern = r'"(\d+)"\s*->\s*"(\d+)"\s*\[ label\s*=\s*"([^"]+)"\]'

    # �~L��~E~M�~J~B�~B�
    for match in re.finditer(node_pattern, dot_content):
        node_id = match.group(1)
        label = match.group(2)
        # �~D�~P~F HTML �~^�~S并�~N��~Y� <SUB> �~R~L </SUB> �| ~G签�~O~J�~E��~F~E容
        clean_label = html.unescape(label)
        clean_label = re.sub(r'<SUB', '', clean_label)  # �~N��~Y� <SUB> �| ~G签�~O~J�~E��~F~E容
        # �~L��~E~M以大�~F~Y�~W�~M�~@头�~Z~D�~M~U�~M�~H�~B IDENTIFIER�~I
        first_word_match = re.search(r'\b[A-Z_]+\b', clean_label.split(',')[0])
        if first_word_match:
            first_word = first_word_match.group(0)
        # �~O~P�~O~V�~I��~Y�~C��~H~F�~L�~N��~N~I第�~@个�~@~W�~O��~K�~I~M�~Z~D�~F~E容
            remaining_code = clean_label[len(first_word_match.group(0)) + 1:]
            remaining_code = remaining_code.strip()
            remaining_code = re.sub(r'^[,(]+', '', remaining_code)  # �~N��~N~I�~@头�~Z~D ( �~H~V ,
            node_text = re.sub(r'\)$', '', remaining_code, 1)  # �~N��~N~I�~\~@�~P~N�~@个 )     
        else:
        # �~B�~^~\没�~\~I�~L��~E~M�~H�第�~@个大�~F~Y�~M~U�~M�~L�~T�~[~^�~N~_�~V~G�~\�
            first_word="METHOD"
            remaining_code = re.sub(r'^[,(]+', '', clean_label)  # �~N��~N~I�~@头�~Z~D ( �~H~V ,
            node_text = re.sub(r'\)$', '', remaining_code, 1)  # �~N��~N~I�~\~@�~P~N�~@个 )  
        # �~X�~B� node_id �~R~L index �~Z~D�~X| �~D
        node_id_to_index[node_id] = index
        index += 1
        # �~X�~B�word2vct�~Z~D�~I~@�~\~Itoken
        node_tokens=tokenize_expression(node_text)
        node_alltokens.append(node_tokens)
        # �~X�~B�nodes
        nodes.append((node_id,clean_label,first_word,node_text))

    # �~L��~E~M边
    for match in re.finditer(edge_pattern, dot_content):
        source = match.group(1)
        target = match.group(2)
        label = match.group(3)
        edge_type_match = re.search(r'\b[A-Z_]+\b', label.split(':')[0])
        edge_type=edge_type_match.group(0)
        edge_code = label[len(edge_type_match.group(0)) + 1:]
        edge_code = edge_code.strip()
        edge_text = re.sub(r'^[:(]+', '', edge_code)  # �~N��~N~I�~@头�~Z~D ( �~H~V ,       
        edges.append((source, target, label,edge_type,edge_text))

    return nodes, edges,node_alltokens,node_id_to_index


def train_word2vec_on_folders(base_directory, folder_list, vector_size=100, epochs=1, workers=4, final_training=True):
    w2v_model = Word2Vec(vector_size=vector_size, min_count=1, alpha=0.01, sample=1e-5, workers=workers, sg=1, hs=0, negative=5)
    w2v_init = True
    all_tokens = []

    for folder in folder_list:
        folder_path = os.path.join(base_directory, folder)
        batch_tokens = []

        files = [f for f in os.listdir(folder_path) if f.endswith('.dot')]
        for filename in tqdm(files, desc=f"Processing folder {folder}"):
            file_path = os.path.join(folder_path, filename)
            with open(file_path, 'r', encoding='utf-8') as file:
                dot_content = file.read()
            _, _, node_alltokens, _ = parse_dot_file(dot_content)
            batch_tokens.extend(node_alltokens)
            all_tokens.extend(node_alltokens)
        print("Starting training")
        if w2v_init:
            w2v_model.build_vocab(corpus_iterable=batch_tokens)
            w2v_model.train(corpus_iterable=batch_tokens, total_examples=w2v_model.corpus_count, epochs=epochs)
            w2v_init = False
        else:
            w2v_model.build_vocab(corpus_iterable=batch_tokens, update=True)
            w2v_model.train(corpus_iterable=batch_tokens, total_examples=w2v_model.corpus_count, epochs=epochs)

    if final_training:
        print("Starting final fine-tuning...")
        w2v_model.train(corpus_iterable=all_tokens, total_examples=w2v_model.corpus_count, epochs=epochs)

    return w2v_model

# 假设有13个文件夹，每个文件夹包含1000个文件
base_directory = './data/split-file/'
folder_list = [f"folder_{i}" for i in range(1, 25)]

# 训练模型并保存
w2v_model = train_word2vec_on_folders(base_directory, folder_list)
model_save_path = "word2vec_model.model"
w2v_model.save(model_save_path)
print(f"Model saved to {model_save_path}")

