import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt



def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子


# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
window_size = 7  # 滑动窗口
embed_dim = 1  # 嵌入维度 = 特征数
num_channels = [8, 8, 8, 8]  # TCN通道数
kernel_size = 3  # 卷积核大小
dropout = 0.2  # Dropout比率

batch_size = 256
num_epochs = 100
learning_rate = 1e-3


# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# ============== TCN模型组件 ==============
class Chomp1d(nn.Module):
    """
    为因果卷积移除多余的padding
    """

    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    """
    TCN的基本块，包含两个因果卷积层
    """

    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.utils.weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                                    stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.utils.weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                                    stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)

        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    """
    TCN网络，由多个TemporalBlock组成
    """

    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i - 1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size - 1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class TCNModel(nn.Module):
    """
    完整的TCN模型
    """

    def __init__(self, input_size, output_size, num_channels, kernel_size, dropout):
        super(TCNModel, self).__init__()
        self.tcn = TemporalConvNet(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
        self.linear = nn.Linear(num_channels[-1], output_size)

    def forward(self, x):
        # x需要是 [batch_size, channels, seq_len] 的形式
        output = self.tcn(x)
        # 取最后一个时间步的输出
        output = output[:, :, -1]
        output = self.linear(output)
        return output


# 加载数据
try:
    dataset_swat = pd.read_csv('../../data/output-cnn.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # Create some dummy data for testing
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)

# 划分训练集和测试集
train_size = int(len(features) * 0.8)  # 80% 作为训练集
train_features, test_features = features[:train_size], features[train_size:]

# 重塑为二维
train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)

# 在数据划分后添加标准化逻辑
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
test_features = scaler.transform(test_features)

train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 创建TCN模型
model = TCNModel(input_size=1, output_size=1, num_channels=num_channels,
                 kernel_size=kernel_size, dropout=dropout).to(device)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        # TCN需要输入形状为 [batch_size, channels, seq_len]
        x_batch = x_batch.reshape(x_batch.shape[0], 1, window_size - 1)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)

        # 前向传播
        optimizer.zero_grad()
        outputs = model(x_batch)

        # 计算损失
        loss = criterion(outputs, y_batch)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # # 打印每个epoch的平均损失
    # if (epoch + 1) % 10 == 0:  # 每10个epoch打印一次
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        # TCN需要输入形状为 [batch_size, channels, seq_len]
        x_test = x_test.reshape(x_test.shape[0], 1, window_size - 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs = model(x_test)

        # 计算损失
        loss = criterion(outputs, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")
