import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import math



def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子

# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
window_size = 9  # 滑动窗口
hidden_dim = 64  # ConvLSTM隐藏层通道数
kernel_size = 3  # 卷积核大小
num_layers = 2  # ConvLSTM层数
batch_size = 256
num_epochs = 150
learning_rate = 1e-4



# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 加载数据
try:
    dataset_swat = pd.read_csv('../../data/datasets_ETT_afterProcess.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # Create some dummy data for testing
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)



# 划分训练集和测试集
train_size = int(len(features) * 0.6)  # 80% 作为训练集
train_features, test_features = features[:train_size], features[train_size:]


# # 重塑为二维
# train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
# test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)
#
# # 在数据划分后添加标准化逻辑
# scaler = StandardScaler()
# train_features = scaler.fit_transform(train_features)
# test_features = scaler.transform(test_features)


train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# ConvLSTM单元实现
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        """
        初始化ConvLSTM单元
        Parameters
        ----------
        input_dim: int
            输入特征数量
        hidden_dim: int
            隐藏状态的通道数
        kernel_size: int
            卷积核大小
        bias: bool
            是否使用偏置
        """
        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.bias = bias

        # 为所有门构建一个卷积层
        self.conv = nn.Conv1d(
            in_channels=self.input_dim + self.hidden_dim,
            out_channels=4 * self.hidden_dim,  # 4个门：输入门，遗忘门，单元门，输出门
            kernel_size=self.kernel_size,
            padding=self.padding,
            bias=self.bias
        )

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # 拼接输入和隐藏状态
        combined = torch.cat([input_tensor, h_cur], dim=1)  # 沿通道维度拼接

        # 计算所有门的值
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        # 应用激活函数并计算下一个状态
        i = torch.sigmoid(cc_i)  # 输入门
        f = torch.sigmoid(cc_f)  # 遗忘门
        o = torch.sigmoid(cc_o)  # 输出门
        g = torch.tanh(cc_g)  # 候选单元状态

        c_next = f * c_cur + i * g  # 更新单元状态
        h_next = o * torch.tanh(c_next)  # 更新隐藏状态

        return h_next, c_next

    def init_hidden(self, batch_size, seq_len):
        # 返回初始化的隐藏状态和单元状态
        return (torch.zeros(batch_size, self.hidden_dim, seq_len, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, seq_len, device=self.conv.weight.device))


class ConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, batch_first=True):
        super(ConvLSTM, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first

        # 创建多层ConvLSTM
        cell_list = []
        for i in range(self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim
            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim,
                                          kernel_size=self.kernel_size,
                                          bias=True))

        self.cell_list = nn.ModuleList(cell_list)
        self.fc = nn.Linear(hidden_dim, 1)  # 输出层

    def forward(self, input_tensor, hidden_state=None):
        """
        Parameters
        ----------
        input_tensor: 5-D Tensor
            (batch, seq_len, channel, height, width) 或
            (seq_len, batch, channel, height, width) 取决于batch_first
        hidden_state: tuple
            (h, c) 每个为 num_layers 长的列表，列表中每个元素为 5-D Tensor
        Returns
        -------
        output: 5-D Tensor
            (batch, seq_len, channel, height, width) 或
            (seq_len, batch, channel, height, width) 取决于batch_first
        last_state: tuple
            (h, c) 每个为 num_layers 长的列表，列表中每个元素为 5-D Tensor
        """
        # 对输入数据进行重塑，使其变为时间序列的1D卷积形式
        batch_size, seq_len, _ = input_tensor.size()

        # 重塑为 (batch, channel, seq_len) 形式，适合1D卷积
        input_tensor = input_tensor.transpose(1, 2)

        # 初始化隐藏状态
        if hidden_state is None:
            hidden_state = []
            for i in range(self.num_layers):
                hidden_state.append(self.cell_list[i].init_hidden(batch_size, seq_len))

        layer_output_list = []
        last_state_list = []

        # 通过每个层
        h = input_tensor
        for layer_idx in range(self.num_layers):
            h_state, c_state = hidden_state[layer_idx]
            h_layer_out, c_layer_out = self.cell_list[layer_idx](h, (h_state, c_state))
            h = h_layer_out
            last_state_list.append((h_layer_out, c_layer_out))

        # 处理输出
        # 转换回原来的格式 (batch, seq_len, hidden_dim)
        output = h.transpose(1, 2)

        # 使用全连接层得到最终输出
        output = self.fc(output)

        return output, last_state_list


# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y

# 创建ConvLSTM模型
model = ConvLSTM(input_dim=1, hidden_dim=hidden_dim, kernel_size=kernel_size, num_layers=num_layers).to(device)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)  # (batch_size, seq_len, input_dim)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)  # (batch_size, output_dim)

        # 前向传播
        optimizer.zero_grad()
        outputs, _ = model(x_batch)
        pred = outputs[:, -1, :]  # 取最后一个时间步的输出

        # 计算损失
        loss = criterion(pred, y_batch)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # 打印每个epoch的平均损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs, _ = model(x_test)
        pred = outputs[:, -1, :]

        # 计算损失
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")