import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import math



def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子

# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
window_size = 9  # 滑动窗口
embed_dim = 1  # 嵌入维度 = 特征数
hidden_dim = 64  # LSTM隐藏层维度
num_layers = 2  # LSTM层数

batch_size = 256
num_epochs = 150
learning_rate = 1e-4

# LSTNet 特有的超参数
conv_channels = 32  # CNN层的通道数
conv_kernel_size = 3  # CNN卷积核大小
recurrent_units = 64  # RNN层的单元数
skip_size = 3  # 跳跃RNN层的大小
skip_rnn_units = 16  # 跳跃RNN层的单元数
ar_window = 3  # 自回归窗口大小




# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 加载数据
try:
    dataset_swat = pd.read_csv('../../data/datasets_ETT_afterProcess.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # Create some dummy data for testing
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)



# 划分训练集和测试集
train_size = int(len(features) * 0.6)  # 80% 作为训练集
train_features, test_features = features[:train_size], features[train_size:]


# # 重塑为二维
# train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
# test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)
#
# # 在数据划分后添加标准化逻辑
# scaler = StandardScaler()
# train_features = scaler.fit_transform(train_features)
# test_features = scaler.transform(test_features)


train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)



# ============== LSTNet模型 ==============
class LSTNetModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, conv_channels, kernel_size, recurrent_units,
                 skip_size, skip_rnn_units, ar_window, output_dim=1):
        super(LSTNetModel, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.skip_size = skip_size
        self.ar_window = ar_window
        # 计算CNN后序列长度

        # CNN层，捕获短期模式
        self.conv1 = nn.Conv1d(input_dim, conv_channels, kernel_size)

        # 常规RNN层，捕获中期依赖性
        self.gru = nn.GRU(conv_channels, recurrent_units, batch_first=True)

        # 跳跃RNN层，捕获周期性和长期依赖性
        self.skip_rnn = nn.GRU(conv_channels, skip_rnn_units, batch_first=True)

        # 自回归组件，直接建模线性依赖关系
        self.ar_linear = nn.Linear(ar_window * input_dim, output_dim)

        # -----------------------------------------------------------------------------------------------------------------









        # 输出层，融合所有组件的输出
        # self.output_linear = nn.Linear(recurrent_units + skip_size * skip_rnn_units, output_dim)
        self.output_linear = nn.Linear(112, output_dim)
















        # -----------------------------------------------------------------------------------------------------------------

        # Dropout用于正则化
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        """
        x: 输入张量，形状 (batch_size, seq_len, input_dim)
        """
        batch_size = x.size(0)
        seq_len = x.size(1)

        # 准备自回归组件的输入
        if self.ar_window > 0:
            ar_x = x[:, -self.ar_window:, :].contiguous()
            ar_x = ar_x.view(batch_size, -1)  # 展平为 (batch_size, ar_window * input_dim)

        # CNN层 - 需要调整张量维度
        x_conv = x.permute(0, 2, 1)  # (batch_size, input_dim, seq_len)
        x_conv = self.conv1(x_conv)  # (batch_size, conv_channels, seq_len - kernel_size + 1)
        x_conv = F.relu(x_conv)
        x_conv = x_conv.permute(0, 2, 1)  # (batch_size, seq_len - kernel_size + 1, conv_channels)

        # GRU层
        out_gru, _ = self.gru(x_conv)  # (batch_size, seq_len - kernel_size + 1, recurrent_units)

        # 从GRU输出中提取最后一个时间步
        last_gru = out_gru[:, -1, :]  # (batch_size, recurrent_units)

        # 跳跃RNN层
        skip_gru_outputs = []
        if self.skip_size > 0:
            # 划分输入为不同部分，每部分间隔为skip_size
            s = 0
            while s < seq_len - self.kernel_size + 1 - self.skip_size:
                skip_input = x_conv[:, s:s + seq_len:self.skip_size, :]
                skip_out, _ = self.skip_rnn(skip_input)
                skip_gru_outputs.append(skip_out[:, -1, :])  # 每个部分取最后一个时间步
                s += 1

            # 如果有多个skip输出，连接它们
            if skip_gru_outputs:
                skip_concat = torch.cat(skip_gru_outputs, dim=1)  # (batch_size, num_skips * skip_rnn_units)
            else:
                skip_concat = torch.zeros(batch_size, 0, device=x.device)

            # 连接GRU输出和Skip-RNN输出
            combined = torch.cat([last_gru, skip_concat], dim=1)
        else:
            combined = last_gru

        # 应用Dropout
        combined = self.dropout(combined)

        # 连接所有输出并通过最终线性层
        out_final = self.output_linear(combined)

        # 如果启用了自回归组件，添加它的输出
        if self.ar_window > 0:
            out_ar = self.ar_linear(ar_x)
            out_final = out_final + out_ar

        return out_final.unsqueeze(1).unsqueeze(1)  # 调整形状以匹配之前的输出 (batch_size, 1, 1)


# 创建LSTNet模型
model = LSTNetModel(
    input_dim=1,
    hidden_dim=hidden_dim,
    conv_channels=conv_channels,
    kernel_size=conv_kernel_size,
    recurrent_units=recurrent_units,
    skip_size=skip_size,
    skip_rnn_units=skip_rnn_units,
    ar_window=ar_window
).to(device)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)  # (batch_size, seq_len, input_dim)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)  # (batch_size, output_dim)

        # 前向传播
        optimizer.zero_grad()
        outputs = model(x_batch)
        pred = outputs.squeeze()  # 适应LSTNet的输出形状
        y_batch = y_batch.squeeze()
        # print(pred.shape)

        # 计算损失
        loss = criterion(pred, y_batch)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # 打印每个epoch的平均损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs = model(x_test)
        pred = outputs.squeeze()
        y_test = y_test.squeeze()

        # 计算损失
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")
