import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import time


def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子，例如42

# ------------------------------------------------------------------------------
# 超参数
save_path = "Ablation experiments/pattern_model_9.pth"
length_pattern = 9
window_size = 9  # 滑动窗口
embed_dim = 64  # 嵌入维度 = 特征数

print("特征数：", embed_dim)
num_heads = 4
batch_size = 256
num_epochs = 100
learning_rate = 1e-3

# ------------------------------------------------------------------------------


# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 存储频繁模式和支持度
frequent_patterns = []

# 读取文件
with open("../data/OPR-output-cnn-450/OPR-Miner-Output.txt", "r",
          encoding="GBK") as file:
    for line in file:
        match = re.match(r"频繁模式：([\d,]+), 支持度为：(\d+)", line.strip())
        if match:
            pattern = tuple(map(int, match.group(1).split(',')))  # 将模式转换为元组
            support = int(match.group(2))  # 支持度转换为整数
            frequent_patterns.append((pattern, support))

# 打印结果
for pattern, support in frequent_patterns:
    print(f"模式: {pattern}, 支持度: {support}")

print(frequent_patterns)

# 过滤模式长度为 3 的项
patterns_length_n = [(pattern, support) for pattern, support in frequent_patterns if len(pattern) == length_pattern]

length_pattern_son = length_pattern - 1
patterns_length_n_son = [(pattern, support) for pattern, support in frequent_patterns if
                         len(pattern) == length_pattern_son]

# 输出结果
for pattern, support in patterns_length_n_son:
    print(f"模式: {pattern}, 支持度: {support}")
print(patterns_length_n_son)

# 将频繁保序模式的元素全部 -1
adjusted_patterns_son = [((tuple(x - 1 for x in pattern)), support) for pattern, support in patterns_length_n_son]

# 输出转换后的结果
print(adjusted_patterns_son)

# 计算 softmax 归一化权重
support_values = np.array([sup for _, sup in adjusted_patterns_son])

# 输出结果
for pattern, support in patterns_length_n:
    print(f"模式: {pattern}, 支持度: {support}")
print(patterns_length_n)

# 将频繁保序模式的元素全部 -1
adjusted_patterns = [((tuple(x - 1 for x in pattern)), support) for pattern, support in patterns_length_n]

# 输出转换后的结果
print(adjusted_patterns)

# 计算 softmax 归一化权重
support_values = np.array([sup for _, sup in adjusted_patterns])

# 归一化处理
total_sum = np.sum(support_values)
normalized_values = support_values / total_sum
print("normalized_values:", normalized_values)

support_List = normalized_values

val = np.exp(normalized_values)
print("val:", val)

softmax_weights = normalized_values

print(support_values)
print(softmax_weights)

dataset_swat = pd.read_csv('../data/output-cnn.txt', encoding='utf-8',
                           low_memory=False)

# 预处理数据：假设前 N-1 列是特征，最后一列是标签
feature_cols = dataset_swat.columns[0]

features = dataset_swat[feature_cols].values  # (num_samples, num_features)

print("feature.shape", features.shape)

pattern_list = np.array([pat for pat, _ in adjusted_patterns])

print("pattern_list_gaibian:", pattern_list[:, 0:-1])


