import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt


def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子

# ================================================= 超参数 ========================================================
# 超参数
save_path = "Ablation experiments/base_model_7.pth"
window_size = 7  # 滑动窗口
embed_dim = 64  # 嵌入维度 = 特征数

print("特征数：", embed_dim)
num_heads = 4
batch_size = 256
num_epochs = 150
learning_rate = 1e-4
# ================================================= 超参数 ========================================================


# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 读取频繁模式和支持度
frequent_patterns = []
try:
    with open("../data/OPR-output-ETT-450/OPR-Miner-Output.txt", "r",
              encoding="GBK") as file:
        for line in file:
            match = re.match(r"频繁模式：([\d,]+), 支持度为：(\d+)", line.strip())
            if match:
                pattern = tuple(map(int, match.group(1).split(',')))  # 将模式转换为元组
                support = int(match.group(2))  # 支持度转换为整数
                frequent_patterns.append((pattern, support))
    print("Success reading frequent patterns")
except Exception as e:
    print(f"Failed to read frequent patterns: {e}")
    frequent_patterns = []


# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 加载数据
try:
    dataset_swat = pd.read_csv('../data/datasets_ETT_afterProcess.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # Create some dummy data for testing
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)

# 划分训练集和测试集
train_size = int(len(features) * 0.6)  # 80% 作为训练集
train_features, test_features = features[:train_size], features[train_size:]

# # 重塑为二维
# train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
# test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)
#
# # 在数据划分后添加标准化逻辑
# scaler = StandardScaler()
# train_features = scaler.fit_transform(train_features)
# test_features = scaler.transform(test_features)


train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


class OpTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, input_dim=1):
        super(OpTransformer, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.input_dim = input_dim

        # Input dimension expansion layer
        self.input_expansion = nn.Linear(input_dim, embed_dim)

        # Output dimension compression layer
        self.output_compression = nn.Linear(embed_dim, input_dim)

        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
        self.head_dim = embed_dim // num_heads  # 每个头的维度

        # 线性变换层，用于生成 Q, K, V
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)

        # 最终的输出变换
        # self.W_out = nn.Linear(embed_dim, embed_dim)
        # 最终的输出变换
        self.W_out1 = nn.Linear(embed_dim, 100)
        self.W_out2 = nn.Linear(100, 50)
        self.W_out3 = nn.Linear(50, embed_dim)

    def forward(self, x):
        """
        x: 输入张量，形状 (batch_size, seq_len, input_dim)
        """
        batch_size, seq_len, _ = x.shape

        # 扩展输入维度
        x_expanded = self.input_expansion(x)  # (batch_size, seq_len, embed_dim)

        # 计算 Q, K, V
        Q = self.W_q(x_expanded)  # (batch_size, seq_len, embed_dim)
        K = self.W_k(x_expanded)  # (batch_size, seq_len, embed_dim)
        V = self.W_v(x_expanded)  # (batch_size, seq_len, embed_dim)

        # 重新调整形状，适应多头注意力
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)

        # 计算注意力分数 (QK^T / sqrt(d))
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(
            self.head_dim)  # (batch_size, num_heads, seq_len, seq_len)

        # 计算 softmax 归一化的注意力权重
        attention_weights = F.softmax(attention_scores, dim=-1)  # (batch_size, num_heads, seq_len, seq_len)

        # 计算最终的注意力输出
        attention_output = torch.matmul(attention_weights, V)  # (batch_size, num_heads, seq_len, head_dim)

        # 调整形状，合并多个头
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len,
                                                                              self.embed_dim)  # (batch_size, seq_len, embed_dim)

        # 通过最终的线性变换层
        output = F.relu(self.W_out1(attention_output))  # (batch_size, seq_len, embed_dim)
        output = F.relu(self.W_out2(output))  # (batch_size, seq_len, embed_dim)
        output = self.W_out3(output)  # (batch_size, seq_len, embed_dim)

        # 压缩回原始维度
        compressed_output = self.output_compression(output)  # (batch_size, seq_len, input_dim)

        return compressed_output, attention_weights


# 创建 Transformer
model = OpTransformer(embed_dim, num_heads, window_size).to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    Mseloss = []
    Myloss = []

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)

        optimizer.zero_grad()
        outputs, _ = model(x_batch)

        pred = outputs[:, -1, :]

        # 计算损失
        loss = criterion(pred, y_batch)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # 打印每个epoch的平均损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs, _ = model(x_test)
        pred = outputs[:, -1, :]

        # 计算损失e
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")

torch.save(model.state_dict(), save_path)
