import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import math



def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子

# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



# 超参数
window_size = 9  # 滑动窗口
embed_dim = 64  # 嵌入维度
hidden_dim = 64  # 隐藏层维度
num_layers = 2  # 编码器层数
n_heads = 8  # 多头注意力中的头数
dropout = 0.1
d_ff = 128  # 前馈神经网络维度

batch_size = 256
num_epochs = 150
learning_rate = 1e-4




# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 加载数据
try:
    dataset_swat = pd.read_csv('../../data/datasets_ETT_afterProcess.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # Create some dummy data for testing
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)



# 划分训练集和测试集
train_size = int(len(features) * 0.6)  # 80% 作为训练集
train_features, test_features = features[:train_size], features[train_size:]


# # 重塑为二维
# train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
# test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)
#
# # 在数据划分后添加标准化逻辑
# scaler = StandardScaler()
# train_features = scaler.fit_transform(train_features)
# test_features = scaler.transform(test_features)


train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# 超参数
window_size = 9
embed_dim = 1
hidden_dim = 64
stacks = 2
levels = 3
dropout = 0.5

batch_size = 256
num_epochs = 150
learning_rate = 1e-4


# 完全重写的SCINet模型实现
class SCIBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dropout=0.5):
        super(SCIBlock, self).__init__()

        # 交互学习单元
        self.phi = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(dropout)
        )

        self.psi = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(dropout)
        )

        # 特征提取
        self.even_conv = nn.Conv1d(out_channels, out_channels, kernel_size=1)
        self.odd_conv = nn.Conv1d(out_channels, out_channels, kernel_size=1)

        # 残差连接
        if in_channels != out_channels:
            self.residual = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        else:
            self.residual = nn.Identity()

    def forward(self, x):
        batch_size, channels, seq_len = x.shape

        # 确保输入序列长度为偶数
        is_odd = (seq_len % 2 == 1)
        if is_odd:
            x = F.pad(x, (0, 1), "replicate")  # 填充一个值使序列长度为偶数
            seq_len += 1

        # 分解为偶数和奇数索引
        even_indices = torch.arange(0, seq_len, 2).to(x.device)
        odd_indices = torch.arange(1, seq_len, 2).to(x.device)

        x_even = x[:, :, even_indices]
        x_odd = x[:, :, odd_indices]

        # 应用交互学习
        phi_even = self.phi(x_even)
        psi_odd = self.psi(x_odd)

        # 特征增强与交互
        even_out = self.even_conv(phi_even)
        odd_out = self.odd_conv(psi_odd)

        # 添加交互信息
        even_with_odd = even_out + psi_odd
        odd_with_even = odd_out + phi_even

        # 重建序列
        z = torch.zeros((batch_size, channels, seq_len), device=x.device)
        z[:, :, even_indices] = even_with_odd
        z[:, :, odd_indices] = odd_with_even

        # 如果原序列为奇数长度，去掉填充
        if is_odd:
            z = z[:, :, :-1]
            x = x[:, :, :-1]

        # 添加残差连接
        return z + self.residual(x)


class SCINet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim=1, stacks=2, levels=3, dropout=0.5):
        super(SCINet, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.stacks = stacks
        self.levels = levels

        # 输入投影
        self.input_projection = nn.Conv1d(input_dim, hidden_dim, kernel_size=1)

        # 构建SCINet堆叠结构
        self.blocks = nn.ModuleList()
        for _ in range(stacks):
            for _ in range(levels):
                self.blocks.append(SCIBlock(hidden_dim, hidden_dim, dropout=dropout))

        # 输出投影
        self.output_projection = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim // 2, kernel_size=1),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim // 2, output_dim, kernel_size=1)
        )

    def forward(self, x):
        # 输入 x 的形状: (batch_size, seq_len, input_dim)
        # 转换为 Conv1d 的输入格式: (batch_size, input_dim, seq_len)
        x = x.transpose(1, 2)

        # 输入投影
        x = self.input_projection(x)

        # 通过所有SCIBlock
        for block in self.blocks:
            x = block(x)

        # 输出投影
        out = self.output_projection(x)

        # 转换回原始格式: (batch_size, seq_len, output_dim)
        out = out.transpose(1, 2)

        return out



# 创建模型
model = SCINet(
    input_dim=1,
    hidden_dim=hidden_dim,
    output_dim=1,
    stacks=stacks,
    levels=levels,
    dropout=dropout
).to(device)

# 打印模型结构（可选）
print(model)

# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练循环
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)  # (batch_size, seq_len, input_dim)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)  # (batch_size, output_dim)

        # 前向传播
        optimizer.zero_grad()
        outputs = model(x_batch)
        pred = outputs[:, -1, :]  # 取最后一个时间步的输出

        # 计算损失
        loss = criterion(pred, y_batch)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # 打印进度
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

predictions = []
actuals = []

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs = model(x_test)
        pred = outputs[:, -1, :]

        # 收集预测和实际值
        predictions.extend(pred.cpu().numpy())
        actuals.extend(y_test.cpu().numpy())

        # 计算损失
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")