def constrained_loss(pred_mse, pred_pattern, pred_pattern_mse, data_input_pattern, target_mse,
                     target_pattern, pattern_list, support_List, epsilon=0.00001):
    original_mseloss = nn.MSELoss()(pred_mse, target_mse)

    maxyou = False
    minzuo = False
    midzhong = False
    constraint_loss = 0.
    temp = -1
    tempzuo = -1
    tempyou = -1
    surport_count = 0

    # print(pattern_list)
    maxbord = len(pattern_list[0]) - 1

    for single_pattern in pattern_list:
        if single_pattern[len(single_pattern) - 1] == 0:
            count = 0
            for i in single_pattern:
                if i == 1:
                    temp = count
                minzuo = True
                count = count + 1
        elif single_pattern[len(single_pattern) - 1] == maxbord:
            count = 0
            for j in single_pattern:
                # if j == 4:
                if j == len(single_pattern) - 2:
                    temp = count
                maxyou = True
                count = count + 1
        else:
            count = 0
            for m in single_pattern:
                if m == single_pattern[len(single_pattern) - 1] - 1:
                    tempzuo = count
                if m == single_pattern[len(single_pattern) - 1] + 1:
                    tempyou = count
                midzhong = True
                count = count + 1

        if minzuo == True:
            ref_val = data_input_pattern[:, temp]  # 这段可能需要改
            # 反向约束：pred_T6 < ref_val - epsilon

            margin = torch.clamp(pred_pattern - ref_val + epsilon, min=0)
            lambda_surp = support_List[surport_count]
            constraint_loss += (lambda_surp * margin).mean()

        if maxyou == True:
            ref_val = data_input_pattern[:, temp]  # 这段可能需要改
            # 正向约束：pred_T6 > ref_val + epsilon

            margin = torch.clamp(ref_val - pred_pattern + epsilon, min=0)
            lambda_surp = support_List[surport_count]
            constraint_loss += (lambda_surp * margin).mean()

        if midzhong == True:
            ref_valzuo = data_input_pattern[:, tempzuo]  # 这段可能需要改

            ref_valyou = data_input_pattern[:, tempyou]  # 这段可能需要改

            # 左边界约束：pred > ref_left + epsilon

            loss_left = torch.clamp(ref_valzuo + epsilon - pred_pattern, min=0)

            # 右边界约束：pred < ref_right + epsilon

            loss_right = torch.clamp(pred_pattern - ref_valyou + epsilon, min=0)

            lambda_surp = support_List[surport_count]

            # 计算约束损失
            constraint_loss += (lambda_surp * (loss_left + loss_right)).mean()

        maxyou = False
        minzuo = False
        midzhong = False
        surport_count = surport_count + 1

    if pred_pattern_mse.numel() == 0 or target_pattern.numel() == 0:
        pattern_mseloss = torch.tensor(0.0, device=pred_pattern_mse.device)  # 直接赋值 0
        constraint_loss = torch.tensor(0.0, device=pred_pattern_mse.device)
    else:
        pattern_mseloss = nn.MSELoss()(pred_pattern_mse, target_pattern)

    pattern_loss = 0.001 * pattern_mseloss + constraint_loss

    total_loss = original_mseloss + pattern_loss

    # return total_loss, total_mseloss, pattern_loss
    return total_loss, pattern_mseloss, original_mseloss


# --------------------------------------------------------------------------------------------------------------------------------------------------------------
# --------------------------------------------------------------------------------------------------------------------------------------------------------------


# ============== 修改数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        # self.features = torch.tensor(features, dtype=torch.float32)
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 假设 features, labels 是 numpy 数组
train_size = int(len(features) * 0.8)  # 80% 作为训练集
test_size = len(features) - train_size

train_features, test_features = features[:train_size], features[train_size:]

# 重塑为二维
train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)

# 在数据划分后添加标准化逻辑
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
test_features = scaler.transform(test_features)

train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


class OpTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, input_dim=1):
        super(OpTransformer, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.input_dim = input_dim

        # Input dimension expansion layer
        self.input_expansion = nn.Linear(input_dim, embed_dim)

        # Output dimension compression layer
        self.output_compression = nn.Linear(embed_dim, input_dim)

        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
        self.head_dim = embed_dim // num_heads  # 每个头的维度

        # 线性变换层，用于生成 Q, K, V
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)

        # # 最终的输出变换
        self.W_out = nn.Linear(embed_dim, embed_dim)
        # self.W_out1 = nn.Linear(embed_dim, 100)
        # self.W_out2 = nn.Linear(100, 50)
        # self.W_out3 = nn.Linear(50, embed_dim)

    def forward(self, x):
        """
        x: 输入张量，形状 (batch_size, seq_len, input_dim)
        """
        batch_size, seq_len, _ = x.shape

        # 扩展输入维度
        x_expanded = self.input_expansion(x)  # (batch_size, seq_len, embed_dim)

        # 计算 Q, K, V
        Q = self.W_q(x_expanded)  # (batch_size, seq_len, embed_dim)
        K = self.W_k(x_expanded)  # (batch_size, seq_len, embed_dim)
        V = self.W_v(x_expanded)  # (batch_size, seq_len, embed_dim)

        # 重新调整形状，适应多头注意力
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)

        # 计算注意力分数 (QK^T / sqrt(d))
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(
            self.head_dim)  # (batch_size, num_heads, seq_len, seq_len)

        # 计算 softmax 归一化的注意力权重
        attention_weights = F.softmax(attention_scores, dim=-1)  # (batch_size, num_heads, seq_len, seq_len)

        # 计算最终的注意力输出
        attention_output = torch.matmul(attention_weights, V)  # (batch_size, num_heads, seq_len, head_dim)

        # 调整形状，合并多个头
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len,
                                                                              self.embed_dim)  # (batch_size, seq_len, embed_dim)

        # 通过最终的线性变换层
        output = self.W_out(attention_output)  # (batch_size, seq_len, embed_dim)

        # 压缩回原始维度
        compressed_output = self.output_compression(output)  # (batch_size, seq_len, input_dim)

        return compressed_output, attention_weights


# 创建 Transformer
model = OpTransformer(embed_dim, num_heads, window_size).to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


# 修改后的损失函数，返回约束损失值
def constrained_loss_with_stats(pred_mse, pred_pattern, pred_pattern_mse, data_input_pattern, target_mse,
                                target_pattern, pattern_list, support_List, epsilon=0.00001):
    original_mseloss = nn.MSELoss()(pred_mse, target_mse)

    maxyou = False
    minzuo = False
    midzhong = False
    constraint_loss = 0.
    temp = -1
    tempzuo = -1
    tempyou = -1
    surport_count = 0

    maxbord = len(pattern_list[0]) - 1

    for single_pattern in pattern_list:
        if surport_count >= pred_pattern.shape[0]:
            break

        if single_pattern[len(single_pattern) - 1] == 0:
            count = 0
            for i in single_pattern:
                if i == 1:
                    temp = count
                minzuo = True
                count = count + 1
        elif single_pattern[len(single_pattern) - 1] == maxbord:
            count = 0
            for j in single_pattern:
                if j == len(single_pattern) - 2:
                    temp = count
                maxyou = True
                count = count + 1
        else:
            count = 0
            for m in single_pattern:
                if m == single_pattern[len(single_pattern) - 1] - 1:
                    tempzuo = count
                if m == single_pattern[len(single_pattern) - 1] + 1:
                    tempyou = count
                midzhong = True
                count = count + 1

        if minzuo == True and surport_count < pred_pattern.shape[0]:
            ref_val = data_input_pattern[surport_count, temp]
            margin = torch.clamp(pred_pattern[surport_count] - ref_val + epsilon, min=0)
            lambda_surp = support_List[surport_count] if surport_count < len(support_List) else support_List[-1]
            constraint_loss += (lambda_surp * margin).mean()

        if maxyou == True and surport_count < pred_pattern.shape[0]:
            ref_val = data_input_pattern[surport_count, temp]
            margin = torch.clamp(ref_val - pred_pattern[surport_count] + epsilon, min=0)
            lambda_surp = support_List[surport_count] if surport_count < len(support_List) else support_List[-1]
            constraint_loss += (lambda_surp * margin).mean()

        if midzhong == True and surport_count < pred_pattern.shape[0]:
            ref_valzuo = data_input_pattern[surport_count, tempzuo]
            ref_valyou = data_input_pattern[surport_count, tempyou]

            loss_left = torch.clamp(ref_valzuo + epsilon - pred_pattern[surport_count], min=0)
            loss_right = torch.clamp(pred_pattern[surport_count] - ref_valyou + epsilon, min=0)

            lambda_surp = support_List[surport_count] if surport_count < len(support_List) else support_List[-1]
            constraint_loss += (lambda_surp * (loss_left + loss_right)).mean()

        maxyou = False
        minzuo = False
        midzhong = False
        surport_count = surport_count + 1

    if pred_pattern_mse.numel() == 0 or target_pattern.numel() == 0:
        pattern_mseloss = torch.tensor(0.0, device=pred_pattern_mse.device)
        constraint_loss = torch.tensor(0.0, device=pred_pattern_mse.device)
    else:
        pattern_mseloss = nn.MSELoss()(pred_pattern_mse, target_pattern)

    pattern_loss = 0.001 * pattern_mseloss + constraint_loss
    total_loss = original_mseloss + pattern_loss

    # 返回约束损失值用于统计
    constraint_loss_value = constraint_loss.item() if hasattr(constraint_loss, 'item') else 0.0

    return total_loss, pattern_mseloss, original_mseloss, constraint_loss_value


