import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler


def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子，例如42

# ================================================= 超参数 ========================================================
length_pattern = 7
save_path = "Ablation experiments/pattern_model_7.pth"
# 超参数
window_size = 7  # 滑动窗口
embed_dim = 64  # 嵌入维度 = 特征数

print("特征数：", embed_dim)
num_heads = 4
batch_size = 256
num_epochs = 150
learning_rate = 1e-4

# ================================================= 超参数 ========================================================

# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 存储频繁模式和支持度
frequent_patterns = []

# 读取文件
with open("../data/OPR-output-ETT-50/OPR-Miner-Output.txt", "r",
          encoding="GBK") as file:
    for line in file:
        match = re.match(r"频繁模式：([\d,]+), 支持度为：(\d+)", line.strip())
        if match:
            pattern = tuple(map(int, match.group(1).split(',')))  # 将模式转换为元组
            support = int(match.group(2))  # 支持度转换为整数
            frequent_patterns.append((pattern, support))

# 打印结果
for pattern, support in frequent_patterns:
    print(f"模式: {pattern}, 支持度: {support}")

print(frequent_patterns)

patterns_length_n = [(pattern, support) for pattern, support in frequent_patterns if len(pattern) == length_pattern]

length_pattern_son = length_pattern - 1
patterns_length_n_son = [(pattern, support) for pattern, support in frequent_patterns if
                         len(pattern) == length_pattern_son]

# 输出结果
for pattern, support in patterns_length_n_son:
    print(f"模式: {pattern}, 支持度: {support}")
print(patterns_length_n_son)

# 将频繁保序模式的元素全部 -1
adjusted_patterns_son = [((tuple(x - 1 for x in pattern)), support) for pattern, support in patterns_length_n_son]

# 输出转换后的结果
print(adjusted_patterns_son)

# 计算 softmax 归一化权重
support_values = np.array([sup for _, sup in adjusted_patterns_son])

# 输出结果
for pattern, support in patterns_length_n:
    print(f"模式: {pattern}, 支持度: {support}")
print(patterns_length_n)

# 将频繁保序模式的元素全部 -1
adjusted_patterns = [((tuple(x - 1 for x in pattern)), support) for pattern, support in patterns_length_n]

# 输出转换后的结果
print(adjusted_patterns)

# 计算 softmax 归一化权重
support_values = np.array([sup for _, sup in adjusted_patterns])

# 归一化处理
total_sum = np.sum(support_values)
normalized_values = support_values / total_sum
print("normalized_values:", normalized_values)

support_List = normalized_values

val = np.exp(normalized_values)
print("val:", val)

softmax_weights = normalized_values

print(support_values)
print(softmax_weights)

# B = torch.tensor(B, dtype=torch.float32).to(device)

dataset_swat = pd.read_csv('../data/datasets_ETT_afterProcess.txt', encoding='utf-8',
                           low_memory=False)

# 预处理数据：假设前 N-1 列是特征，最后一列是标签
feature_cols = dataset_swat.columns[0]

features = dataset_swat[feature_cols].values  # (num_samples, num_features)

print("feature.shape", features.shape)

pattern_list = np.array([pat for pat, _ in adjusted_patterns])

print("pattern_list_gaibian:", pattern_list[:, 0:-1])


