import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import math


def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子

# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
window_size = 11  # 滑动窗口
embed_dim = 1  # 嵌入维度 = 特征数

# Crossformer specific hyperparameters
d_model = 64  # 模型维度
n_heads = 4  # 注意力头数
e_layers = 2  # encoder层数
d_ff = 256  # 前馈网络维度
dropout = 0.1  # dropout率
seg_len = 4  # 时间序列分段长度，建议窗口大小的一半左右

batch_size = 256
num_epochs = 100
learning_rate = 1e-3


# ============== 数据集类 ==============
class SWaTDataset(Dataset):
    def __init__(self, features, window_size):
        self.features = features.clone().detach().float()
        self.window_size = window_size  # 完整窗口大小

    def __len__(self):
        return len(self.features) - self.window_size  # 修改长度计算

    def __getitem__(self, idx):
        window = self.features[idx:idx + self.window_size]  # 取完整窗口
        x = window[:-1]  # 前N-1个作为输入
        y = window[-1]  # 最后一个作为目标
        return x, y


# 加载数据
try:
    dataset_swat = pd.read_csv('../../data/output-cnn.txt', encoding='utf-8',
                               low_memory=False)
    feature_cols = dataset_swat.columns[0]
    features = dataset_swat[feature_cols].values
    print("Successfully loaded data")
    print("feature.shape", features.shape)
except Exception as e:
    print(f"Failed to load data: {e}")
    # Create some dummy data for testing
    features = np.random.rand(10000, 1)
    print("Using dummy data with shape:", features.shape)

# 划分训练集和测试集
train_size = int(len(features) * 0.8)  # 80% 作为训练集
train_features, test_features = features[:train_size], features[train_size:]

# 重塑为二维
train_features = train_features.reshape(-1, 1)  # shape=(n_train, 1)
test_features = test_features.reshape(-1, 1)  # shape=(n_test, 1)

# 在数据划分后添加标准化逻辑
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
test_features = scaler.transform(test_features)

train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)

print("train_feature.shape", train_features.shape)
print("test_feature.shape", test_features.shape)

# 创建训练和测试数据集
train_dataset = SWaTDataset(train_features, window_size)
test_dataset = SWaTDataset(test_features, window_size)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# ============== Crossformer 辅助组件 ==============
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]


class TimeFeatureEmbedding(nn.Module):
    def __init__(self, d_model, embed_type='fixed'):
        super(TimeFeatureEmbedding, self).__init__()
        self.embed = nn.Linear(1, d_model)

    def forward(self, x):
        # x: [B, T, 1]
        return self.embed(x)


class CrossAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        assert self.head_dim * n_heads == d_model, "d_model must be divisible by n_heads"

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        # Linear projections and reshape
        q = self.query(q).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.key(k).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.value(v).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention weights
        context = torch.matmul(attn_weights, v)

        # Reshape and concat heads
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # Final linear projection
        output = self.proj(context)

        return output


class SegmentationAttention(nn.Module):
    def __init__(self, d_model, seg_len, n_heads, dropout=0.1):
        super(SegmentationAttention, self).__init__()
        self.seg_len = seg_len
        self.attention = CrossAttention(d_model, n_heads, dropout)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()

        # Ensure seq_len is divisible by seg_len
        pad_len = (self.seg_len - seq_len % self.seg_len) % self.seg_len
        if pad_len > 0:
            x = F.pad(x, (0, 0, 0, pad_len))
            seq_len = x.size(1)

        # Reshape to segments
        n_segments = seq_len // self.seg_len
        x_reshaped = x.view(batch_size, n_segments, self.seg_len, d_model)

        # Process each segment with cross-attention
        output = []
        for i in range(n_segments):
            segment = x_reshaped[:, i, :, :]  # [B, seg_len, d_model]
            out = self.attention(segment, segment, segment)
            output.append(out)

        # Concatenate segments back
        output = torch.cat(output, dim=1)

        # Remove padding if needed
        if pad_len > 0:
            output = output[:, :seq_len - pad_len, :]

        return output


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x


class CrossformerEncoderLayer(nn.Module):
    def __init__(self, d_model, seg_len, n_heads, d_ff, dropout=0.1):
        super(CrossformerEncoderLayer, self).__init__()
        self.segment_attention = SegmentationAttention(d_model, seg_len, n_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.cross_attention = CrossAttention(d_model, n_heads, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.feedforward = FeedForward(d_model, d_ff, dropout)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Segmentation attention (intra-segment)
        residual = x
        x = self.segment_attention(x)
        x = self.norm1(residual + self.dropout(x))

        # Cross attention (inter-segment)
        residual = x
        x = self.cross_attention(x, x, x)
        x = self.norm2(residual + self.dropout(x))

        # Feed forward
        residual = x
        x = self.feedforward(x)
        x = self.norm3(residual + self.dropout(x))

        return x


# ============== Crossformer 模型 ==============
class CrossformerModel(nn.Module):
    def __init__(self, input_dim, d_model, n_heads, e_layers, d_ff, seg_len, output_dim=1, dropout=0.1):
        super(CrossformerModel, self).__init__()

        # Input embedding
        self.embedding = TimeFeatureEmbedding(d_model)

        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)

        # Encoder layers
        self.encoder_layers = nn.ModuleList([
            CrossformerEncoderLayer(d_model, seg_len, n_heads, d_ff, dropout)
            for _ in range(e_layers)
        ])

        # Output layer
        self.output_layer = nn.Linear(d_model, output_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        x: 输入张量，形状 (batch_size, seq_len, input_dim)
        """
        # Input embedding
        x = self.embedding(x)  # [B, T, d_model]

        # Add positional encoding
        x = self.pos_encoder(x)

        # Apply encoder layers
        for enc_layer in self.encoder_layers:
            x = enc_layer(x)

        # Output projection
        output = self.output_layer(x)  # [B, T, output_dim]

        return output


# 创建Crossformer模型
model = CrossformerModel(
    input_dim=1,
    d_model=d_model,
    n_heads=n_heads,
    e_layers=e_layers,
    d_ff=d_ff,
    seg_len=seg_len,
    dropout=dropout
).to(device)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch = x_batch.reshape(x_batch.shape[0], window_size - 1, 1)  # (batch_size, seq_len, input_dim)
        y_batch = y_batch.reshape(y_batch.shape[0], 1)  # (batch_size, output_dim)

        # 前向传播
        optimizer.zero_grad()
        outputs = model(x_batch)
        pred = outputs[:, -1, :]  # 取最后一个时间步的输出

        # 计算损失
        loss = criterion(pred, y_batch)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # 打印每个epoch的平均损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.10f}")

print("Training complete.")

# 评估模型
model.eval()
test_loss = 0.0

with torch.no_grad():
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        x_test = x_test.reshape(x_test.shape[0], window_size - 1, 1)
        y_test = y_test.reshape(y_test.shape[0], 1)

        # 前向传播
        outputs = model(x_test)
        pred = outputs[:, -1, :]

        # 计算损失
        loss = criterion(pred, y_test)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.10f}")