# 计算约束违反次数的辅助函数
def calculate_constraint_violations(pred_pattern, data_input_pattern, pattern_list, epsilon=0.00001):
    violations = 0
    maxbord = len(pattern_list[0]) - 1

    for i, single_pattern in enumerate(pattern_list):
        if i >= pred_pattern.shape[0]:
            break

        if single_pattern[-1] == 0:  # 最小值约束
            temp = next((j for j, x in enumerate(single_pattern) if x == 1), -1)
            if temp != -1:
                ref_val = data_input_pattern[i, temp]
                if pred_pattern[i] > ref_val - epsilon:
                    violations += 1

        elif single_pattern[-1] == maxbord:  # 最大值约束
            temp = next((j for j, x in enumerate(single_pattern) if x == maxbord - 1), -1)
            if temp != -1:
                ref_val = data_input_pattern[i, temp]
                if pred_pattern[i] < ref_val + epsilon:
                    violations += 1

        else:  # 中间值约束
            tempzuo = next((j for j, x in enumerate(single_pattern) if x == single_pattern[-1] - 1), -1)
            tempyou = next((j for j, x in enumerate(single_pattern) if x == single_pattern[-1] + 1), -1)

            if tempzuo != -1 and tempyou != -1:
                ref_valzuo = data_input_pattern[i, tempzuo]
                ref_valyou = data_input_pattern[i, tempyou]
                if not (ref_valzuo + epsilon < pred_pattern[i] < ref_valyou - epsilon):
                    violations += 1

    return violations


import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict

