# main.py
import numpy as np
import tensorflow as tf
import os
from sklearn.preprocessing import StandardScaler
import pandas as pd
from sklearn.model_selection import train_test_split
import time
import random
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder

# select.py

# Hard Thresholding Gradient Descent with SVRG
class HardThresholdingSVRG:
    def __init__(self, top_k, learning_rate=0.001, n_iter=200, batch_size=128):
        self.top_k = top_k
        self.learning_rate = learning_rate
        self.n_iter = n_iter
        self.batch_size = batch_size
        self.coef_ = None

    def fit(self, X, y, theta):
        n_samples, n_features = X.shape
        self.coef_ = np.zeros(n_features)
        for i in range(self.n_iter):
            # Compute the full gradient at the current coefficients
            full_gradient = -2 * X.T @ (y - X @ self.coef_) / n_samples
            coef_snapshot = self.coef_.copy()

            # SVRG gradient update
            for _ in range(n_samples // self.batch_size):
                batch_indices = random.sample(range(n_samples), self.batch_size)
                X_batch = X[batch_indices]
                y_batch = y[batch_indices]
                gradient = -2 * X_batch.T @ (y_batch - X_batch @ self.coef_) / self.batch_size
                gradient = np.nan_to_num(gradient)
                snapshot_gradient = -2 * X_batch.T @ (y_batch - X_batch @ coef_snapshot) / self.batch_size
                svrg_gradient = (1/theta)*(gradient - snapshot_gradient) + full_gradient
                self.coef_ -= self.learning_rate * svrg_gradient
            # Hard thresholding step
            self.hard_threshold()

    def hard_threshold(self):
        threshold = np.percentile(np.abs(self.coef_), (100 - self.top_k))
        self.coef_ = np.where(np.abs(self.coef_) >= threshold, self.coef_, 0)

    def transform(self, X):
        # Select non-zero features based on the learned coefficients
        mask = self.coef_ != 0
        return X[:, mask]

# Hard Thresholding Gradient Descent with SARAH
class HardThresholdingSARAH:
    def __init__(self, top_k, learning_rate=0.001, n_iter=200, batch_size=128):
        self.top_k = top_k
        self.learning_rate = learning_rate
        self.n_iter = n_iter
        self.batch_size = batch_size
        self.coef_ = None

    def fit(self, X, y):  # SARAH implementation
        n_samples, n_features = X.shape
        self.coef_ = np.zeros(n_features)
        for i in range(self.n_iter):
            # Take a snapshot of the current coefficients
            coef_snapshot = self.coef_.copy()
        # Compute the initial full gradient at the snapshot
                # Compute the initial full gradient at the snapshot
            v = -2 * X.T @ (y - X @ coef_snapshot) / n_samples

            for _ in range(n_samples // self.batch_size):
                batch_indices = random.sample(range(n_samples), self.batch_size)
                X_batch = X[batch_indices]
                y_batch = y[batch_indices]
                gradient = -2 * X_batch.T @ (y_batch - X_batch @ self.coef_) / self.batch_size
                gradient = np.nan_to_num(gradient)
                snapshot_gradient = -2 * X_batch.T @ (y_batch - X_batch @ coef_snapshot) / self.batch_size
                v = gradient - snapshot_gradient + v
                self.coef_ -= self.learning_rate * v
                            # Hard thresholding step
            self.hard_threshold()

    def hard_threshold(self):
        threshold = np.percentile(np.abs(self.coef_), (100 - self.top_k))
        self.coef_ = np.where(np.abs(self.coef_) >= threshold, self.coef_, 0)

    def transform(self, X):
        # Select non-zero features based on the learned coefficients
        mask = self.coef_ != 0
        return X[:, mask]

# Feature selection function that takes the algorithm name and applies the corresponding method
def feature_selection(algorithm, X_train, y_train, X_test, k=30):
    if algorithm == 'BVRSZHTn':
        top_k = k
        htgd_svrg = HardThresholdingSVRG(top_k=top_k, learning_rate=0.001, n_iter=200, batch_size=128)
        htgd_svrg.fit(X_train, y_train, X_train.shape[0])
        X_train_selected = htgd_svrg.transform(X_train)
        X_test_selected = htgd_svrg.transform(X_test)
    elif algorithm == 'BVRSZHT12':
        top_k = k
        htgd_svrg = HardThresholdingSVRG(top_k=top_k, learning_rate=0.001, n_iter=200, batch_size=128)
        htgd_svrg.fit(X_train, y_train, 2)
        X_train_selected = htgd_svrg.transform(X_train)
        X_test_selected = htgd_svrg.transform(X_test)
    elif algorithm == 'VRSZHT':
        top_k = k
        htgd_svrg = HardThresholdingSVRG(top_k=top_k, learning_rate=0.001, n_iter=200, batch_size=128)
        htgd_svrg.fit(X_train, y_train, 1)
        X_train_selected = htgd_svrg.transform(X_train)
        X_test_selected = htgd_svrg.transform(X_test)
    elif algorithm == 'SARAH':
        top_k = k
        htgd_svrg = HardThresholdingSARAH(top_k=top_k, learning_rate=0.001, n_iter=200, batch_size=128)
        htgd_svrg.fit(X_train, y_train)
        X_train_selected = htgd_svrg.transform(X_train)
        X_test_selected = htgd_svrg.transform(X_test)
    elif algorithm == 'SAGA':
        saga = LogisticRegression(penalty='l1', solver='saga', C=0.8, max_iter=300, tol=1e-4, random_state=42)
        saga.fit(X_train, y_train)

        # Get the non-zero features selected by SAGA
        coef_saga = saga.coef_.ravel()[:X_train.shape[1]]
        mask_saga = coef_saga != 0
        X_train_selected = X_train[:, mask_saga]
        X_test_selected = X_test[:, mask_saga]
    else:
        raise ValueError(f"Unknown algorithm: {algorithm}")

    return X_train_selected, X_test_selected

# Step 1: Load scRNA-seq dataset

print("Loading scRNA-seq dataset...")
data = pd.read_csv('/home/xinzhe/data/GSE81861_Cell_Line_COUNT.csv', index_col=0)  # Update with actual file path

# Extract labels from column names and create y labels
data = data.T  # Transpose so that rows are samples
labels = [col.split('__')[1] for col in data.index]  # Extract cell line from column name
le = LabelEncoder()
y = le.fit_transform(labels)  # Encode labels as integers
X = data.values.astype(float)  # Ensure X is float type

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Normalize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Define algorithms to test
algorithms = ['BVRSZHTn', 'SAGA', 'BVRSZHT12', 'VRSZHT', 'SARAH']
results = []

# Loop over each algorithm and perform feature selection, training, and evaluation
for algorithm in algorithms:
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # Use GPU 2 for all algorithms
    print(f"Applying feature selection with {algorithm}...")
    start_time = time.time()
    X_train_selected, X_test_selected = feature_selection(algorithm, X_train, y_train, X_test, k=5)
    selection_time = time.time() - start_time
    print(f"Feature selection for {algorithm} took {selection_time:.2f} seconds")

    # Record the number of selected features
    num_features = X_train_selected.shape[1]

    # Train a DNN classifier for the selected features
    print(f"Training DNN classifier for {algorithm}-selected features...")
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(X_train_selected.shape[1],)),
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(len(np.unique(y_train)), activation='softmax')
    ])

    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Train the model
    model.fit(X_train_selected, y_train, epochs=20, batch_size=64, validation_split=0.1, verbose=2)

    # Evaluate the model
    test_loss, test_acc = model.evaluate(X_test_selected, y_test, verbose=0)
    print(f"Test accuracy for {algorithm}-selected features: {test_acc:.2f}")

    # Record results
    results.append({'Algorithm': algorithm, 'Accuracy': test_acc, 'Num_Features': num_features, 'Selection_Time (s)': selection_time})

# Create a DataFrame to store the results
results_df = pd.DataFrame(results)
print(results_df)

# Save the results to a CSV file
results_df.to_csv('feature_selection_results.csv', index=False)
