import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt


def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子

# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
window_size = 7  # 滑动窗口
hidden_dim = 64  # ConvLSTM隐藏层通道数
kernel_size = 3  # 卷积核大小
num_layers = 2  # ConvLSTM层数
batch_size = 256
num_epochs = 100
learning_rate = 1e-3


# ConvLSTM单元实现
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        """
        初始化ConvLSTM单元
        Parameters
        ----------
        input_dim: int
            输入特征数量
        hidden_dim: int
            隐藏状态的通道数
        kernel_size: int
            卷积核大小
        bias: bool
            是否使用偏置
        """
        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.bias = bias

        # 为所有门构建一个卷积层
        self.conv = nn.Conv1d(
            in_channels=self.input_dim + self.hidden_dim,
            out_channels=4 * self.hidden_dim,  # 4个门：输入门，遗忘门，单元门，输出门
            kernel_size=self.kernel_size,
            padding=self.padding,
            bias=self.bias
        )

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # 拼接输入和隐藏状态
        combined = torch.cat([input_tensor, h_cur], dim=1)  # 沿通道维度拼接

        # 计算所有门的值
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        # 应用激活函数并计算下一个状态
        i = torch.sigmoid(cc_i)  # 输入门
        f = torch.sigmoid(cc_f)  # 遗忘门
        o = torch.sigmoid(cc_o)  # 输出门
        g = torch.tanh(cc_g)  # 候选单元状态

        c_next = f * c_cur + i * g  # 更新单元状态
        h_next = o * torch.tanh(c_next)  # 更新隐藏状态

        return h_next, c_next

    def init_hidden(self, batch_size, seq_len):
        # 返回初始化的隐藏状态和单元状态
        return (torch.zeros(batch_size, self.hidden_dim, seq_len, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, seq_len, device=self.conv.weight.device))


class ConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, batch_first=True):
        super(ConvLSTM, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first

        # 创建多层ConvLSTM
        cell_list = []
        for i in range(self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim
            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim,
                                          kernel_size=self.kernel_size,
                                          bias=True))

        self.cell_list = nn.ModuleList(cell_list)
        self.fc = nn.Linear(hidden_dim, 1)  # 输出层

    def forward(self, input_tensor, hidden_state=None):
        """
        Parameters
        ----------
        input_tensor: 5-D Tensor
            (batch, seq_len, channel, height, width) 或
            (seq_len, batch, channel, height, width) 取决于batch_first
        hidden_state: tuple
            (h, c) 每个为 num_layers 长的列表，列表中每个元素为 5-D Tensor
        Returns
        -------
        output: 5-D Tensor
            (batch, seq_len, channel, height, width) 或
            (seq_len, batch, channel, height, width) 取决于batch_first
        last_state: tuple
            (h, c) 每个为 num_layers 长的列表，列表中每个元素为 5-D Tensor
        """
        # 对输入数据进行重塑，使其变为时间序列的1D卷积形式
        batch_size, seq_len, _ = input_tensor.size()

        # 重塑为 (batch, channel, seq_len) 形式，适合1D卷积
        input_tensor = input_tensor.transpose(1, 2)

        # 初始化隐藏状态
        if hidden_state is None:
            hidden_state = []
            for i in range(self.num_layers):
                hidden_state.append(self.cell_list[i].init_hidden(batch_size, seq_len))

        layer_output_list = []
        last_state_list = []

        # 通过每个层
        h = input_tensor
        for layer_idx in range(self.num_layers):
            h_state, c_state = hidden_state[layer_idx]
            h_layer_out, c_layer_out = self.cell_list[layer_idx](h, (h_state, c_state))
            h = h_layer_out
            last_state_list.append((h_layer_out, c_layer_out))

        # 处理输出
        # 转换回原来的格式 (batch, seq_len, hidden_dim)
        output = h.transpose(1, 2)

        # 使用全连接层得到最终输出
        output = self.fc(output)

        return output, last_state_list


# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 加载数据
try:
    dataset_swat = pd.read_csv('../../data/output-cnn.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # Create some dummy data for testing
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)

# 划分训练集和测试集
train_size = int(len(features) * 0.8)  # 80% 作为训练集
train_features, test_features = features[:train_size], features[train_size:]


# 重塑为二维
train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)

# 在数据划分后添加标准化逻辑
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
test_features = scaler.transform(test_features)



train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 创建ConvLSTM模型
model = ConvLSTM(input_dim=1, hidden_dim=hidden_dim, kernel_size=kernel_size, num_layers=num_layers).to(device)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)  # (batch_size, seq_len, input_dim)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)  # (batch_size, output_dim)

        # 前向传播
        optimizer.zero_grad()
        outputs, _ = model(x_batch)
        pred = outputs[:, -1, :]  # 取最后一个时间步的输出

        # 计算损失
        loss = criterion(pred, y_batch)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # 打印每个epoch的平均损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs, _ = model(x_test)
        pred = outputs[:, -1, :]

        # 计算损失
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")