import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler


def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子，例如42

# ------------------------------------------------------------------------------
length_pattern = 9
# batch = 256

save_path = "Ablation experiments/base_model_9.pth"

# 超参数
window_size = 9  # 滑动窗口
embed_dim = 64  # 嵌入维度 = 特征数

print("特征数：", embed_dim)
num_heads = 4
batch_size = 256
num_epochs = 100
learning_rate = 1e-3

# ------------------------------------------------------------------------------


# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 存储频繁模式和支持度
frequent_patterns = []

# 读取文件
with open("../data/OPR-output-cnn-450/OPR-Miner-Output.txt", "r",
          encoding="GBK") as file:
    for line in file:
        match = re.match(r"频繁模式：([\d,]+), 支持度为：(\d+)", line.strip())
        if match:
            pattern = tuple(map(int, match.group(1).split(',')))  # 将模式转换为元组
            support = int(match.group(2))  # 支持度转换为整数
            frequent_patterns.append((pattern, support))

# 打印结果
for pattern, support in frequent_patterns:
    print(f"模式: {pattern}, 支持度: {support}")

print(frequent_patterns)

# 过滤模式长度为 3 的项
patterns_length_n = [(pattern, support) for pattern, support in frequent_patterns if len(pattern) == length_pattern]

length_pattern_son = length_pattern - 1
patterns_length_n_son = [(pattern, support) for pattern, support in frequent_patterns if
                         len(pattern) == length_pattern_son]

# 输出结果
for pattern, support in patterns_length_n_son:
    print(f"模式: {pattern}, 支持度: {support}")
print(patterns_length_n_son)

# 将频繁保序模式的元素全部 -1
adjusted_patterns_son = [((tuple(x - 1 for x in pattern)), support) for pattern, support in patterns_length_n_son]

# 输出转换后的结果
print(adjusted_patterns_son)

# 计算 softmax 归一化权重
support_values = np.array([sup for _, sup in adjusted_patterns_son])

# 输出结果
for pattern, support in patterns_length_n:
    print(f"模式: {pattern}, 支持度: {support}")
print(patterns_length_n)

# 将频繁保序模式的元素全部 -1
adjusted_patterns = [((tuple(x - 1 for x in pattern)), support) for pattern, support in patterns_length_n]

# 输出转换后的结果
print(adjusted_patterns)

# 计算 softmax 归一化权重
support_values = np.array([sup for _, sup in adjusted_patterns])

# 归一化处理
total_sum = np.sum(support_values)
normalized_values = support_values / total_sum
print("normalized_values:", normalized_values)

support_List = normalized_values

val = np.exp(normalized_values)
print("val:", val)

softmax_weights = normalized_values

print(support_values)
print(softmax_weights)

# B = torch.tensor(B, dtype=torch.float32).to(device)


dataset_swat = pd.read_csv('../data/output-cnn.txt', encoding='utf-8',
                           low_memory=False)

# 预处理数据：假设前 N-1 列是特征，最后一列是标签
feature_cols = dataset_swat.columns[0]

features = dataset_swat[feature_cols].values  # (num_samples, num_features)

print("feature.shape", features.shape)

pattern_list = np.array([pat for pat, _ in adjusted_patterns])
print("pattern_list:", pattern_list)

print("pattern_list_gaibian:", pattern_list[:, 0:-1])


# ============== 修改数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        # self.features = torch.tensor(features, dtype=torch.float32)
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 假设 features, labels 是 numpy 数组
train_size = int(len(features) * 0.8)  # 80% 作为训练集
test_size = len(features) - train_size

train_features, test_features = features[:train_size], features[train_size:]

# 重塑为二维
train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)

# 在数据划分后添加标准化逻辑
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
test_features = scaler.transform(test_features)

train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


class OpTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super(OpTransformer, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size

        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
        self.head_dim = embed_dim // num_heads  # 每个头的维度

        # 线性变换层，用于生成 Q, K, V
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)

        # 最终的输出变换
        self.W_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        """
        x: 输入张量，形状 (batch_size, seq_len, embed_dim)
        """
        batch_size, seq_len, embed_dim = x.shape

        # 计算 Q, K, V
        Q = self.W_q(x)  # (batch_size, seq_len, embed_dim)
        K = self.W_k(x)  # (batch_size, seq_len, embed_dim)
        V = self.W_v(x)  # (batch_size, seq_len, embed_dim)

        # 重新调整形状，适应多头注意力
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)

        # 计算注意力分数 (QK^T / sqrt(d))
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(
            self.head_dim)  # (batch_size, num_heads, seq_len, seq_len)

        # 计算 softmax 归一化的注意力权重
        attention_weights = F.softmax(attention_scores, dim=-1)  # (batch_size, num_heads, seq_len, seq_len)

        # 计算最终的注意力输出
        attention_output = torch.matmul(attention_weights, V)  # (batch_size, num_heads, seq_len, head_dim)

        # 调整形状，合并多个头
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len,
                                                                              self.embed_dim)  # (batch_size, seq_len, embed_dim)

        # 通过最终的线性变换层
        output = self.W_out(attention_output)  # (batch_size, seq_len, embed_dim)

        return output, attention_weights


class OpTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, input_dim=1):
        super(OpTransformer, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.input_dim = input_dim

        # Input dimension expansion layer
        self.input_expansion = nn.Linear(input_dim, embed_dim)

        # Output dimension compression layer
        self.output_compression = nn.Linear(embed_dim, input_dim)

        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
        self.head_dim = embed_dim // num_heads  # 每个头的维度

        # 线性变换层，用于生成 Q, K, V
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)

        # 最终的输出变换
        self.W_out = nn.Linear(embed_dim, embed_dim)
        # self.W_out1 = nn.Linear(embed_dim, 100)
        # self.W_out2 = nn.Linear(100, 50)
        # self.W_out3 = nn.Linear(50, embed_dim)

    def forward(self, x):
        """
        x: 输入张量，形状 (batch_size, seq_len, input_dim)
        """
        batch_size, seq_len, _ = x.shape

        # 扩展输入维度
        x_expanded = self.input_expansion(x)  # (batch_size, seq_len, embed_dim)

        # 计算 Q, K, V
        Q = self.W_q(x_expanded)  # (batch_size, seq_len, embed_dim)
        K = self.W_k(x_expanded)  # (batch_size, seq_len, embed_dim)
        V = self.W_v(x_expanded)  # (batch_size, seq_len, embed_dim)

        # 重新调整形状，适应多头注意力
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)

        # 计算注意力分数 (QK^T / sqrt(d))
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(
            self.head_dim)  # (batch_size, num_heads, seq_len, seq_len)

        # 计算 softmax 归一化的注意力权重
        attention_weights = F.softmax(attention_scores, dim=-1)  # (batch_size, num_heads, seq_len, seq_len)

        # 计算最终的注意力输出
        attention_output = torch.matmul(attention_weights, V)  # (batch_size, num_heads, seq_len, head_dim)

        # 调整形状，合并多个头
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len,
                                                                              self.embed_dim)  # (batch_size, seq_len, embed_dim)

        # 通过最终的线性变换层
        output = self.W_out(attention_output)  # (batch_size, seq_len, embed_dim)
        # 通过最终的线性变换层
        # output = F.relu(self.W_out1(attention_output))  # (batch_size, seq_len, embed_dim)
        # output = F.relu(self.W_out2(output))  # (batch_size, seq_len, embed_dim)
        # output = self.W_out3(output)  # (batch_size, seq_len, embed_dim)

        # 压缩回原始维度
        compressed_output = self.output_compression(output)  # (batch_size, seq_len, input_dim)

        return compressed_output, attention_weights


# 创建 Transformer
model = OpTransformer(embed_dim, num_heads, window_size).to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    Mseloss = []
    Myloss = []

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)

        optimizer.zero_grad()
        outputs, _ = model(x_batch)

        pred = outputs[:, -1, :]

        # 计算损失
        loss = criterion(pred, y_batch)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # 打印每个epoch的平均损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs, _ = model(x_test)
        pred = outputs[:, -1, :]

        # 计算损失e
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")

torch.save(model.state_dict(), save_path)
