import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import os
import torch.nn as nn
from safetensors.torch import load_model, save_model
import torch
import numpy as np
from sklearn.cluster import KMeans
from transformers import BertTokenizer, BertModel
import json

nltk.download('stopwords')

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'


def preprocess_text(text):
    # 去除网址
    text = re.sub(r'http\S+', '', text)
    # 去除日期格式（例如 2024-07-07 或 07/07/2024）
    text = re.sub(r'\b\d{4}[-/]\d{2}[-/]\d{2}\b', '', text)
    text = re.sub(r'\b\d{2}[-/]\d{2}[-/]\d{4}\b', '', text)
    # 去除特殊字符
    text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
    # 小写化
    text = text.lower()
    # 分词
    words = word_tokenize(text)
    # 去除停词
    stop_words = set(stopwords.words('english'))
    words = [word for word in words if word not in stop_words]
    # 去除特定词语
    exclude_words = ['contract', 'functions', 'ethereum', 'smartcontract', 'smart contract', 'smart contracts',
                     'smartcontracts', 'solidity', 'etherscan', 'may', '2019']
    words = [word for word in words if word.lower() not in exclude_words]
    num_words = len(words)
    return ' '.join(words), num_words


class ConvolutionalLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvolutionalLayer, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, 512, kernel_size=3, padding=1, stride=2)
        self.conv2 = nn.Conv1d(512, 512, kernel_size=3, padding=1, stride=2)
        self.conv3 = nn.Conv1d(512, 512, kernel_size=3, padding=1, stride=2)
        self.conv4 = nn.Conv1d(512, out_channels, kernel_size=3, padding=1, stride=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return x


def get_model_dirname(file_num):
    model_dirname = ''
    if file_num <= 41:
        model_dirname = "access_control"
    elif file_num <= 1060 and file_num >= 299:
        model_dirname = "reentrancy"
    elif file_num <= 2918 and file_num >= 1061:
        model_dirname = "wild-clean"
    elif file_num <= 3336 and file_num >= 2919:
        model_dirname = "external_call"
    return model_dirname


comments_dir = '/root/pll/comments/'
topic_dir = '/root/pll/features/topicembedding/'

for dirpath, dirnames, filenames in os.walk(comments_dir):
    for dirname in dirnames:
        if '.idea' not in dirname and "delegatecall" not in dirname:
            dir_folder = os.path.join(dirpath, dirname)
            for file in os.listdir(dir_folder):
                if file.endswith('.txt'):
                    cluster_keywords = []

                    topic_file = os.path.join(topic_dir, file[:-4] + ".json")
                    print("the input file is", topic_file)
                    # if os.path.exists(topic_file):
                    #     continue

                    print("The comment file is:", os.path.join(dir_folder, file))
                    # 读取并预处理文本
                    with open(os.path.join(dir_folder, file), 'r') as f:
                        text = f.read()
                    file_num = int(file[:-4])

                    text, num_words = preprocess_text(text)
                    if len(text) == 0:
                        continue
                    topic_num = max(1, num_words // 10)
                    print("text is", text)

                    # 加载预训练的BERT模型和tokenizer
                    model_dir = "/root/pll/google-bert/bert-base-uncased/"
                    model = BertModel.from_pretrained(model_dir, output_hidden_states=True)

                    model_dirname = get_model_dirname(file_num)
                    model_folder = '/root/pll/models/'
                    model_folder = os.path.join(model_folder, model_dirname)
                    model_path = os.path.join(model_folder, "model.safetensors")
                    load_model(model, model_path, strict=False)
                    tokenizer = BertTokenizer.from_pretrained(model_dir)

                    # Tokenize数据集
                    tokenized_inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True,
                                                 max_length=512)

                    # 获取输入IDs和tokens
                    input_ids = tokenized_inputs['input_ids'].squeeze().tolist()
                    tokens = tokenizer.convert_ids_to_tokens(input_ids)

                    conv_layer = ConvolutionalLayer(in_channels=768, out_channels=512)

                    # 获取BERT模型的输出特征向量
                    with torch.no_grad():
                        outputs = model(**tokenized_inputs)
                        hidden_states = outputs.hidden_states[-1].squeeze(0)  # 获取最后一层的隐藏状态
                        hidden_states = hidden_states.permute(1, 0)  # 转换维度，以适应Conv1d的输入
                        conv_output = conv_layer(hidden_states)
                        hidden_states = conv_output.permute(1, 0)  # 转换维度，以便后续处理

                    # 合并子词以处理完整词
                    word_features = []
                    current_word = ""
                    current_features = []

                    for token, token_id, feature in zip(tokens, input_ids, hidden_states):
                        if token not in ['[CLS]', '[SEP]', '[PAD]']:
                            if token.startswith("##"):
                                current_word += token[2:]
                                current_features.append(feature)
                            else:
                                if current_word:
                                    # 处理前一个词
                                    word_features.append({
                                        "token": current_word,
                                        "features": torch.stack(current_features).mean(dim=0).tolist()
                                    })
                                current_word = token
                                current_features = [feature]

                    if current_word:
                        # 处理最后一个词
                        word_features.append({
                            "token": current_word,
                            "features": torch.stack(current_features).mean(dim=0).tolist()
                        })

                    # 确保 word_features 非空再进行聚类
                    if word_features:
                        # 转换特征向量为 numpy 数组
                        feature_vectors = np.array([word["features"] for word in word_features])

                        # 使用 K-Means 聚类
                        num_clusters = min(topic_num, len(feature_vectors))
                        kmeans = KMeans(n_clusters=num_clusters, random_state=0, n_init='auto').fit(feature_vectors)

                        # 获取每个簇的中心点
                        centroids = kmeans.cluster_centers_

                        # 合并相似的词语或子词
                        merged_keywords = {}

                        for i in range(num_clusters):
                            # 获取属于当前簇的所有词特征
                            cluster_indices = np.where(kmeans.labels_ == i)[0]
                            cluster_features = feature_vectors[cluster_indices]

                            # 计算簇内每个点到簇中心的距离
                            distances = np.linalg.norm(cluster_features - centroids[i], axis=1)
                            # 获取距离最近的词索引
                            closest_index = cluster_indices[np.argmin(distances)]
                            cluster_word = word_features[closest_index]

                            # 将词语或子词添加到合并字典中
                            token = cluster_word["token"]
                            if token in merged_keywords:
                                merged_keywords[token]["features"] = (np.array(merged_keywords[token]["features"]) +
                                                                     np.array(cluster_word["features"])).tolist()
                            else:
                                merged_keywords[token] = {
                                    "token": token,
                                    "features": cluster_word["features"]
                                }

                        # 将合并字典转换为列表形式
                        cluster_keywords = list(merged_keywords.values())

                    # 排序并截取前5个关键词
                    cluster_keywords.sort(key=lambda x: x["features"], reverse=True)
                    cluster_keywords = cluster_keywords[:5]

                    print("Top 5 keywords and their embeddings:")
                    for keyword in cluster_keywords:
                        print(f"Keyword: {keyword['token']}")

                    # 将聚类结果写入JSON文件
                    with open(topic_file, 'w') as f:
                        json.dump(cluster_keywords, f)