def constrained_loss(pred_mse, pred_pattern, pred_pattern_mse, data_input_pattern, target_mse,
                     target_pattern, pattern_list, support_List, epsilon=0.00001):
    original_mseloss = nn.MSELoss()(pred_mse, target_mse)

    maxyou = False
    minzuo = False
    midzhong = False
    constraint_loss = 0.
    temp = -1
    tempzuo = -1
    tempyou = -1
    surport_count = 0

    # print(pattern_list)
    maxbord = len(pattern_list[0]) - 1

    for single_pattern in pattern_list:
        if single_pattern[len(single_pattern) - 1] == 0:
            count = 0
            for i in single_pattern:
                if i == 1:
                    temp = count
                minzuo = True
                count = count + 1
        elif single_pattern[len(single_pattern) - 1] == maxbord:
            count = 0
            for j in single_pattern:
                # if j == 4:
                if j == len(single_pattern) - 2:
                    temp = count
                maxyou = True
                count = count + 1
        else:
            count = 0
            for m in single_pattern:
                if m == single_pattern[len(single_pattern) - 1] - 1:
                    tempzuo = count
                if m == single_pattern[len(single_pattern) - 1] + 1:
                    tempyou = count
                midzhong = True
                count = count + 1

        if minzuo == True:
            ref_val = data_input_pattern[:, temp]  # 这段可能需要改
            # 反向约束：pred_T6 < ref_val - epsilon

            margin = torch.clamp(pred_pattern - ref_val + epsilon, min=0)
            lambda_surp = support_List[surport_count]
            constraint_loss += (lambda_surp * margin).mean()

        if maxyou == True:
            ref_val = data_input_pattern[:, temp]  # 这段可能需要改
            # 正向约束：pred_T6 > ref_val + epsilon

            margin = torch.clamp(ref_val - pred_pattern + epsilon, min=0)
            lambda_surp = support_List[surport_count]
            constraint_loss += (lambda_surp * margin).mean()

        if midzhong == True:
            ref_valzuo = data_input_pattern[:, tempzuo]  # 这段可能需要改

            ref_valyou = data_input_pattern[:, tempyou]  # 这段可能需要改

            # 左边界约束：pred > ref_left + epsilon

            loss_left = torch.clamp(ref_valzuo + epsilon - pred_pattern, min=0)

            # 右边界约束：pred < ref_right + epsilon

            loss_right = torch.clamp(pred_pattern - ref_valyou + epsilon, min=0)

            lambda_surp = support_List[surport_count]

            # 计算约束损失
            constraint_loss += (lambda_surp * (loss_left + loss_right)).mean()

        maxyou = False
        minzuo = False
        midzhong = False
        surport_count = surport_count + 1

    if pred_pattern_mse.numel() == 0 or target_pattern.numel() == 0:
        pattern_mseloss = torch.tensor(0.0, device=pred_pattern_mse.device)  # 直接赋值 0
        constraint_loss = torch.tensor(0.0, device=pred_pattern_mse.device)
    else:
        pattern_mseloss = nn.MSELoss()(pred_pattern_mse, target_pattern)

    pattern_loss = 0.001 * pattern_mseloss + constraint_loss

    total_loss = original_mseloss + pattern_loss

    # return total_loss, total_mseloss, pattern_loss
    return total_loss, pattern_mseloss, original_mseloss


# --------------------------------------------------------------------------------------------------------------------------------------------------------------
# --------------------------------------------------------------------------------------------------------------------------------------------------------------


# ============== 修改数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        # self.features = torch.tensor(features, dtype=torch.float32)
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 假设 features, labels 是 numpy 数组
train_size = int(len(features) * 0.6)  # 80% 作为训练集
test_size = len(features) - train_size

train_features, test_features = features[:train_size], features[train_size:]




train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


class OpTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, input_dim=1):
        super(OpTransformer, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.input_dim = input_dim

        # Input dimension expansion layer
        self.input_expansion = nn.Linear(input_dim, embed_dim)

        # Output dimension compression layer
        self.output_compression = nn.Linear(embed_dim, input_dim)

        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
        self.head_dim = embed_dim // num_heads  # 每个头的维度

        # 线性变换层，用于生成 Q, K, V
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)

        # 最终的输出变换
        # self.W_out = nn.Linear(embed_dim, embed_dim)
        # 最终的输出变换
        self.W_out1 = nn.Linear(embed_dim, 100)
        self.W_out2 = nn.Linear(100, 50)
        self.W_out3 = nn.Linear(50, embed_dim)

    def forward(self, x):
        """
        x: 输入张量，形状 (batch_size, seq_len, input_dim)
        """
        batch_size, seq_len, _ = x.shape

        # 扩展输入维度
        x_expanded = self.input_expansion(x)  # (batch_size, seq_len, embed_dim)

        # 计算 Q, K, V
        Q = self.W_q(x_expanded)  # (batch_size, seq_len, embed_dim)
        K = self.W_k(x_expanded)  # (batch_size, seq_len, embed_dim)
        V = self.W_v(x_expanded)  # (batch_size, seq_len, embed_dim)

        # 重新调整形状，适应多头注意力
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,
                                                                                 2)  # (batch_size, num_heads, seq_len, head_dim)

        # 计算注意力分数 (QK^T / sqrt(d))
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(
            self.head_dim)  # (batch_size, num_heads, seq_len, seq_len)

        # 计算 softmax 归一化的注意力权重
        attention_weights = F.softmax(attention_scores, dim=-1)  # (batch_size, num_heads, seq_len, seq_len)

        # 计算最终的注意力输出
        attention_output = torch.matmul(attention_weights, V)  # (batch_size, num_heads, seq_len, head_dim)

        # 调整形状，合并多个头
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len,
                                                                              self.embed_dim)  # (batch_size, seq_len, embed_dim)

        # 通过最终的线性变换层
        output = F.relu(self.W_out1(attention_output))  # (batch_size, seq_len, embed_dim)
        output = F.relu(self.W_out2(output))  # (batch_size, seq_len, embed_dim)
        output = self.W_out3(output)  # (batch_size, seq_len, embed_dim)

        # 压缩回原始维度
        compressed_output = self.output_compression(output)  # (batch_size, seq_len, input_dim)

        return compressed_output, attention_weights


