import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import os
from pathlib import Path

# 设置随机种子以确保可重复性
torch.manual_seed(42)
np.random.seed(42)

# 检查是否有可用的GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 加载数据
def load_data_from_folders(main_folder_path):
    """
    从主文件夹下的各个子文件夹中加载嵌入向量和标签数据
    每个子文件夹包含 embedding.npy 和 vector.npy 文件
    """
    embeddings_list = []
    labels_list = []
    
    # 获取主文件夹下的所有子文件夹
    main_path = Path(main_folder_path)
    subfolders = [f for f in main_path.iterdir() if f.is_dir()]
    
    print(f"找到 {len(subfolders)} 个子文件夹")
    
    # 遍历每个子文件夹
    for folder in subfolders:
        embedding_file = folder / "embedding.npy"
        vector_file = folder / "vector.npy"
        
        # 检查文件是否存在
        if not embedding_file.exists():
            print(f"警告: {embedding_file} 不存在，跳过此文件夹")
            continue
            
        if not vector_file.exists():
            print(f"警告: {vector_file} 不存在，跳过此文件夹")
            continue
        
        # 加载嵌入向量
        embedding = np.load(embedding_file)
        
        # 加载标签向量并提取第一个数字作为标签
        vector_data = np.load(vector_file)
        label = vector_data[0] if len(vector_data.shape) == 1 else vector_data[:, 0][0]
        
        embeddings_list.append(embedding)
        labels_list.append(label)
    
    # 转换为数组
    embeddings = np.array(embeddings_list)
    labels = np.array(labels_list)
    
    print(f"嵌入向量形状: {embeddings.shape}")
    print(f"标签形状: {labels.shape}")
    print(f"标签示例: {labels[:10]}")
    print(f"标签范围: [{np.min(labels)}, {np.max(labels)}]")
    
    return embeddings, labels

# 定义MLP回归模型
class MLPRegressor(nn.Module):
    def __init__(self, input_size=256, hidden_layers=[128, 64], dropout_rate=0.3):
        super(MLPRegressor, self).__init__()
        
        layers = []
        prev_size = input_size
        
        # 添加隐藏层
        for hidden_size in hidden_layers:
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            prev_size = hidden_size
        
        # 输出层 - 回归任务只需要一个输出节点
        layers.append(nn.Linear(prev_size, 1))
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

# 训练函数
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=50):
    train_losses = []
    val_losses = []
    val_maes = []
    
    # 添加早停机制
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0
    
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
        
        epoch_train_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_train_loss)
        
        # 验证阶段
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                
                all_preds.extend(outputs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        epoch_val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(epoch_val_loss)
        
        # 计算MAE
        epoch_val_mae = mean_absolute_error(all_labels, all_preds)
        val_maes.append(epoch_val_mae)
        
        # 早停检查
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            patience_counter = 0
            # 保存最佳模型
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"早停在第 {epoch+1} 轮")
            break
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, '
                  f'Train Loss: {epoch_train_loss:.4f}, '
                  f'Val Loss: {epoch_val_loss:.4f}, '
                  f'Val MAE: {epoch_val_mae:.4f}')
    
    # 加载最佳模型
    model.load_state_dict(torch.load('best_model.pth'))
    
    return train_losses, val_losses, val_maes

# 主函数
def main():
    # 主文件夹路径 - 请替换为您的实际路径
    main_folder_path = '../Representation/inference_results'  # 替换为您的文件夹路径
    
    # 加载数据
    embeddings, labels = load_data_from_folders(main_folder_path)
    
    # 分割数据集
    X_train, X_test, y_train, y_test = train_test_split(
        embeddings, labels, test_size=0.2, random_state=42
    )
    
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train, test_size=0.1, random_state=42
    )
    
    print(f"训练集大小: {X_train.shape[0]}")
    print(f"验证集大小: {X_val.shape[0]}")
    print(f"测试集大小: {X_test.shape[0]}")
    
    # 转换为PyTorch张量
    X_train_tensor = torch.FloatTensor(X_train)
    y_train_tensor = torch.FloatTensor(y_train)
    X_val_tensor = torch.FloatTensor(X_val)
    y_val_tensor = torch.FloatTensor(y_val)
    X_test_tensor = torch.FloatTensor(X_test)
    y_test_tensor = torch.FloatTensor(y_test)
    
    # 创建数据加载器
    batch_size = min(32, len(X_train))  # 如果数据量小，使用较小的批次大小
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # 初始化模型
    model = MLPRegressor(input_size=256, hidden_layers=[128, 64]).to(device)
    
    # 定义损失函数和优化器 - 使用MSE损失进行回归
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    
    # 训练模型
    print("开始训练模型...")
    train_losses, val_losses, val_maes = train_model(
        model, train_loader, val_loader, criterion, optimizer, num_epochs=100
    )
    
    # 绘制训练曲线
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='训练损失 (MSE')
    plt.plot(val_losses, label='验证损失 (MSE)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(val_maes, label='验证MAE')
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_curves.png')
    
    # 在测试集上评估模型
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
            
            outputs = model(inputs)
            
            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # 计算回归指标
    mse = mean_squared_error(all_labels, all_preds)
    mae = mean_absolute_error(all_labels, all_preds)
    r2 = r2_score(all_labels, all_preds)
    
    print(f"测试集MSE: {mse:.4f}")
    print(f"测试集MAE: {mae:.4f}")
    print(f"测试集R²: {r2:.4f}")
    
    # 绘制预测值与真实值的散点图
    plt.figure(figsize=(8, 6))
    plt.scatter(all_labels, all_preds, alpha=0.5)
    plt.plot([min(all_labels), max(all_labels)], [min(all_labels), max(all_labels)], 'r--')
    plt.xlabel('真实值')
    plt.ylabel('预测值')
    plt.title('预测值与真实值对比')
    plt.savefig('predictions_vs_actuals.png')
    
    # 保存最终模型
    torch.save(model.state_dict(), 'mlp_regressor.pth')
    print("模型已保存为 'mlp_regressor.pth'")

if __name__ == "__main__":
    main()