# 1. 在训练开始前初始化统计变量
training_stats = {
    'epochs': [],
    'total_losses': [],
    'original_mse_losses': [],
    'pattern_mse_losses': [],
    'constraint_losses': [],
    'pattern_match_ratios': [],
    'constraint_violation_counts': []
}

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    # Mseloss = 0.0
    Mseloss = []
    # Myloss = 0.0
    Myloss = []

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)
        y_batch = y_batch.reshape(y_batch.shape[0], -1, 1)

        optimizer.zero_grad()
        outputs, _ = model(x_batch)

        pred = outputs[:, -1, :]

        first_shape = pred.shape[0]

        y_batch = y_batch.reshape(y_batch.shape[0], -1, 1)

        pred_mid = pred.reshape(pred.shape[0], -1, 1)

        # ---------------------------------------------------------------------------------------------------------

        jieguo = torch.cat((x_batch, pred_mid), dim=1)

        # ---------------------------------------------------------------------------------------------------------

        # 生成索引 (0,1,2,3,4,5)，形状为 (6,)
        indices = torch.arange(jieguo.shape[1])

        # 扩展索引为 (64, 6)
        indices = indices.expand(jieguo.shape[0], -1)

        # 变换形状，去掉最后一个维度
        data_squeezed = jieguo.squeeze(-1)  # 形状 (64, 6)

        # 拼接数值和索引，形状 (64, 6, 2)
        formatted_data = torch.stack([data_squeezed, indices.to(device)], dim=-1)

        # 提取数值部分，并根据这些值进行排序，保存排序后的索引
        sorted_indices = torch.argsort(formatted_data[:, :, 0], dim=1)

        # 将排序后的索引形状转为 (64, 6, 1)
        sorted_indices = sorted_indices.unsqueeze(-1)

        # print(sorted_indices)
        sorted_indices = np.array(sorted_indices.cpu())

        y_batch = y_batch.reshape(y_batch.shape[0], -1)
        x_batch = x_batch.reshape(x_batch.shape[0], -1)

        # 初始化为空张量
        pred_pattern = torch.empty(0, pred.shape[1], device=pred.device)
        pred_pattern_mse = torch.empty(0, pred.shape[1], device=pred.device)
        pred_mse = torch.empty(0, pred.shape[1], device=pred.device)
        target_pattern = torch.empty(0, y_batch.shape[1], device=y_batch.device)
        target_mse = torch.empty(0, y_batch.shape[1], device=y_batch.device)
        data_input_pattern = torch.empty(0, x_batch.shape[1], device=x_batch.device)
        data_input_mse = torch.empty(0, x_batch.shape[1], device=x_batch.device)

        for single_sec in range(0, x_batch.shape[0]):
            isExists = np.any(np.all(sorted_indices[single_sec, :].reshape(-1) == pattern_list, axis=1))

            if isExists:
                pred_pattern = torch.cat((pred_pattern, pred[single_sec, :].unsqueeze(0)), dim=0)
                pred_pattern_mse = torch.cat((pred_pattern_mse, pred[single_sec, :].unsqueeze(0)), dim=0)
                data_input_pattern = torch.cat((data_input_pattern, x_batch[single_sec, :].unsqueeze(0)), dim=0)
                target_pattern = torch.cat((target_pattern, y_batch[single_sec, :].unsqueeze(0)), dim=0)
            else:
                pred_mse = torch.cat((pred_mse, pred[single_sec, :].unsqueeze(0)), dim=0)
                target_mse = torch.cat((target_mse, y_batch[single_sec, :].unsqueeze(0)), dim=0)
                data_input_mse = torch.cat((data_input_mse, x_batch[single_sec, :].unsqueeze(0)), dim=0)

        loss, mseloss, myloss = constrained_loss(pred_mse, pred_pattern, pred_pattern_mse, data_input_pattern,
                                                 target_mse,
                                                 target_pattern, pattern_list, support_List)

        # loss = constrained_loss(outputs[:, -1, :], x_batch, adjusted_patterns, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # Mseloss += mseloss.item()

        Mseloss.append(mseloss.item())
        Mseloss = sorted(Mseloss)
        # print(Mseloss)

        Myloss.append(myloss.item())
        Myloss = sorted(Myloss)

        # Myloss += myloss.item()
    # print(
    #     f"Epoch [{epoch + 1}/{num_epochs}], totla_loss: {total_loss / len(train_loader):.6f}, mse_loss:{Mseloss / len(train_loader):.6f}, constraint_loss:{Myloss / len(train_loader):.10f}")

    # print("len_trainloader:", len(train_loader))
    Mseloss = Mseloss[1:-1]
    Mseloss_sum = sum(Mseloss)
    Myloss = Myloss[1:-1]
    Myloss_sum = sum(Myloss)

    print(
        f"Epoch [{epoch + 1}/{num_epochs}], totla_loss: {total_loss / len(train_loader):.6f}, pattern_mseloss:{Mseloss_sum / (len(train_loader) - 2):.6f}, original_mseloss:{Myloss_sum / (len(train_loader) - 2):.10f}")