# 创建 Transformer
model = OpTransformer(embed_dim, num_heads, window_size).to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    Mseloss = []
    Myloss = []

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)
        y_batch = y_batch.reshape(y_batch.shape[0], -1, 1)
        optimizer.zero_grad()
        outputs, _ = model(x_batch)

        pred = outputs[:, -1, :]

        first_shape = pred.shape[0]
        y_batch = y_batch.reshape(y_batch.shape[0], -1, 1)

        pred_mid = pred.reshape(pred.shape[0], -1, 1)

        jieguo = torch.cat((x_batch, pred_mid), dim=1)
        # jieguo = x_batch

        # 生成索引 (0,1,2,3,4,5)，形状为 (6,)
        indices = torch.arange(jieguo.shape[1])

        # 扩展索引为 (64, 6)
        indices = indices.expand(jieguo.shape[0], -1)

        # 变换形状，去掉最后一个维度
        data_squeezed = jieguo.squeeze(-1)  # 形状 (64, 6)

        # 拼接数值和索引，形状 (64, 6, 2)
        formatted_data = torch.stack([data_squeezed, indices.to(device)], dim=-1)

        # 提取数值部分，并根据这些值进行排序，保存排序后的索引
        sorted_indices = torch.argsort(formatted_data[:, :, 0], dim=1)
        sorted_indices = sorted_indices.unsqueeze(-1)

        # print(sorted_indices)
        sorted_indices = np.array(sorted_indices.cpu())

        y_batch = y_batch.reshape(y_batch.shape[0], -1)
        x_batch = x_batch.reshape(x_batch.shape[0], -1)

        # 初始化为空张量
        pred_pattern = torch.empty(0, pred.shape[1], device=pred.device)
        pred_pattern_mse = torch.empty(0, pred.shape[1], device=pred.device)
        pred_mse = torch.empty(0, pred.shape[1], device=pred.device)
        target_pattern = torch.empty(0, y_batch.shape[1], device=y_batch.device)
        target_mse = torch.empty(0, y_batch.shape[1], device=y_batch.device)
        data_input_pattern = torch.empty(0, x_batch.shape[1], device=x_batch.device)
        data_input_mse = torch.empty(0, x_batch.shape[1], device=x_batch.device)

        # print("x_batch.shapoe:", x_batch.shape[0])
        # print("sorted_indices", sorted_indices.shape)

        for single_sec in range(0, x_batch.shape[0]):
            isExists = np.any(np.all(sorted_indices[single_sec, :].reshape(-1) == pattern_list, axis=1))

            if isExists:
                pred_pattern = torch.cat((pred_pattern, pred[single_sec, :].unsqueeze(0)), dim=0)
                pred_pattern_mse = torch.cat((pred_pattern_mse, pred[single_sec, :].unsqueeze(0)), dim=0)
                data_input_pattern = torch.cat((data_input_pattern, x_batch[single_sec, :].unsqueeze(0)), dim=0)
                target_pattern = torch.cat((target_pattern, y_batch[single_sec, :].unsqueeze(0)), dim=0)
            else:
                pred_mse = torch.cat((pred_mse, pred[single_sec, :].unsqueeze(0)), dim=0)
                target_mse = torch.cat((target_mse, y_batch[single_sec, :].unsqueeze(0)), dim=0)
                data_input_mse = torch.cat((data_input_mse, x_batch[single_sec, :].unsqueeze(0)), dim=0)

        loss, mseloss, myloss = constrained_loss(pred_mse, pred_pattern, pred_pattern_mse, data_input_pattern,
                                                 target_mse,
                                                 target_pattern, pattern_list, support_List)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        Mseloss.append(mseloss.item())
        Mseloss = sorted(Mseloss)
        Myloss.append(myloss.item())
        Myloss = sorted(Myloss)

    Mseloss = Mseloss[1:-1]
    Mseloss_sum = sum(Mseloss)
    Myloss = Myloss[1:-1]
    Myloss_sum = sum(Myloss)

    print(
        f"Epoch [{epoch + 1}/{num_epochs}], totla_loss: {total_loss / len(train_loader):.6f}, pattern_mseloss:{Mseloss_sum / (len(train_loader) - 2):.6f}, original_mseloss:{Myloss_sum / (len(train_loader) - 2):.10f}")
    # print(
    #     f"Epoch [{epoch + 1}/{num_epochs}], totla_loss: {total_loss / len(train_loader):.6f}, pattern_mseloss:{Mseloss_sum / (len(train_loader)):.6f}, original_mseloss:{Myloss_sum / (len(train_loader)):.10f}")

