import os
import pickle
import argparse
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np


# 创建一个简单的数据集
class SimpleTrainDataset(Dataset):
    def __init__(self, emb_fold):
        self.data, self.label = [], []
        self.data_dim = None
        emb_files = os.listdir(emb_fold)
        emb_files = sorted(emb_files)
        for emb_file in emb_files:
            with open(os.path.join(emb_fold, emb_file), "rb") as f:
                data = pickle.load(f)
                if self.data_dim is None:
                    self.data_dim = data[0].view(-1).size()[-1]
                self.data.append(data[0].view(-1))
                self.label.append(1)
                self.data.append(data[1].view(-1))
                self.label.append(0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]


class SimpleTestDataset(Dataset):
    def __init__(self, emb_fold):
        self.data, self.label = [], []
        self.data_dim = None
        emb_files = os.listdir(emb_fold)
        emb_files = sorted(emb_files)
        for emb_file in emb_files:
            with open(os.path.join(emb_fold, emb_file), "rb") as f:
                print(os.path.join(emb_fold, emb_file))
                data = pickle.load(f)
                if self.data_dim is None:
                    self.data_dim = data[0].view(-1).size()[-1]
                self.data.append(data[0].view(-1))
                self.label.append(1)
                self.data.append(data[1].view(-1))
                self.label.append(0)
                # self.data.append(data[2].view(-1))
                # self.label.append(0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]


# 定义单层MLP模型
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleMLP, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)


# 训练函数
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    _tqdm = tqdm(total=len(dataloader))
    for inputs, labels in dataloader:
        inputs, labels = inputs.float().to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        # 计算交叉熵损失
        ce_loss = criterion(outputs, labels)

        # 计算L2正则化项
        l2_reg = torch.tensor(0.).to(device)
        for param in model.parameters():
            l2_reg += torch.norm(param, 2)

        # 总损失 = 交叉熵损失 + lambda * L2正则化项
        loss = ce_loss + 0.1 * l2_reg

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        _tqdm.update(1)
    _tqdm.close()

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


# 测试函数
def test(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0

    _tqdm = tqdm(total=len(dataloader))
    for inputs, labels in dataloader:
        inputs, labels = inputs.float().to(device), labels.to(device)

        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        _tqdm.update(1)
    _tqdm.close()

    epoch_acc = correct / total
    return epoch_acc


# 主函数
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size')
    parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
    parser.add_argument('--train_emb_fold', type=str)
    parser.add_argument('--test_emb_fold', type=str)
    args = parser.parse_args()
    # 设置随机种子以确保可重复性
    torch.manual_seed(42)
    np.random.seed(42)

    # 设置设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # 创建数据集和数据加载器
    train_ds = SimpleTrainDataset(args.train_emb_fold)
    test_ds = SimpleTestDataset(args.test_emb_fold)
    train_dataloader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
    test_dataloader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False)

    # 初始化模型
    input_dim = train_ds.data_dim
    output_dim = 2  # 二分类问题
    model = SimpleMLP(input_dim, output_dim).to(device)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # 训练模型
    for epoch in range(args.epochs):
        epoch_loss, epoch_acc = train(model, train_dataloader, criterion, optimizer, device)
        print(f"Epoch [{epoch+1}/{args.epochs}]")
        print(f"Train: Loss {epoch_loss:.4f}, Accuracy {epoch_acc*100:.4f}%")
        # 测试模型
        test_acc = test(model, test_dataloader, device)
        print(f"Test Accuracy: {test_acc*100:.4f}%")


if __name__ == "__main__":
    main()