# 训练完成后绘制可视化图表
def plot_training_analysis(training_stats, save_path=None):
    """绘制训练分析图表"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    epochs = training_stats['epochs']

    # 1. 损失分解
    axes[0, 0].plot(epochs, training_stats['total_losses'], 'b-', label='Total Loss', linewidth=2)
    axes[0, 0].plot(epochs, training_stats['original_mse_losses'], 'g-', label='Original MSE', linewidth=2)
    axes[0, 0].plot(epochs, training_stats['pattern_mse_losses'], 'r-', label='Pattern MSE', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Loss Decomposition')
    axes[0, 0].legend()
    axes[0, 0].grid(True)

    # 2. 约束损失变化
    axes[0, 1].plot(epochs, training_stats['constraint_losses'], 'purple', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Constraint Loss')
    axes[0, 1].set_title('Constraint Loss Evolution')
    axes[0, 1].grid(True)

    # 3. 模式匹配率
    axes[0, 2].plot(epochs, training_stats['pattern_match_ratios'], 'orange', linewidth=2)
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('Pattern Match Ratio')
    axes[0, 2].set_title('Pattern Matching Rate')
    axes[0, 2].grid(True)

    # 4. 约束违反次数
    axes[1, 0].plot(epochs, training_stats['constraint_violation_counts'], 'red', linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Violation Count')
    axes[1, 0].set_title('Constraint Violations')
    axes[1, 0].grid(True)

    # 5. 约束损失占比
    total_losses = np.array(training_stats['total_losses'])
    constraint_losses = np.array(training_stats['constraint_losses'])
    constraint_ratios = constraint_losses / (total_losses + 1e-8)
    axes[1, 1].plot(epochs, constraint_ratios, 'brown', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Constraint Loss Ratio')
    axes[1, 1].set_title('Constraint Loss Contribution')
    axes[1, 1].grid(True)

    # 6. 损失改善趋势
    if len(training_stats['total_losses']) > 1:
        loss_improvements = np.diff(training_stats['total_losses'])
        axes[1, 2].plot(epochs[1:], loss_improvements, 'navy', linewidth=2)
        axes[1, 2].set_xlabel('Epoch')
        axes[1, 2].set_ylabel('Loss Change')
        axes[1, 2].set_title('Loss Improvement Trend')
        axes[1, 2].grid(True)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

    # 打印统计摘要
    print("\n=== Training Statistics Summary ===")
    print(f"Final Total Loss: {training_stats['total_losses'][-1]:.6f}")
    print(f"Final Original MSE: {training_stats['original_mse_losses'][-1]:.6f}")
    print(f"Final Pattern MSE: {training_stats['pattern_mse_losses'][-1]:.6f}")
    print(f"Final Constraint Loss: {training_stats['constraint_losses'][-1]:.6f}")
    print(f"Average Pattern Match Ratio: {np.mean(training_stats['pattern_match_ratios']):.3f}")
    print(f"Final Constraint Violations: {training_stats['constraint_violation_counts'][-1]}")
    print(
        f"Constraint Loss Reduction: {(training_stats['constraint_losses'][0] - training_stats['constraint_losses'][-1]) / training_stats['constraint_losses'][0] * 100:.2f}%")


# 在训练完成后调用
print("Training complete.")
plot_training_analysis(training_stats, save_path='pattern_constraint_analysis.png')

print("Training complete.")
model.eval()

totla_loss_all = 0.0
pattern_mseloss_all = 0.0
original_mseloss_all = 0.0

total_loss = 0.0
Mseloss = []
Myloss = []
with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_batch = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_batch = y_test.reshape(y_test.shape[0], -1, 1)
        optimizer.zero_grad()
        outputs, _ = model(x_batch)
        pred = outputs[:, -1, :]

        first_shape = pred.shape[0]

        # print("final_out.shape", final_output.shape)
        y_batch = y_batch.reshape(y_batch.shape[0], -1, 1)

        pred_mid = pred.reshape(pred.shape[0], -1, 1)

        jieguo = torch.cat((x_batch, pred_mid), dim=1)

        # 生成索引 (0,1,2,3,4,5)，形状为 (6,)
        indices = torch.arange(jieguo.shape[1])

        # 扩展索引为 (64, 6)
        indices = indices.expand(jieguo.shape[0], -1)

        # 变换形状，去掉最后一个维度
        data_squeezed = jieguo.squeeze(-1)  # 形状 (64, 6)

        # 拼接数值和索引，形状 (64, 6, 2)
        formatted_data = torch.stack([data_squeezed, indices.to(device)], dim=-1)

        # 提取数值部分，并根据这些值进行排序，保存排序后的索引
        sorted_indices = torch.argsort(formatted_data[:, :, 0], dim=1)

        # 将排序后的索引形状转为 (64, 6, 1)
        sorted_indices = sorted_indices.unsqueeze(-1)

        sorted_indices = np.array(sorted_indices.cpu())

        y_batch = y_batch.reshape(y_batch.shape[0], -1)
        x_batch = x_batch.reshape(x_batch.shape[0], -1)

        # 初始化为空张量
        pred_pattern = torch.empty(0, pred.shape[1], device=pred.device)
        pred_pattern_mse = torch.empty(0, pred.shape[1], device=pred.device)
        pred_mse = torch.empty(0, pred.shape[1], device=pred.device)
        target_pattern = torch.empty(0, y_batch.shape[1], device=y_batch.device)
        target_mse = torch.empty(0, y_batch.shape[1], device=y_batch.device)
        data_input_pattern = torch.empty(0, x_batch.shape[1], device=x_batch.device)
        data_input_mse = torch.empty(0, x_batch.shape[1], device=x_batch.device)

        for single_sec in range(0, x_batch.shape[0]):
            isExists = np.any(np.all(sorted_indices[single_sec, :].reshape(-1) == pattern_list, axis=1))

            if isExists:
                pred_pattern = torch.cat((pred_pattern, pred[single_sec, :].unsqueeze(0)), dim=0)
                pred_pattern_mse = torch.cat((pred_pattern_mse, pred[single_sec, :].unsqueeze(0)), dim=0)
                data_input_pattern = torch.cat((data_input_pattern, x_batch[single_sec, :].unsqueeze(0)), dim=0)
                target_pattern = torch.cat((target_pattern, y_batch[single_sec, :].unsqueeze(0)), dim=0)
            else:
                pred_mse = torch.cat((pred_mse, pred[single_sec, :].unsqueeze(0)), dim=0)
                target_mse = torch.cat((target_mse, y_batch[single_sec, :].unsqueeze(0)), dim=0)
                data_input_mse = torch.cat((data_input_mse, x_batch[single_sec, :].unsqueeze(0)), dim=0)

        loss, mseloss, myloss = constrained_loss(pred_mse, pred_pattern, pred_pattern_mse, data_input_pattern,
                                                 target_mse,
                                                 target_pattern, pattern_list, support_List)

        total_loss += loss.item()

        Mseloss.append(mseloss.item())
        Mseloss = sorted(Mseloss)
        Myloss.append(myloss.item())
        Myloss = sorted(Myloss)

    Mseloss = Mseloss[1:-1]
    Mseloss_sum = sum(Mseloss)
    Myloss = Myloss[1:-1]
    Myloss_sum = sum(Myloss)

print(
    f"totla_loss: {total_loss / len(test_loader):.6f}, pattern_mseloss:{Mseloss_sum / (len(test_loader) - 2):.6f}, original_mseloss:{Myloss_sum / (len(test_loader) - 2):.10f}")

torch.save(model.state_dict(), save_path)
