import os
import numpy as np
import scipy.io
from sklearn.preprocessing import StandardScaler
from FLAiR_BLS_Model import FLAiR_BLS_Model # Import the model

# File to save results
result_file = r"PATH_TO_SAVE_RESULTS"
with open(result_file, 'w') as result:
    result.write("DataSetName\tBestMeanTrainAccuracy\tBestStdTrainAccuracy\tBestMeanTestAccuracy\t"
                 "BestStdTestAccuracy\tBestMeanTrainTime\tBestStdTrainTime\t"
                 "BestMeanTestTime\tBestStdTestTime\tBest_C\tBest_N1\tBest_N2\tBest_N3\tBest_N4\tBest_Epoch\n")

# Directory of datasets
directory = r"PATH_TO_YOUR_DATASET_FOLDER"
files = [f for f in os.listdir(directory) if f.endswith('.mat')]
print("Datasets found:", files)

for file in files:
    data_path = os.path.join(directory, file)
    file_data = scipy.io.loadmat(data_path)

    # Extract the dataset key
    dataset_key = next((key for key in file_data.keys() if not key.startswith('__')), None)
    if dataset_key is None:
        print(f"No valid dataset key found in {file}. Skipping.")
        continue

    # Load and preprocess
    all_data = file_data[dataset_key]
    X = StandardScaler().fit_transform(all_data[:, :-1])
    y = all_data[:, -1].astype(int)
    num_classes = len(np.unique(y))
    all_data = np.hstack((X, y.reshape(-1, 1)))

    length_train = all_data.shape[0]
    block_size = length_train // 5

    # Hyperparameters
    C1_list = 10.0 ** np.arange(-5, 6)
    N1_list = np.arange(5, 51, 5)
    N2_list = np.arange(1, 22, 2)
    N3_list = np.arange(5, 106, 10)
    N4_list = [1]

    best_metrics = {'BestMeanTestAccuracy': 0}

    for C1 in C1_list:
        for N1 in N1_list:
            for N2 in N2_list:
                for N3 in N3_list:
                    for N4 in N4_list:
                        temp_results = []
                        for part in range(5):
                            t1 = part * block_size
                            t2 = (part + 1) * block_size
                            test_data = all_data[t1:t2, :]
                            train_data = np.vstack((all_data[:t1, :], all_data[t2:, :]))

                            trainX, trainY = train_data[:, :-1], train_data[:, -1]
                            testX, testY = test_data[:, :-1], test_data[:, -1]

                            option = {'C': C1, 'N': N1, 'NN': N2, 'NNN': N3, 'N4': N4}

                            result = FLAiR_BLS_Model(trainX, trainY, testX, testY, option, num_classes)
                            if result is None:
                                continue

                            warmup_epochs, train_acc, test_acc, train_time, test_time = result
                            temp_results.append([train_acc, test_acc, train_time, test_time, warmup_epochs])

                        if not temp_results:
                            continue

                        temp_array = np.array(temp_results)
                        mean_test_acc = np.mean(temp_array[:, 1])
                        if mean_test_acc > best_metrics['BestMeanTestAccuracy']:
                            best_metrics.update({
                                'BestMeanTrainAccuracy': np.mean(temp_array[:, 0]),
                                'BestStdTrainAccuracy': np.std(temp_array[:, 0]),
                                'BestMeanTestAccuracy': mean_test_acc,
                                'BestStdTestAccuracy': np.std(temp_array[:, 1]),
                                'BestMeanTrainTime': np.mean(temp_array[:, 2]),
                                'BestStdTrainTime': np.std(temp_array[:, 2]),
                                'BestMeanTestTime': np.mean(temp_array[:, 3]),
                                'BestStdTestTime': np.std(temp_array[:, 3]),
                                'Best_Epoch': int(np.round(np.mean(temp_array[:, 4]))),
                                'Best_C': C1,
                                'Best_N1': N1,
                                'Best_N2': N2,
                                'Best_N3': N3,
                                'Best_N4': N4,
                            })

    with open(result_file, 'a') as result:
        result.write(f"{file}\t"
                     f"{best_metrics.get('BestMeanTrainAccuracy', 0):.4f}\t"
                     f"{best_metrics.get('BestStdTrainAccuracy', 0):.4f}\t"
                     f"{best_metrics.get('BestMeanTestAccuracy', 0):.4f}\t"
                     f"{best_metrics.get('BestStdTestAccuracy', 0):.4f}\t"
                     f"{best_metrics.get('BestMeanTrainTime', 0):.6f}\t"
                     f"{best_metrics.get('BestStdTrainTime', 0):.6f}\t"
                     f"{best_metrics.get('BestMeanTestTime', 0):.6f}\t"
                     f"{best_metrics.get('BestStdTestTime', 0):.6f}\t"
                     f"{best_metrics.get('Best_C', 0):.6f}\t"
                     f"{best_metrics.get('Best_N1', 0)}\t"
                     f"{best_metrics.get('Best_N2', 0)}\t"
                     f"{best_metrics.get('Best_N3', 0)}\t"
                     f"{best_metrics.get('Best_N4', 0)}\t"
                     f"{best_metrics.get('Best_Epoch', 0)}\n")

    print(f"✅ Completed processing: {file}")
