import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.nn as pyg_nn
from graphextractor import process_graph_data
from graphextractor import GraphDataset


class GCNGRU(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, gru_hidden_size, num_gru_layers, dropout=0.1):
        super(GCNGRU, self).__init__()
        self.conv1 = pyg_nn.GCNConv(input_dim + 1, hidden_dim)  # 加入边类型作为特征
        self.gru = nn.GRU(hidden_dim, gru_hidden_size, num_layers=num_gru_layers, batch_first=True)
        self.fc = nn.Linear(gru_hidden_size, output_dim)
        self.regression = nn.Linear(output_dim, 1)  # 回归层
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index, edge_type):
        x = torch.cat([x, edge_type], dim=1)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)

        x = x.unsqueeze(0)  # 添加batch维度
        x, _ = self.gru(x)
        x = x.squeeze(0)  # 去除batch维度

        features = self.fc(x)
        output = self.regression(features)  # 回归层输出
        return features, output


# 提取所有数据的特征向量
def extract_features(model, dataset, device):
    model.eval()
    features_list = {}
    # data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

    with torch.no_grad():
        for data in dataset:
            node_features, edge_index, edge_type = process_graph_data(data, device)

            # print(node_features.shape)
            # print(edge_index)
            # print(edge_type)
            features, _ = model(node_features, edge_index, edge_type)

            features = torch.mean(features, dim=0, keepdim=True)
            features = features.squeeze().tolist()
            features_list[data['contract_name']] = [f"{x:.8f}" for x in features]
            # features_list.append({
            #     'contract_name': data['contract_name'],
            #     'features': [f"{x:.8f}" for x in features]
            # })

    return features_list


if __name__ == '__main__':
    data_base_dir = './features-cfg'
    data_cfg_dir = './data-cfg'  # 请替换成数据文件夹的路径
    for subdir in os.listdir(data_base_dir):
        data_dir = os.path.join(data_base_dir, subdir)

        if os.path.isdir(data_dir) and "__pycache__" not in data_dir:
            print(data_dir)
            model_dir = os.path.join(data_dir, 'best_model.pth')
            # 加载保存的模型权重
            model = GCNGRU(input_dim=256, hidden_dim=512, output_dim=512, gru_hidden_size=512, num_gru_layers=3)
            model.load_state_dict(torch.load(model_dir))
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            model = model.to(device)

            all_dataset = GraphDataset(os.path.join(data_cfg_dir, subdir), split='all')
            output_features = extract_features(model, all_dataset.data, device)

            # 保存特征到output_features.json文件
            with open(os.path.join(data_dir, 'output_features.json'), 'w') as f:
                json.dump(output_features, f)
