import os
import json
import torch
from transformers import BertForSequenceClassification
from safetensors.torch import load_model, save_model
from transformers import BertModel, BertTokenizer
import torch.nn as nn

# dirname = "wild-clean"

for dirname in ['reentrancy', 'wild-clean', 'external_call', 'access_control']:
    class ConvolutionalLayer(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(ConvolutionalLayer, self).__init__()
            self.conv1 = nn.Conv1d(in_channels, 512, kernel_size=3, padding=1, stride=2)
            self.conv2 = nn.Conv1d(512, 512, kernel_size=3, padding=1, stride=2)
            self.conv3 = nn.Conv1d(512, 512, kernel_size=3, padding=1, stride=2)
            self.conv4 = nn.Conv1d(512, out_channels, kernel_size=3, padding=1, stride=2)
            # self.pool = nn.MaxPool1d(kernel_size=2)

        def forward(self, x):
            x = self.conv1(x)
            # x = nn.functional.relu(x)
            # x = self.pool(x)
            x = self.conv2(x)
            # x = nn.functional.relu(x)
            # x = self.pool(x)
            x = self.conv3(x)
            # x = nn.functional.relu(x)
            # x = nn.functional.relu(x)
            # x = self.pool(x)
            x = self.conv4(x)
            return x


    # 文件路径
    file_path_folder = "./features/input_features"
    file_path = os.path.join(file_path_folder, dirname + ".txt")
    print("file_path ", file_path)

    # 保存结果到JSON文件
    output_file_path = "./features/output/" + dirname + '.json'
    print("output_file_path ", output_file_path)
    os.makedirs(os.path.dirname(output_file_path), exist_ok=True)

    if os.path.exists(output_file_path):
        continue

    # 读取文件内容
    with open(file_path, 'r', encoding='utf-8') as file:
        text = file.read()

    # 加载预训练的BERT模型和tokenizer
    model_dir = "./google-bert/bert-base-uncased"

    model = BertModel.from_pretrained(model_dir, output_hidden_states=True)
    load_model(model, "./models/" + dirname + "/model.safetensors", strict=False)
    # model = BertForSequenceClassification.from_pretrained(model_dir, num_labels=2, output_hidden_states=True)
    # load_model(model,
    #            "./models/" + dirname + "/model.safetensors", strict=False)
    # model.load_state_dict(r'E:\2024\experiment_code_clone\total4\comment_embedding\comments\models\reentrancy\config.json')
    tokenizer = BertTokenizer.from_pretrained(model_dir)

    # Tokenize数据集
    tokenized_inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length')

    # 获取输入IDs和tokens
    input_ids = tokenized_inputs['input_ids'].squeeze().tolist()
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    conv_layer = ConvolutionalLayer(in_channels=768, out_channels=512)
    # conv_layer = ConvolutionalLayer(in_channels=512, out_channels=512)
    # # 获取BERT模型的输出特征向量
    # with torch.no_grad():
    #     outputs = model(**tokenized_inputs)
    #     hidden_states = outputs.hidden_states[-1].squeeze(0)  # 获取倒数第二层的隐藏状态
    #     conv_output = conv_layer(hidden_states)
    #     hidden_states = conv_output
    # 获取BERT模型的输出特征向量
    with torch.no_grad():
        outputs = model(**tokenized_inputs)
        hidden_states = outputs.hidden_states[-1].squeeze(0)  # 获取最后一层的隐藏状态
        hidden_states = hidden_states.permute(1, 0)  # 转换维度，以适应Conv1d的输入
        conv_output = conv_layer(hidden_states)
        hidden_states = conv_output.permute(1, 0)  # 转换维度，以便后续处理
    # 创建词、编码和特征向量的字典
    word_features = []
    for token, token_id, feature in zip(tokens, input_ids, hidden_states):
        word_features.append({
            "token": token,
            "token_id": token_id,
            "features": feature.tolist()
        })


    with open(output_file_path, 'w', encoding='utf-8') as output_file:
        json.dump(word_features, output_file, ensure_ascii=False, indent=4)

    print("结果已保存到", output_file_path)
