import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt


def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子

# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 超参数
window_size = 7  # 滑动窗口
embed_dim = 1  # 嵌入维度 = 特征数
hidden_dim = 64  # 隐藏层维度
num_layers = 2  # 层数
num_nodes = 1  # 节点数量，对于单变量时间序列设为1

batch_size = 256
num_epochs = 100
learning_rate = 1e-3


# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 加载数据
try:
    dataset_swat = pd.read_csv('../../data/output-cnn.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # Create some dummy data for testing
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)

# 划分训练集和测试集
train_size = int(len(features) * 0.8)  # 80% 作为训练集
train_features, test_features = features[:train_size], features[train_size:]

# 重塑为二维
train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)

# 在数据划分后添加标准化逻辑
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
test_features = scaler.transform(test_features)

train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# ============== MegaCRN模型组件 ==============
class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, x, adj):
        support = torch.matmul(x, self.weight)
        output = torch.matmul(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output


class GCRUCell(nn.Module):
    def __init__(self, node_num, dim_in, dim_out):
        super(GCRUCell, self).__init__()
        self.node_num = node_num
        self.hidden_dim = dim_out

        # 更新门组件
        self.gc_update = GraphConvolution(dim_in + dim_out, dim_out)
        self.update_gate = nn.Sigmoid()

        # 重置门组件
        self.gc_reset = GraphConvolution(dim_in + dim_out, dim_out)
        self.reset_gate = nn.Sigmoid()

        # 候选隐藏状态组件
        self.gc_candidate = GraphConvolution(dim_in + dim_out, dim_out)
        self.candidate_activation = nn.Tanh()

        # 自适应邻接矩阵
        self.adaptive_adj = nn.Parameter(torch.FloatTensor(node_num, node_num))
        nn.init.xavier_uniform_(self.adaptive_adj)

    def forward(self, x, h):
        batch_size = x.size(0)

        # 确保邻接矩阵是对称的
        adj = F.relu(self.adaptive_adj)
        adj = 0.5 * (adj + adj.transpose(0, 1))

        # 扩展邻接矩阵以适应批次维度
        adj = adj.unsqueeze(0).repeat(batch_size, 1, 1)

        # 组合输入和隐藏状态
        combined = torch.cat([x, h], dim=2)

        # 更新门
        z = self.update_gate(self.gc_update(combined, adj))

        # 重置门
        r = self.reset_gate(self.gc_reset(combined, adj))

        # 候选隐藏状态
        combined_reset = torch.cat([x, r * h], dim=2)
        h_tilde = self.candidate_activation(self.gc_candidate(combined_reset, adj))

        # 新的隐藏状态
        h_new = (1 - z) * h + z * h_tilde

        return h_new


class DilatedInception(nn.Module):
    def __init__(self, cin, cout, dilation_factor=2):
        super(DilatedInception, self).__init__()
        self.tconv = nn.ModuleList()
        self.kernel_set = [2, 3, 6, 7]
        self.tconv = nn.ModuleList()

        for kern in self.kernel_set:
            self.tconv.append(nn.Conv2d(cin, cout, (1, kern), dilation=(1, dilation_factor)))

    def forward(self, x):
        x_out = []
        for i in range(len(self.kernel_set)):
            x_out.append(self.tconv[i](x))

        for i in range(len(self.kernel_set)):
            x_out[i] = x_out[i][..., -x.size(3):]

        x_out = torch.cat(x_out, dim=1)
        return x_out


class MultiScaleModule(nn.Module):
    def __init__(self, in_channels, out_channels, scales=3):
        super(MultiScaleModule, self).__init__()
        self.scales = scales
        self.dilated_layers = nn.ModuleList()

        for i in range(scales):
            dilation_factor = 2 ** i
            self.dilated_layers.append(DilatedInception(in_channels, out_channels // scales, dilation_factor))

    def forward(self, x):
        outputs = []
        for layer in self.dilated_layers:
            outputs.append(layer(x))

        return torch.cat(outputs, dim=1)


# ============== MegaCRN模型 ==============
class MegaCRNModel(nn.Module):
    def __init__(self, num_nodes, input_dim, hidden_dim, output_dim=1, num_layers=2):
        super(MegaCRNModel, self).__init__()
        self.num_nodes = num_nodes  # 节点数量
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers

        # 输入投影
        self.input_proj = nn.Linear(input_dim, hidden_dim)

        # 多尺度模块
        self.multi_scale = MultiScaleModule(hidden_dim, hidden_dim)

        # GRU单元
        self.gcru_cells = nn.ModuleList()
        for i in range(num_layers):
            self.gcru_cells.append(GCRUCell(num_nodes, hidden_dim, hidden_dim))

        # 输出层
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        """
        x: 输入张量，形状 (batch_size, seq_len, input_dim)
        """
        batch_size, seq_len, _ = x.size()

        # 使用单变量时间序列，我们将其视为一个节点
        x = x.reshape(batch_size, seq_len, self.num_nodes, -1)

        # 初始化隐藏状态
        h = torch.zeros(batch_size, self.num_nodes, self.hidden_dim).to(x.device)

        outputs = []

        # 对每个时间步进行处理
        for t in range(seq_len):
            # 当前时间步的输入
            xt = x[:, t, :, :]  # shape: (batch_size, num_nodes, input_dim)

            # 输入投影
            xt = self.input_proj(xt)  # shape: (batch_size, num_nodes, hidden_dim)

            # 通过所有GCRU层
            for layer in self.gcru_cells:
                h = layer(xt, h)

            # 输出层
            out = self.output_layer(h)  # shape: (batch_size, num_nodes, output_dim)
            outputs.append(out)

        # 堆叠时间步的输出
        outputs = torch.stack(outputs, dim=1)  # shape: (batch_size, seq_len, num_nodes, output_dim)

        return outputs


# 创建MegaCRN模型
model = MegaCRNModel(num_nodes=num_nodes, input_dim=embed_dim, hidden_dim=hidden_dim,
                     output_dim=1, num_layers=num_layers).to(device)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)  # (batch_size, seq_len, input_dim)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)  # (batch_size, output_dim)

        # 前向传播
        optimizer.zero_grad()
        outputs = model(x_batch)  # shape: (batch_size, seq_len, num_nodes, output_dim)
        pred = outputs[:, -1, 0, :]  # 取最后一个时间步、第一个节点的输出

        # 计算损失
        loss = criterion(pred, y_batch)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # 打印每个epoch的平均损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs = model(x_test)
        pred = outputs[:, -1, 0, :]  # 取最后一个时间步、第一个节点的输出

        # 计算损失
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")