import os
import numpy as np
import scipy.io
from sklearn.preprocessing import StandardScaler
from FLAiR_ELM_function import FLAiR_ELM_function  

# File to save results
result_file = r"PATH_TO_SAVE_RESULTS"    
with open(result_file, 'w') as result:
    result.write("DataSetName\tBestMeanTrainAccuracy\tBestStdTrainAccuracy\tBestMeanTestAccuracy\t"
                 "BestStdTestAccuracy\tBestMeanTrainTime\tBestStdTrainTime\t"
                 "BestMeanTestTime\tBestStdTestTime\tBest_C\tBest_N\tBest_Activation\tBest_Epoch\n")

# Directory of datasets
directory = r"PATH_TO_YOUR_DATASET_FOLDER"
files = [f for f in os.listdir(directory) if f.endswith('.mat')]
print("Datasets found:", files)

for file in files:
    data_path = os.path.join(directory, file)
    file_data = scipy.io.loadmat(data_path)

    # Identify key
    dataset_key = next((key for key in file_data.keys() if not key.startswith('__')), None)
    if dataset_key is None:
        print(f"No valid dataset key found in {file}. Skipping.")
        continue

    # Load and preprocess data
    all_data = file_data[dataset_key]
    X = StandardScaler().fit_transform(all_data[:, :-1])
    y = all_data[:, -1].astype(int)
    num_classes = len(np.unique(y))
    all_data = np.hstack((X, y.reshape(-1, 1)))



    length_train = all_data.shape[0]
    block_size = length_train // 5

    # Hyperparameters
    C = 10.0 ** np.arange(-5, 6)
    N = np.arange(3, 204, 20)
    activations = range(1, 7)

    best_metrics = {
        'MeanTestAccuracy': 0
    }

    for c in C:
        for n in N:
            for act in activations:
                temp_results = []
                temp_test_accuracy = []

                for part in range(5):
                    t1 = part * block_size
                    t2 = (part + 1) * block_size

                    test_data = all_data[t1:t2, :]
                    train_data = np.vstack((all_data[:t1, :], all_data[t2:, :]))

                    trainX, trainY = train_data[:, :-1], train_data[:, -1]
                    testX, testY = test_data[:, :-1], test_data[:, -1]

                    option = {'C': c, 'N': n, 'activation': act}

                    best_result, _ = FLAiR_ELM_function(trainX, trainY, testX, testY, option, num_classes)
                    warmup_epochs, train_acc, valid_acc, train_time, valid_time = best_result

                    temp_results.append([train_acc, valid_acc, train_time, valid_time, warmup_epochs])
                    temp_test_accuracy.append(valid_acc)

                mean_test_accuracy = np.mean(temp_test_accuracy)
                if mean_test_accuracy > best_metrics['MeanTestAccuracy']:
                    temp_results_np = np.array(temp_results)
                    best_metrics.update({
                        'MeanTrainAccuracy': np.mean(temp_results_np[:, 0]),
                        'StdTrainAccuracy': np.std(temp_results_np[:, 0]),
                        'MeanTestAccuracy': mean_test_accuracy,
                        'StdTestAccuracy': np.std(temp_results_np[:, 1]),
                        'MeanTrainTime': np.mean(temp_results_np[:, 2]),
                        'StdTrainTime': np.std(temp_results_np[:, 2]),
                        'MeanTestTime': np.mean(temp_results_np[:, 3]),
                        'StdTestTime': np.std(temp_results_np[:, 3]),
                        'C': c,
                        'N': n,
                        'Activation': act,
                        'BestEpoch': int(np.round(np.mean(temp_results_np[:, 4])))
                    })

    # Write to result file
    with open(result_file, 'a') as result:
        result.write(f"{file}\t{best_metrics.get('MeanTrainAccuracy', 0):.4f}\t{best_metrics.get('StdTrainAccuracy', 0):.4f}\t"
                     f"{best_metrics.get('MeanTestAccuracy', 0):.4f}\t{best_metrics.get('StdTestAccuracy', 0):.4f}\t"
                     f"{best_metrics.get('MeanTrainTime', 0):.6f}\t{best_metrics.get('StdTrainTime', 0):.6f}\t"
                     f"{best_metrics.get('MeanTestTime', 0):.6f}\t{best_metrics.get('StdTestTime', 0):.6f}\t"
                     f"{best_metrics.get('C', 0.0):.6f}\t{best_metrics.get('N', 0)}\t"
                     f"{best_metrics.get('Activation', 0)}\t{best_metrics.get('BestEpoch', 0)}\n")

    print(f"✅ Completed: {file}")
