import numpy as np
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

def sigmoid(z):
    """
    计算 sigmoid 函数。
    """
    return 1 / (1 + np.exp(-z))

def optimize_l0(X, y, model, max_iter=100, sparsity=0.1, learning_rate=0.01):
    """
    使用梯度近似求解 L0 正则化稀疏优化问题。

    参数:
        X (numpy.ndarray): 特征矩阵。
        y (numpy.ndarray): 标签向量。
        model (sklearn-like): 基于 sklearn 的分类模型。
        max_iter (int): 最大迭代次数。
        sparsity (float): 稀疏性控制参数，取值范围为 0 到 1。
        learning_rate (float): 学习率。

    返回:
        selected_features (numpy.ndarray): 选择的特征索引。
        model (sklearn-like): 训练好的模型。
    """
    n_samples, n_features = X.shape
    # 初始化特征权重
    feature_weights = np.random.randn(n_features)
    
    for iteration in range(max_iter):
        # 计算预测值和残差
        predictions = sigmoid(np.dot(X, feature_weights))  # 预测概率
        residuals = predictions - y  # 残差 (概率 - 实际标签)
        
        # 计算梯度 (基于交叉熵损失)
        gradients = np.dot(X.T, residuals) / n_samples
        feature_weights -= learning_rate * gradients  # 梯度下降更新权重
        
        # 稀疏性约束：仅保留权重绝对值最大的 top-k 特征
        threshold = np.percentile(np.abs(feature_weights), 100 * (1 - sparsity))
        feature_weights[np.abs(feature_weights) < threshold] = 0  # 硬阈值化权重
    
    # 选择非零特征
    selected_features = np.where(feature_weights != 0)[0]
    print(f"Selected features after optimization: {len(selected_features)}")

    # 使用选定特征训练模型
    X_selected = X[:, selected_features]
    model.fit(X_selected, y)
    return selected_features, model

# 数据准备
X, y = make_classification(n_samples=1000, n_features=20, n_informative=5, random_state=42)

# 初始化逻辑回归模型
model = LogisticRegression()

# 执行稀疏特征选择
selected_features, trained_model = optimize_l0(X, y, model, sparsity=0.2, learning_rate=0.01)

# 使用选择的特征进行预测
X_test = X[:, selected_features]
y_pred = trained_model.predict(X_test)

# 评估模型性能
accuracy = accuracy_score(y, y_pred)
print(f"Model accuracy with selected features: {accuracy:.4f}")
