import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import math



def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子

# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



# 超参数
window_size = 9  # 滑动窗口
embed_dim = 1  # 嵌入维度 = 特征数
hidden_dim = 64  # 隐藏层维度
num_layers = 2  # 层数
num_nodes = 1  # 节点数量，对于单变量时间序列设为1

batch_size = 256
num_epochs = 150
learning_rate = 1e-4




# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 加载数据
try:
    dataset_swat = pd.read_csv('../../data/datasets_ETT_afterProcess.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # Create some dummy data for testing
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)



# 划分训练集和测试集
train_size = int(len(features) * 0.6)  # 80% 作为训练集
train_features, test_features = features[:train_size], features[train_size:]


# # 重塑为二维
# train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
# test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)
#
# # 在数据划分后添加标准化逻辑
# scaler = StandardScaler()
# train_features = scaler.fit_transform(train_features)
# test_features = scaler.transform(test_features)


train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# ============== MegaCRN模型组件 ==============
class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, x, adj):
        support = torch.matmul(x, self.weight)
        output = torch.matmul(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output


class GCRUCell(nn.Module):
    def __init__(self, node_num, dim_in, dim_out):
        super(GCRUCell, self).__init__()
        self.node_num = node_num
        self.hidden_dim = dim_out

        # 更新门组件
        self.gc_update = GraphConvolution(dim_in + dim_out, dim_out)
        self.update_gate = nn.Sigmoid()

        # 重置门组件
        self.gc_reset = GraphConvolution(dim_in + dim_out, dim_out)
        self.reset_gate = nn.Sigmoid()

        # 候选隐藏状态组件
        self.gc_candidate = GraphConvolution(dim_in + dim_out, dim_out)
        self.candidate_activation = nn.Tanh()

        # 自适应邻接矩阵
        self.adaptive_adj = nn.Parameter(torch.FloatTensor(node_num, node_num))
        nn.init.xavier_uniform_(self.adaptive_adj)

    def forward(self, x, h):
        batch_size = x.size(0)

        # 确保邻接矩阵是对称的
        adj = F.relu(self.adaptive_adj)
        adj = 0.5 * (adj + adj.transpose(0, 1))

        # 扩展邻接矩阵以适应批次维度
        adj = adj.unsqueeze(0).repeat(batch_size, 1, 1)

        # 组合输入和隐藏状态
        combined = torch.cat([x, h], dim=2)

        # 更新门
        z = self.update_gate(self.gc_update(combined, adj))

        # 重置门
        r = self.reset_gate(self.gc_reset(combined, adj))

        # 候选隐藏状态
        combined_reset = torch.cat([x, r * h], dim=2)
        h_tilde = self.candidate_activation(self.gc_candidate(combined_reset, adj))

        # 新的隐藏状态
        h_new = (1 - z) * h + z * h_tilde

        return h_new


class DilatedInception(nn.Module):
    def __init__(self, cin, cout, dilation_factor=2):
        super(DilatedInception, self).__init__()
        self.tconv = nn.ModuleList()
        self.kernel_set = [2, 3, 6, 7]
        self.tconv = nn.ModuleList()

        for kern in self.kernel_set:
            self.tconv.append(nn.Conv2d(cin, cout, (1, kern), dilation=(1, dilation_factor)))

    def forward(self, x):
        x_out = []
        for i in range(len(self.kernel_set)):
            x_out.append(self.tconv[i](x))

        for i in range(len(self.kernel_set)):
            x_out[i] = x_out[i][..., -x.size(3):]

        x_out = torch.cat(x_out, dim=1)
        return x_out


class MultiScaleModule(nn.Module):
    def __init__(self, in_channels, out_channels, scales=3):
        super(MultiScaleModule, self).__init__()
        self.scales = scales
        self.dilated_layers = nn.ModuleList()

        for i in range(scales):
            dilation_factor = 2 ** i
            self.dilated_layers.append(DilatedInception(in_channels, out_channels // scales, dilation_factor))

    def forward(self, x):
        outputs = []
        for layer in self.dilated_layers:
            outputs.append(layer(x))

        return torch.cat(outputs, dim=1)


# ============== MegaCRN模型 ==============
class MegaCRNModel(nn.Module):
    def __init__(self, num_nodes, input_dim, hidden_dim, output_dim=1, num_layers=2):
        super(MegaCRNModel, self).__init__()
        self.num_nodes = num_nodes  # 节点数量
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers

        # 输入投影
        self.input_proj = nn.Linear(input_dim, hidden_dim)

        # 多尺度模块
        self.multi_scale = MultiScaleModule(hidden_dim, hidden_dim)

        # GRU单元
        self.gcru_cells = nn.ModuleList()
        for i in range(num_layers):
            self.gcru_cells.append(GCRUCell(num_nodes, hidden_dim, hidden_dim))

        # 输出层
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        """
        x: 输入张量，形状 (batch_size, seq_len, input_dim)
        """
        batch_size, seq_len, _ = x.size()

        # 使用单变量时间序列，我们将其视为一个节点
        x = x.reshape(batch_size, seq_len, self.num_nodes, -1)

        # 初始化隐藏状态
        h = torch.zeros(batch_size, self.num_nodes, self.hidden_dim).to(x.device)

        outputs = []

        # 对每个时间步进行处理
        for t in range(seq_len):
            # 当前时间步的输入
            xt = x[:, t, :, :]  # shape: (batch_size, num_nodes, input_dim)

            # 输入投影
            xt = self.input_proj(xt)  # shape: (batch_size, num_nodes, hidden_dim)

            # 通过所有GCRU层
            for layer in self.gcru_cells:
                h = layer(xt, h)

            # 输出层
            out = self.output_layer(h)  # shape: (batch_size, num_nodes, output_dim)
            outputs.append(out)

        # 堆叠时间步的输出
        outputs = torch.stack(outputs, dim=1)  # shape: (batch_size, seq_len, num_nodes, output_dim)

        return outputs


# 创建MegaCRN模型
model = MegaCRNModel(num_nodes=num_nodes, input_dim=embed_dim, hidden_dim=hidden_dim,
                     output_dim=1, num_layers=num_layers).to(device)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)  # (batch_size, seq_len, input_dim)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)  # (batch_size, output_dim)

        # 前向传播
        optimizer.zero_grad()
        outputs = model(x_batch)  # shape: (batch_size, seq_len, num_nodes, output_dim)
        pred = outputs[:, -1, 0, :]  # 取最后一个时间步、第一个节点的输出

        # 计算损失
        loss = criterion(pred, y_batch)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # 打印每个epoch的平均损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs = model(x_test)
        pred = outputs[:, -1, 0, :]  # 取最后一个时间步、第一个节点的输出

        # 计算损失
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")
