import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt


def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子

# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 超参数
window_size = 7  # 滑动窗口
embed_dim = 1  # 嵌入维度 = 特征数
hidden_dim = 64  # LSTM隐藏层维度
num_layers = 2  # LSTM层数

batch_size = 256
num_epochs = 100
learning_rate = 1e-3

# LSTNet 特有的超参数
conv_channels = 32  # CNN层的通道数
conv_kernel_size = 3  # CNN卷积核大小
recurrent_units = 64  # RNN层的单元数
skip_size = 3  # 跳跃RNN层的大小
skip_rnn_units = 16  # 跳跃RNN层的单元数
ar_window = 3  # 自回归窗口大小


# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 加载数据
try:
    dataset_swat = pd.read_csv('../../data/output-cnn.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # Create some dummy data for testing
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)

# 划分训练集和测试集
train_size = int(len(features) * 0.8)  # 80% 作为训练集
train_features, test_features = features[:train_size], features[train_size:]

# 重塑为二维
train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)

# 在数据划分后添加标准化逻辑
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
test_features = scaler.transform(test_features)

train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# ============== LSTNet模型 ==============
class LSTNetModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, conv_channels, kernel_size, recurrent_units,
                 skip_size, skip_rnn_units, ar_window, output_dim=1):
        super(LSTNetModel, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.skip_size = skip_size
        self.ar_window = ar_window
        # 计算CNN后序列长度

        # CNN层，捕获短期模式
        self.conv1 = nn.Conv1d(input_dim, conv_channels, kernel_size)

        # 常规RNN层，捕获中期依赖性
        self.gru = nn.GRU(conv_channels, recurrent_units, batch_first=True)

        # 跳跃RNN层，捕获周期性和长期依赖性
        self.skip_rnn = nn.GRU(conv_channels, skip_rnn_units, batch_first=True)

        # 自回归组件，直接建模线性依赖关系
        self.ar_linear = nn.Linear(ar_window * input_dim, output_dim)

        # -----------------------------------------------------------------------------------------------------------------









        # 输出层，融合所有组件的输出
        # self.output_linear = nn.Linear(recurrent_units + skip_size * skip_rnn_units, output_dim)
        self.output_linear = nn.Linear(80, output_dim)
















        # -----------------------------------------------------------------------------------------------------------------

        # Dropout用于正则化
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        """
        x: 输入张量，形状 (batch_size, seq_len, input_dim)
        """
        batch_size = x.size(0)
        seq_len = x.size(1)

        # 准备自回归组件的输入
        if self.ar_window > 0:
            ar_x = x[:, -self.ar_window:, :].contiguous()
            ar_x = ar_x.view(batch_size, -1)  # 展平为 (batch_size, ar_window * input_dim)

        # CNN层 - 需要调整张量维度
        x_conv = x.permute(0, 2, 1)  # (batch_size, input_dim, seq_len)
        x_conv = self.conv1(x_conv)  # (batch_size, conv_channels, seq_len - kernel_size + 1)
        x_conv = F.relu(x_conv)
        x_conv = x_conv.permute(0, 2, 1)  # (batch_size, seq_len - kernel_size + 1, conv_channels)

        # GRU层
        out_gru, _ = self.gru(x_conv)  # (batch_size, seq_len - kernel_size + 1, recurrent_units)

        # 从GRU输出中提取最后一个时间步
        last_gru = out_gru[:, -1, :]  # (batch_size, recurrent_units)

        # 跳跃RNN层
        skip_gru_outputs = []
        if self.skip_size > 0:
            # 划分输入为不同部分，每部分间隔为skip_size
            s = 0
            while s < seq_len - self.kernel_size + 1 - self.skip_size:
                skip_input = x_conv[:, s:s + seq_len:self.skip_size, :]
                skip_out, _ = self.skip_rnn(skip_input)
                skip_gru_outputs.append(skip_out[:, -1, :])  # 每个部分取最后一个时间步
                s += 1

            # 如果有多个skip输出，连接它们
            if skip_gru_outputs:
                skip_concat = torch.cat(skip_gru_outputs, dim=1)  # (batch_size, num_skips * skip_rnn_units)
            else:
                skip_concat = torch.zeros(batch_size, 0, device=x.device)

            # 连接GRU输出和Skip-RNN输出
            combined = torch.cat([last_gru, skip_concat], dim=1)
        else:
            combined = last_gru

        # 应用Dropout
        combined = self.dropout(combined)

        # 连接所有输出并通过最终线性层
        out_final = self.output_linear(combined)

        # 如果启用了自回归组件，添加它的输出
        if self.ar_window > 0:
            out_ar = self.ar_linear(ar_x)
            out_final = out_final + out_ar

        return out_final.unsqueeze(1).unsqueeze(1)  # 调整形状以匹配之前的输出 (batch_size, 1, 1)


# 创建LSTNet模型
model = LSTNetModel(
    input_dim=1,
    hidden_dim=hidden_dim,
    conv_channels=conv_channels,
    kernel_size=conv_kernel_size,
    recurrent_units=recurrent_units,
    skip_size=skip_size,
    skip_rnn_units=skip_rnn_units,
    ar_window=ar_window
).to(device)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)  # (batch_size, seq_len, input_dim)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)  # (batch_size, output_dim)

        # 前向传播
        optimizer.zero_grad()
        outputs = model(x_batch)
        pred = outputs.squeeze()  # 适应LSTNet的输出形状
        y_batch = y_batch.squeeze()
        # print(pred.shape)

        # 计算损失
        loss = criterion(pred, y_batch)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # 打印每个epoch的平均损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs = model(x_test)
        pred = outputs.squeeze()
        y_test = y_test.squeeze()

        # 计算损失
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")
