import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import math


def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子

# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 超参数
window_size = 7  # 滑动窗口
embed_dim = 16  # 嵌入维度
d_model = 16  # Autoformer中的模型维度 (减小以避免过拟合)
n_heads = 4  # 注意力头数量 (减小以适应小模型维度)
e_layers = 2  # 编码器层数
d_layers = 1  # 解码器层数
d_ff = 32  # 前馈网络维度 (减小以适应较小的模型)
moving_avg = 3  # 移动平均窗口大小
dropout = 0.1  # Dropout比率
activation = 'gelu'  # 激活函数

batch_size = 256
num_epochs = 100
learning_rate = 1e-3


# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 加载数据
try:
    dataset_swat = pd.read_csv('../../data/output-cnn.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # Create some dummy data for testing
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)

# 划分训练集和测试集
train_size = int(len(features) * 0.8)  # 80% 作为训练集
train_features, test_features = features[:train_size], features[train_size:]

# 重塑为二维
train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)

# 在数据划分后添加标准化逻辑
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
test_features = scaler.transform(test_features)

train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# ============== Autoformer模型组件 ==============

class SeriesDecomp(nn.Module):
    """时间序列分解"""

    def __init__(self, kernel_size):
        super(SeriesDecomp, self).__init__()
        self.moving_avg = nn.AvgPool1d(kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2)

    def forward(self, x):
        """
        x: [Batch, Length, Channel]
        """
        # 转换为 [Batch, Channel, Length]
        batch_size, length, channels = x.shape
        x_transposed = x.transpose(1, 2)

        # 应用移动平均
        trend = self.moving_avg(x_transposed)

        # 确保trend的长度与x相同
        if trend.shape[2] != length:
            padding = length - trend.shape[2]
            trend = F.pad(trend, (0, padding))

        # 转换回原始形状 [Batch, Length, Channel]
        trend = trend.transpose(1, 2)

        # 计算季节性成分
        seasonal = x - trend

        return seasonal, trend


