import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(1234)

# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 超参数
window_size = 9
embed_dim = 1
hidden_dim = 64
stacks = 2
levels = 3
dropout = 0.5

batch_size = 256
num_epochs = 100
learning_rate = 1e-3


# 完全重写的SCINet模型实现
class SCIBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dropout=0.5):
        super(SCIBlock, self).__init__()

        # 交互学习单元
        self.phi = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(dropout)
        )

        self.psi = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(dropout)
        )

        # 特征提取
        self.even_conv = nn.Conv1d(out_channels, out_channels, kernel_size=1)
        self.odd_conv = nn.Conv1d(out_channels, out_channels, kernel_size=1)

        # 残差连接
        if in_channels != out_channels:
            self.residual = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        else:
            self.residual = nn.Identity()

    def forward(self, x):
        batch_size, channels, seq_len = x.shape

        # 确保输入序列长度为偶数
        is_odd = (seq_len % 2 == 1)
        if is_odd:
            x = F.pad(x, (0, 1), "replicate")  # 填充一个值使序列长度为偶数
            seq_len += 1

        # 分解为偶数和奇数索引
        even_indices = torch.arange(0, seq_len, 2).to(x.device)
        odd_indices = torch.arange(1, seq_len, 2).to(x.device)

        x_even = x[:, :, even_indices]
        x_odd = x[:, :, odd_indices]

        # 应用交互学习
        phi_even = self.phi(x_even)
        psi_odd = self.psi(x_odd)

        # 特征增强与交互
        even_out = self.even_conv(phi_even)
        odd_out = self.odd_conv(psi_odd)

        # 添加交互信息
        even_with_odd = even_out + psi_odd
        odd_with_even = odd_out + phi_even

        # 重建序列
        z = torch.zeros((batch_size, channels, seq_len), device=x.device)
        z[:, :, even_indices] = even_with_odd
        z[:, :, odd_indices] = odd_with_even

        # 如果原序列为奇数长度，去掉填充
        if is_odd:
            z = z[:, :, :-1]
            x = x[:, :, :-1]

        # 添加残差连接
        return z + self.residual(x)


class SCINet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim=1, stacks=2, levels=3, dropout=0.5):
        super(SCINet, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.stacks = stacks
        self.levels = levels

        # 输入投影
        self.input_projection = nn.Conv1d(input_dim, hidden_dim, kernel_size=1)

        # 构建SCINet堆叠结构
        self.blocks = nn.ModuleList()
        for _ in range(stacks):
            for _ in range(levels):
                self.blocks.append(SCIBlock(hidden_dim, hidden_dim, dropout=dropout))

        # 输出投影
        self.output_projection = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim // 2, kernel_size=1),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim // 2, output_dim, kernel_size=1)
        )

    def forward(self, x):
        # 输入 x 的形状: (batch_size, seq_len, input_dim)
        # 转换为 Conv1d 的输入格式: (batch_size, input_dim, seq_len)
        x = x.transpose(1, 2)

        # 输入投影
        x = self.input_projection(x)

        # 通过所有SCIBlock
        for block in self.blocks:
            x = block(x)

        # 输出投影
        out = self.output_projection(x)

        # 转换回原始格式: (batch_size, seq_len, output_dim)
        out = out.transpose(1, 2)

        return out


# 数据集类
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size

    def __len__(self):
        return len(self.features) - self.window_size

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 加载数据
try:
    dataset_swat = pd.read_csv('../../data/output-cnn.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # 创建测试用的假数据
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)

# 划分训练集和测试集
train_size = int(len(features) * 0.8)
train_features, test_features = features[:train_size], features[train_size:]

# 重塑为二维
train_features = train_features.reshape(-1, 1)
test_features = test_features.reshape(-1, 1)

# 标准化
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
test_features = scaler.transform(test_features)

train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建数据集和加载器
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 创建模型
model = SCINet(
    input_dim=1,
    hidden_dim=hidden_dim,
    output_dim=1,
    stacks=stacks,
    levels=levels,
    dropout=dropout
).to(device)

# 打印模型结构（可选）
print(model)

# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练循环
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)  # (batch_size, seq_len, input_dim)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)  # (batch_size, output_dim)

        # 前向传播
        optimizer.zero_grad()
        outputs = model(x_batch)
        pred = outputs[:, -1, :]  # 取最后一个时间步的输出

        # 计算损失
        loss = criterion(pred, y_batch)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # 打印进度
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

predictions = []
actuals = []

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs = model(x_test)
        pred = outputs[:, -1, :]

        # 收集预测和实际值
        predictions.extend(pred.cpu().numpy())
        actuals.extend(y_test.cpu().numpy())

        # 计算损失
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")


# # 可视化预测结果
# def visualize_predictions(predictions, actuals, n_samples=100):
#     plt.figure(figsize=(12, 6))
#
#     # 取前n_samples个样本进行可视化
#     pred_sample = np.array(predictions[:n_samples]).flatten()
#     actual_sample = np.array(actuals[:n_samples]).flatten()
#
#     # 反标准化（如果需要）
#     pred_sample = scaler.inverse_transform(pred_sample.reshape(-1, 1)).flatten()
#     actual_sample = scaler.inverse_transform(actual_sample.reshape(-1, 1)).flatten()
#
#     plt.plot(actual_sample, label='Actual')
#     plt.plot(pred_sample, label='Predicted')
#     plt.legend()
#     plt.title('SCINet Model: Predictions vs Actual Values')
#     plt.xlabel('Sample Index')
#     plt.ylabel('Value')
#     plt.tight_layout()
#     plt.savefig('scinet_predictions.png')
#     plt.show()
#
# # 取消注释以下行来可视化结果
# # visualize_predictions(predictions, actuals)