import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.nn as pyg_nn


# Step 1: 加载数据
# class GraphDataset:
#     def __init__(self, data_dir, split='train'):
#         self.data = []
#         file_name = f'{split}.json'
#         with open(os.path.join(data_dir, file_name), 'r') as f:
#             self.data = json.load(f)
#         self.max_length = max(max(max(g[0], g[2]) for g in sample['graph']) for sample in self.data) + 1
#
#         print("self.max_length", self.max_length)
#
#     def __len__(self):
#         return len(self.data)
#
#     def __getitem__(self, idx):
#         sample = self.data[idx]
#
#         # 填充节点特征
#         node_features = sample['node_features']
#         padded_node_features = node_features + [[0.0] * len(node_features[0])] * (self.max_length - len(node_features))
#
#         # 填充graph
#         graph = sample['graph']
#         padded_graph = graph + [[0, 0, 0]] * (self.max_length - len(graph))
#
#         sample['node_features'] = padded_node_features
#         sample['graph'] = padded_graph
#
#         return sample

class GraphDataset:
    def __init__(self, data_dir, split='train'):
        self.data = []
        file_name = f'{split}.json'
        with open(os.path.join(data_dir, file_name), 'r') as f:
            self.data = json.load(f)
        self.max_length = max(max(max(g[0], g[2]) for g in sample['graph']) for sample in self.data) + 1

        print("self.max_length", self.max_length)

        # 在初始化时可以选择填充所有样本
        self.data = [self._pad_sample(sample) for sample in self.data]

    def _pad_sample(self, sample):
        # 填充节点特征
        node_features = sample['node_features']
        padded_node_features = node_features + [[0.0] * len(node_features[0])] * (self.max_length - len(node_features))

        # 填充graph
        graph = sample['graph']
        padded_graph = graph + [[0, 0, 0]] * (self.max_length - len(graph))

        sample['node_features'] = padded_node_features
        sample['graph'] = padded_graph

        return sample

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return self._pad_sample(sample)

# Step 2: 定义模型
class GCNGRU(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, gru_hidden_size, num_gru_layers, dropout=0.1):
        super(GCNGRU, self).__init__()
        self.conv1 = pyg_nn.GCNConv(input_dim + 1, hidden_dim)  # 加入边类型作为特征
        self.gru = nn.GRU(hidden_dim, gru_hidden_size, num_layers=num_gru_layers, batch_first=True)
        self.fc = nn.Linear(gru_hidden_size, output_dim)
        self.regression = nn.Linear(output_dim, 1)  # 回归层
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index, edge_type):
        x = torch.cat([x, edge_type], dim=1)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)

        x = x.unsqueeze(0)  # 添加batch维度
        x, _ = self.gru(x)
        x = x.squeeze(0)  # 去除batch维度

        features = self.fc(x)
        output = self.regression(features)  # 回归层输出
        return features, output


def process_graph_data(data, device):
    edge_index_list = []
    edge_type_list = []

    for batch in data['graph']:
        edge_index = torch.tensor([batch[0], batch[2]], dtype=torch.long).view(2, 1).contiguous().to(device)
        edge_type = torch.tensor(batch[1], dtype=torch.long).view(1, 1).to(device)
        edge_index_list.append(edge_index)
        edge_type_list.append(edge_type)

    edge_index = torch.cat(edge_index_list, dim=1)  # 连接成 [2, num_edges]
    shape = edge_index.shape[1]
    edge_type = torch.cat(edge_type_list, dim=0)  # 连接成 [num_edges, 1]
    # 如果 edge_type 的长度小于 shape，需要进行填充
    if edge_type.shape[0] < shape:
        padding_length = shape - edge_type.shape[0]
        padding = torch.zeros(padding_length, dtype=edge_type.dtype, device=edge_type.device)
        edge_type = torch.cat([edge_type, padding], dim=0)

    node_features = torch.tensor(data['node_features'], dtype=torch.float).to(device)

    return node_features, edge_index, edge_type