class AutoCorrelation(nn.Module):
    """自相关层"""

    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(AutoCorrelation, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.dropout = nn.Dropout(attention_dropout)
        self.output_attention = output_attention

    def time_delay_agg_training(self, values, corr):
        """
        训练时的时延聚合
        """
        batch, head, length, d_k = values.shape
        # Ensure top_k is at least 1 and not larger than sequence length
        top_k = min(max(1, int(self.factor * math.log(length))), length)

        # Reshape and compute correlation scores
        corr_flat = corr.reshape(batch * head, length, length)

        # Get top-k indices for each batch and head
        scores = torch.mean(corr_flat, dim=1)  # [batch*head, length]
        top_k_indices = torch.topk(scores, k=top_k, dim=-1)[1]  # [batch*head, top_k]

        # Gather values using top-k indices
        batch_indices = torch.arange(batch * head).view(-1, 1).to(values.device)
        batch_indices = batch_indices.repeat(1, top_k)  # [batch*head, top_k]

        # Reshape values for gathering
        values_flat = values.reshape(batch * head, length, d_k)

        # Gather top-k values
        top_k_values = values_flat[batch_indices.flatten(), top_k_indices.flatten()].view(batch * head, top_k, d_k)

        # Calculate attention weights
        weights = torch.softmax(scores.gather(-1, top_k_indices), dim=-1).unsqueeze(-1)  # [batch*head, top_k, 1]

        # Apply weights to values
        weighted_values = weights * top_k_values  # [batch*head, top_k, d_k]
        output = weighted_values.sum(dim=1).view(batch, head, d_k).unsqueeze(2).repeat(1, 1, length, 1)

        return output

    def forward(self, queries, keys, values):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape

        # 重塑形状以进行批处理计算
        queries = queries.transpose(1, 2).reshape(B * H, L, E)
        keys = keys.transpose(1, 2).reshape(B * H, S, E)
        values = values.transpose(1, 2).reshape(B * H, S, D)

        # 计算自相关而不是FFT (简化以避免数值问题)
        q_norm = queries / (torch.norm(queries, dim=-1, keepdim=True) + 1e-8)
        k_norm = keys / (torch.norm(keys, dim=-1, keepdim=True) + 1e-8)
        corr = torch.bmm(q_norm, k_norm.transpose(-1, -2))

        # 应用Dropout
        corr = self.dropout(corr)

        # 计算加权和
        corr = torch.softmax(corr, dim=-1)
        output = torch.bmm(corr, values).reshape(B, H, L, D)

        # 转换回原始形状
        output = output.transpose(1, 2).reshape(B, L, H * D)

        return output


class AutoCorrelationLayer(nn.Module):
    """带有投影的自相关层"""

    def __init__(self, d_model, n_heads, d_keys=None, d_values=None):
        super(AutoCorrelationLayer, self).__init__()
        d_keys = d_keys or d_model // n_heads
        d_values = d_values or d_model // n_heads

        self.inner_correlation = AutoCorrelation(mask_flag=False, factor=5, attention_dropout=0.1)

        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out = self.inner_correlation(queries, keys, values)
        out = self.out_projection(out)

        return out


class EncoderLayer(nn.Module):
    """Autoformer编码器层"""

    def __init__(self, d_model, n_heads, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = AutoCorrelationLayer(d_model, n_heads)
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

        if activation == "relu":
            self.activation = F.relu
        elif activation == "gelu":
            self.activation = F.gelu

    def forward(self, x):
        new_x = self.norm1(x)
        x = x + self.dropout(self.attention(new_x, new_x, new_x))

        y = self.norm2(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return x + y


class Encoder(nn.Module):
    """Autoformer编码器"""

    def __init__(self, e_layers, d_model, n_heads, d_ff, dropout, activation):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout, activation)
            for _ in range(e_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x


class DecoderLayer(nn.Module):
    """Autoformer解码器层"""

    def __init__(self, d_model, n_heads, d_ff=None, dropout=0.1, activation="relu"):
        super(DecoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.self_attention = AutoCorrelationLayer(d_model, n_heads)
        self.cross_attention = AutoCorrelationLayer(d_model, n_heads)
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

        if activation == "relu":
            self.activation = F.relu
        elif activation == "gelu":
            self.activation = F.gelu

    def forward(self, x, cross):
        new_x = self.norm1(x)
        x = x + self.dropout(self.self_attention(new_x, new_x, new_x))

        new_x = self.norm2(x)
        x = x + self.dropout(self.cross_attention(new_x, cross, cross))

        y = self.norm3(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return x + y


class Decoder(nn.Module):
    """Autoformer解码器"""

    def __init__(self, d_layers, d_model, n_heads, d_ff, dropout, activation):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout, activation)
            for _ in range(d_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, cross):
        for layer in self.layers:
            x = layer(x, cross)
        x = self.norm(x)
        return x


class Autoformer(nn.Module):
    """Autoformer完整模型"""

    def __init__(self, input_dim, d_model, n_heads, e_layers, d_layers, d_ff, moving_avg, dropout, activation,
                 output_dim=1):
        super(Autoformer, self).__init__()

        # 输入投影
        self.input_embedding = nn.Linear(input_dim, d_model)

        # 时间序列分解
        self.decomp = SeriesDecomp(moving_avg)

        # 编码器
        self.encoder = Encoder(e_layers, d_model, n_heads, d_ff, dropout, activation)

        # 解码器
        self.decoder = Decoder(d_layers, d_model, n_heads, d_ff, dropout, activation)

        # 输出投影
        self.projection = nn.Linear(d_model, output_dim)

        # 序列长度
        self.seq_len = None

    def forward(self, x):
        """
        x: 输入张量，形状 (batch_size, seq_len, input_dim)
        """
        self.seq_len = x.shape[1]

        # 输入嵌入
        x = self.input_embedding(x)

        # 编码器处理
        enc_out = self.encoder(x)

        # 解码器处理
        dec_out = self.decoder(x, enc_out)

        # 输出预测
        output = self.projection(dec_out)

        return output


# 创建Autoformer模型
model = Autoformer(
    input_dim=1,
    d_model=d_model,
    n_heads=n_heads,
    e_layers=e_layers,
    d_layers=d_layers,
    d_ff=d_ff,
    moving_avg=moving_avg,
    dropout=dropout,
    activation=activation,
    output_dim=1
).to(device)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)  # (batch_size, seq_len, input_dim)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)  # (batch_size, output_dim)

        # 前向传播
        optimizer.zero_grad()
        outputs = model(x_batch)
        pred = outputs[:, -1, :]  # 取最后一个时间步的输出

        # 计算损失
        loss = criterion(pred, y_batch)

        # 反向传播和优化
        loss.backward()

        # 梯度裁剪，防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item()

    # 打印每个epoch的平均损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs = model(x_test)
        pred = outputs[:, -1, :]

        # 计算损失
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")