import torch.nn.functional as F
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report


def set_seed(seed):
    torch.manual_seed(seed)  # 为CPU设置随机种子
    torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多个GPU，也要设置随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python内置随机数生成器的随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作的结果确定
    torch.backends.cudnn.benchmark = False  # 禁用cudnn的自动优化算法选择


set_seed(1234)  # 设置为固定的随机种子，例如42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
dataset = pd.read_csv('./datasets_swat1_afterProcess.csv', encoding='utf-8',
                      low_memory=False)


# print(df)

class Meself_dataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.features)

    def __getitem__(self, item):
        feature = self.features[item]
        label = self.labels[item]
        return feature, label


train_data = []
test_data = []
val_data = []

for label, group in dataset.groupby('Normal/Attack'):
    # 按比例划分
    train_size = int(len(group) * 0.8)
    test_size = int(len(group) * 0.2)

    train, test = train_test_split(group, train_size=train_size, shuffle=True)
    # test, val = train_test_split(test_val, test_size=test_size, shuffle=True)

    train_data.append(train)
    test_data.append(test)
    # val_data.append(val)

# 合并数据
train_set = pd.concat(train_data).reset_index(drop=True)
test_set = pd.concat(test_data).reset_index(drop=True)
# val_set = pd.concat(val_data).reset_index(drop=True)


train_set = train_set.to_numpy()

test_set = test_set.to_numpy()
# val_set = val_set.to_numpy()

print(train_set.shape)
print(type(train_set))

train_set = Meself_dataset(train_set[:, :-1], train_set[:, -1])
test_set = Meself_dataset(test_set[:, :-1], test_set[:, -1])

batch_size = 64

train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, drop_last=True)
# val_loader = DataLoader(dataset=val_set, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True, drop_last=True)


class SimpleCNN(nn.Module):
    def __init__(self, batch=64):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=2, stride=1, padding=1)
        self.conv2 = nn.Conv1d(in_channels=4, out_channels=8, kernel_size=2, stride=1, padding=1)
        self.conv3 = nn.Conv1d(in_channels=8, out_channels=16, kernel_size=2, stride=1, padding=1)
        self.conv4 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=2, stride=1, padding=1)

        self.fc1 = nn.Linear(96, 50)
        self.fc2 = nn.Linear(50, 10)
        self.fc3 = nn.Linear(10, 1)
        self.fc4 = nn.Linear(1, 2)
        self.batch = batch

    def forward(self, x, return_intermediate=False):
        x = F.relu(self.conv1(x))
        x = F.max_pool1d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool1d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool1d(x, 2)
        x = F.relu(self.conv4(x))
        x = F.max_pool1d(x, 2)
        x = x.view(self.batch, -1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        intermediate = self.fc3(x)
        output = self.fc4(intermediate)

        if return_intermediate:
            return intermediate  # 只返回倒数第二层的输出
        return output  # 返回最终输出


model = SimpleCNN()

num_epochs = 50
# 定义K折交叉验证

# 记录每折的准确率
fold_accuracies = []

# 存储每次的准确率和分类结果
accuracies = []
all_y_true = []
all_y_pred = []

model = SimpleCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# criterion = nn.MSELoss()

# 训练模型
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.float().to(device)
        labels = labels.long().to(device)  # 将标签转换为 Long 类型
        inputs = inputs.reshape(batch_size, 1, -1)
        # print("input：", inputs.shape)
        optimizer.zero_grad()
        # 前向传播
        print(inputs.shape)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')

# 验证模型
model.eval()
correct = 0
total = 0
all_labels = []
all_preds = []
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.float().to(device)
        labels = labels.long().to(device)  # 将标签转换为 Long 类型
        inputs = inputs.reshape(batch_size, 1, -1)
        outputs = model(inputs)
        output_intermediate = model(inputs, return_intermediate=True)
        print("output_intermediate", output_intermediate)  # 预期: (64, 10)
        _, predicted = torch.max(outputs.data, 1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
fold_accuracies.append(accuracy)
# 计算每个类别的 precision, recall, f1-score
report = classification_report(all_labels, all_preds, digits=4)
print(f'Classification Report:\n', report)
print(f'Accuracy: {accuracy:.2f}%')

# 假设你的模型实例是 model
torch.save(model.state_dict(), "cnn-model.pth")
print("模型已保存为 cnn-model.pth")