print("Training complete.")
model.eval()

totla_loss_all = 0.0
pattern_mseloss_all = 0.0
original_mseloss_all = 0.0

total_loss = 0.0
Mseloss = []
Myloss = []
with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_batch = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_batch = y_test.reshape(y_test.shape[0], -1, 1)
        optimizer.zero_grad()
        outputs, _ = model(x_batch)
        pred = outputs[:, -1, :]

        first_shape = pred.shape[0]

        # print("final_out.shape", final_output.shape)
        y_batch = y_batch.reshape(y_batch.shape[0], -1, 1)

        pred_mid = pred.reshape(pred.shape[0], -1, 1)

        jieguo = torch.cat((x_batch, pred_mid), dim=1)

        # 生成索引 (0,1,2,3,4,5)，形状为 (6,)
        indices = torch.arange(jieguo.shape[1])

        # 扩展索引为 (64, 6)
        indices = indices.expand(jieguo.shape[0], -1)

        # 变换形状，去掉最后一个维度
        data_squeezed = jieguo.squeeze(-1)  # 形状 (64, 6)

        # 拼接数值和索引，形状 (64, 6, 2)
        formatted_data = torch.stack([data_squeezed, indices.to(device)], dim=-1)

        # 提取数值部分，并根据这些值进行排序，保存排序后的索引
        sorted_indices = torch.argsort(formatted_data[:, :, 0], dim=1)

        # 将排序后的索引形状转为 (64, 6, 1)
        sorted_indices = sorted_indices.unsqueeze(-1)

        sorted_indices = np.array(sorted_indices.cpu())

        y_batch = y_batch.reshape(y_batch.shape[0], -1)
        x_batch = x_batch.reshape(x_batch.shape[0], -1)

        # 初始化为空张量
        pred_pattern = torch.empty(0, pred.shape[1], device=pred.device)
        pred_pattern_mse = torch.empty(0, pred.shape[1], device=pred.device)
        pred_mse = torch.empty(0, pred.shape[1], device=pred.device)
        target_pattern = torch.empty(0, y_batch.shape[1], device=y_batch.device)
        target_mse = torch.empty(0, y_batch.shape[1], device=y_batch.device)
        data_input_pattern = torch.empty(0, x_batch.shape[1], device=x_batch.device)
        data_input_mse = torch.empty(0, x_batch.shape[1], device=x_batch.device)

        for single_sec in range(0, x_batch.shape[0]):
            isExists = np.any(np.all(sorted_indices[single_sec, :].reshape(-1) == pattern_list, axis=1))

            if isExists:
                pred_pattern = torch.cat((pred_pattern, pred[single_sec, :].unsqueeze(0)), dim=0)
                pred_pattern_mse = torch.cat((pred_pattern_mse, pred[single_sec, :].unsqueeze(0)), dim=0)
                data_input_pattern = torch.cat((data_input_pattern, x_batch[single_sec, :].unsqueeze(0)), dim=0)
                target_pattern = torch.cat((target_pattern, y_batch[single_sec, :].unsqueeze(0)), dim=0)
            else:
                pred_mse = torch.cat((pred_mse, pred[single_sec, :].unsqueeze(0)), dim=0)
                target_mse = torch.cat((target_mse, y_batch[single_sec, :].unsqueeze(0)), dim=0)
                data_input_mse = torch.cat((data_input_mse, x_batch[single_sec, :].unsqueeze(0)), dim=0)

        loss, mseloss, myloss = constrained_loss(pred_mse, pred_pattern, pred_pattern_mse, data_input_pattern,
                                                 target_mse,
                                                 target_pattern, pattern_list, support_List)

        total_loss += loss.item()

        Mseloss.append(mseloss.item())
        Mseloss = sorted(Mseloss)
        Myloss.append(myloss.item())
        Myloss = sorted(Myloss)

    Mseloss = Mseloss[1:-1]
    Mseloss_sum = sum(Mseloss)
    Myloss = Myloss[1:-1]
    Myloss_sum = sum(Myloss)