def calculate_accuracy(output, targets, threshold=0.1):
    # 计算准确性，假设输出和 targets 形状匹配
    correct = (torch.abs(output - targets) < threshold).float().sum()
    return correct / len(targets)


def evaluate(model, dataset, device):
    model.eval()
    data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
    total_loss = 0
    total_accuracy = 0

    with torch.no_grad():
        for data in data_loader:
            targets = torch.tensor([float(target) for target in data['targets']], dtype=torch.float).to(device)
            node_features, edge_index, edge_type = process_graph_data(data, device)

            # num_nodes = node_features.shape[0]
            # if edge_index.max() >= num_nodes:
            #     print(f"Skipping sample due to out of bounds edge index: {edge_index.max()} >= {num_nodes}")
            #     continue

            features, output = model(node_features, edge_index, edge_type)
            output = output.squeeze()  # 去掉不必要的维度

            targets = torch.cat([targets] * output.shape[0])

            loss = criterion(output, targets)
            total_loss += loss.item()
            total_accuracy += calculate_accuracy(output, targets).item()

    print("total_loss", total_loss)
    print("total_accuracy", total_accuracy)
    avg_loss = total_loss / len(data_loader)
    avg_accuracy = total_accuracy / len(data_loader)

    return avg_loss, avg_accuracy


def train_model(model, train_dataset, valid_dataset, device):
    best_loss = float('inf')
    best_model_wts = None

    # 训练和验证
    for epoch in range(100):
        model.train()
        num_batches = 0
        total_train_loss = 0

        for data in train_dataset:
            targets = torch.tensor([float(target) for target in data['targets']], dtype=torch.float).to(
                device)  # 将targets转换为浮点数
            node_features, edge_index, edge_type = process_graph_data(data, device)

            # if edge_index.max() >= node_features.shape[0]:
            #     print(
            #         f"Skipping sample due to out of bounds edge index: {edge_index.max()} >= {node_features.shape[0]}")
            #     continue

            optimizer.zero_grad()  # 清除梯度
            features, output = model(node_features, edge_index, edge_type)
            output = output.squeeze()  # 去掉不必要的维度
            targets = torch.cat([targets] * output.shape[0])

            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()  # 更新权重

            total_train_loss += loss.item()  # 累积每个batch的训练损失
            num_batches += 1
            # print("num_batches is {}".format(num_batches))

        avg_train_loss = total_train_loss / num_batches

        if avg_train_loss < best_loss:
            best_loss = avg_train_loss
            best_model_wts = model.state_dict()

        # 验证模型
        valid_loss, valid_accuracy = evaluate(model, valid_dataset, device)

        print(
            f'Epoch {epoch + 1}, Train Loss: {avg_train_loss:.8f}, Valid Loss: {valid_loss:.8f}, Valid Accuracy: {valid_accuracy:.8f}')
    return best_model_wts


if __name__ == '__main__':
    data_base_dir = './data-cfg'
    output_base_dir = './features-cfg'

    input_dim = 256
    hidden_dim = 512
    output_dim = 512
    gru_hidden_size = 512
    num_gru_layers = 3  # GRU的层数

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for subdir in os.listdir(data_base_dir):
        data_dir = os.path.join(data_base_dir, subdir)

        if os.path.isdir(data_dir) and "__pycache__" not in data_dir:
            print(f"Processing directory: {data_dir}")

            train_dataset = GraphDataset(data_dir, split='train')
            valid_dataset = GraphDataset(data_dir, split='valid')

            model = GCNGRU(input_dim, hidden_dim, output_dim, gru_hidden_size, num_gru_layers)
            model = model.to(device)

            optimizer = optim.Adam(model.parameters(), lr=0.001)
            criterion = nn.MSELoss()  # 均方误差损失

            best_model_weight = train_model(model, train_dataset, valid_dataset, device)

            # 创建输出文件夹（如果不存在）
            output_dir = os.path.join(output_base_dir, subdir)
            os.makedirs(output_dir, exist_ok=True)

            # 保存最佳模型
            torch.save(best_model_weight, os.path.join(output_dir, 'best_model.pth'))