print(
    f"totla_loss: {total_loss / len(test_loader):.6f}, pattern_mseloss:{Mseloss_sum / (len(test_loader) - 2):.6f}, original_mseloss:{Myloss_sum / (len(test_loader) - 2):.10f}")

model.eval()

totla_loss_all = 0.0
pattern_mseloss_all = 0.0
original_mseloss_all = 0.0

total_loss = 0.0
Mseloss = []
Myloss = []
with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_batch = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_batch = y_test.reshape(y_test.shape[0], -1, 1)
        optimizer.zero_grad()
        outputs, _ = model(x_batch)
        pred = outputs[:, -1, :]

        first_shape = pred.shape[0]

        # print("final_out.shape", final_output.shape)
        y_batch = y_batch.reshape(y_batch.shape[0], -1, 1)

        pred_mid = pred.reshape(pred.shape[0], -1, 1)

        jieguo = torch.cat((x_batch, pred_mid), dim=1)

        # 生成索引 (0,1,2,3,4,5)，形状为 (6,)
        indices = torch.arange(jieguo.shape[1])

        # 扩展索引为 (64, 6)
        indices = indices.expand(jieguo.shape[0], -1)

        # 变换形状，去掉最后一个维度
        data_squeezed = jieguo.squeeze(-1)  # 形状 (64, 6)

        # 拼接数值和索引，形状 (64, 6, 2)
        formatted_data = torch.stack([data_squeezed, indices.to(device)], dim=-1)

        # 提取数值部分，并根据这些值进行排序，保存排序后的索引
        sorted_indices = torch.argsort(formatted_data[:, :, 0], dim=1)

        # 将排序后的索引形状转为 (64, 6, 1)
        sorted_indices = sorted_indices.unsqueeze(-1)

        sorted_indices = np.array(sorted_indices.cpu())

        y_batch = y_batch.reshape(y_batch.shape[0], -1)
        x_batch = x_batch.reshape(x_batch.shape[0], -1)

        # 初始化为空张量
        pred_pattern = torch.empty(0, pred.shape[1], device=pred.device)
        pred_pattern_mse = torch.empty(0, pred.shape[1], device=pred.device)
        pred_mse = torch.empty(0, pred.shape[1], device=pred.device)
        target_pattern = torch.empty(0, y_batch.shape[1], device=y_batch.device)
        target_mse = torch.empty(0, y_batch.shape[1], device=y_batch.device)
        data_input_pattern = torch.empty(0, x_batch.shape[1], device=x_batch.device)
        data_input_mse = torch.empty(0, x_batch.shape[1], device=x_batch.device)

        for single_sec in range(0, x_batch.shape[0]):
            isExists = np.any(np.all(sorted_indices[single_sec, :].reshape(-1) == pattern_list, axis=1))

            if isExists:
                pred_pattern = torch.cat((pred_pattern, pred[single_sec, :].unsqueeze(0)), dim=0)
                pred_pattern_mse = torch.cat((pred_pattern_mse, pred[single_sec, :].unsqueeze(0)), dim=0)
                data_input_pattern = torch.cat((data_input_pattern, x_batch[single_sec, :].unsqueeze(0)), dim=0)
                target_pattern = torch.cat((target_pattern, y_batch[single_sec, :].unsqueeze(0)), dim=0)
            else:
                pred_mse = torch.cat((pred_mse, pred[single_sec, :].unsqueeze(0)), dim=0)
                target_mse = torch.cat((target_mse, y_batch[single_sec, :].unsqueeze(0)), dim=0)
                data_input_mse = torch.cat((data_input_mse, x_batch[single_sec, :].unsqueeze(0)), dim=0)

        loss, mseloss, myloss = constrained_loss(pred_mse, pred_pattern, pred_pattern_mse, data_input_pattern,
                                                 target_mse,
                                                 target_pattern, pattern_list, support_List)

        total_loss += loss.item()

        Mseloss.append(mseloss.item())
        Mseloss = sorted(Mseloss)
        Myloss.append(myloss.item())
        Myloss = sorted(Myloss)

    Mseloss = Mseloss[1:-1]
    Mseloss_sum = sum(Mseloss)
    Myloss = Myloss[1:-1]
    Myloss_sum = sum(Myloss)

print(
    f"totla_loss: {total_loss / len(test_loader):.6f}, pattern_mseloss:{Mseloss_sum / (len(test_loader) - 2):.6f}, original_mseloss:{Myloss_sum / (len(test_loader) - 2):.10f}")

torch.save(model.state_dict(), save_path)
