import os
import numpy as np
import scipy.io
from sklearn.preprocessing import StandardScaler
from FLAiR_dRVFL_model import train_and_evaluate_FLAiR_drvfl as train_and_evaluate_drvfl

# Configuration
result_path = r"PATH_TO_SAVE_RESULTS" 
data_dir = r"PATH_TO_YOUR_DATASET_FOLDER"
no_part = 5

C_range = 10.0 ** np.arange(-5, 6)
L_range = list(range(1, 11))
scale_range = 2.0 ** np.arange(-1, 2.1, 0.5)
percent_range = [0.9, 0.95, 1, 1.05, 1.1]
activation_list = [1, 2, 3]
warmup_epochs_list = [2, 4, 6, 8, 10]

with open(result_path, 'w') as result_file:
    result_file.write("DataSetName\tBestMeanTrainAccuracy\tBestStdTrainAccuracy\tBestMeanTestAccuracy\tBestStdTestAccuracy\t"
                      "MeanSensitivity\tStdSensitivity\tMeanSpecificity\tStdSpecificity\tMeanPrecision\tStdPrecision\t"
                      "MeanF_measure\tStdF_measure\tMeanGmean\tStdGmean\tBestMeanTrainTime\tBestStdTrainTime\t"
                      "BestMeanTestTime\tBestStdTestTime\tBest_C\tBest_s\tBest_N\tBest_L\tBest_act\tBest_Epochs\n")

mat_files = [f for f in os.listdir(data_dir) if f.endswith(".mat")]

for file_idx, file in enumerate(mat_files):
    print(f"\n🔍 Processing dataset {file_idx + 1}/{len(mat_files)}: {file}")

    data = scipy.io.loadmat(os.path.join(data_dir, file))
    key = list(data.keys())[-1]
    all_data = data[key]

    X_raw = all_data[:, :-1]
    y_raw = all_data[:, -1].astype(int)
    y_raw[y_raw == -1] = 0

    scaler = StandardScaler()
    X = scaler.fit_transform(X_raw)

    classes = np.unique(y_raw)
    nclass = len(classes)
    m = X.shape[0]
    data_all = np.hstack((X, y_raw.reshape(-1, 1)))

    N_range = [256, 512, 1024] if m < 1000 else [1024, 2048, 4096]
    best_accuracy = 0
    best_result = None
    best_params = {}

    # Step 1: Coarse tuning
    print("⚙️ Step 1: Coarse tuning (C, N)")
    for N in N_range:
        for C in C_range:
            results = []
            for part in range(no_part):
                size = m // no_part
                idx1 = part * size
                idx2 = (part + 1) * size

                test_data = data_all[idx1:idx2]
                train_data = np.vstack((data_all[:idx1], data_all[idx2:]))

                X_train, y_train = train_data[:, :-1], train_data[:, -1].astype(int)
                X_test, y_test = test_data[:, :-1], test_data[:, -1].astype(int)

                y_train_oh = np.eye(nclass)[y_train]
                y_test_oh = np.eye(nclass)[y_test]

                opt = {
                    "N": N, "C": C, "activation": 2, "L": 2,
                    "scale": 1, "renormal": 1, "normal_type": 0,
                }

                _, train_eval, test_eval, train_time, test_time = train_and_evaluate_drvfl(
                    X_train, y_train_oh, X_test, y_test_oh, opt, warmup_epochs=5, lr=0.001
                )

                results.append(train_eval[:1] + test_eval + [train_time, test_time])

            results = np.array(results)
            mean_acc = results[:, 1].mean()
            if mean_acc > best_accuracy:
                best_accuracy = mean_acc
                best_result = results
                best_params = opt.copy()
                best_params['warmup_epochs'] = 5

    # Step 2: Fine tuning
    print("🔧 Step 2: Fine tuning (L, scale, C, N, Epochs)")
    for L in L_range:
        for Np in percent_range:
            N_val = int(np.ceil(best_params["N"] * Np))
            for Cp in percent_range:
                C_val = best_params["C"] * Cp
                for act in activation_list:
                    for scale in scale_range:
                        for warmup_epochs in warmup_epochs_list:
                            results = []
                            for part in range(no_part):
                                size = m // no_part
                                idx1 = part * size
                                idx2 = (part + 1) * size

                                test_data = data_all[idx1:idx2]
                                train_data = np.vstack((data_all[:idx1], data_all[idx2:]))

                                X_train, y_train = train_data[:, :-1], train_data[:, -1].astype(int)
                                X_test, y_test = test_data[:, :-1], test_data[:, -1].astype(int)

                                y_train_oh = np.eye(nclass)[y_train]
                                y_test_oh = np.eye(nclass)[y_test]

                                opt = {
                                    "N": N_val, "C": C_val, "activation": act, "L": L,
                                    "scale": scale, "renormal": 1, "normal_type": 0,
                                }

                                _, train_eval, test_eval, train_time, test_time = train_and_evaluate_drvfl(
                                    X_train, y_train_oh, X_test, y_test_oh, opt,
                                    warmup_epochs=warmup_epochs, lr=0.001
                                )

                                results.append(train_eval[:1] + test_eval + [train_time, test_time])

                            results = np.array(results)
                            mean_acc = results[:, 1].mean()
                            if mean_acc >= best_accuracy:
                                best_accuracy = mean_acc
                                best_result = results
                                best_params = opt.copy()
                                best_params['warmup_epochs'] = warmup_epochs

    # Save best result
    if best_result is not None:
        mean_vals = best_result.mean(axis=0)
        std_vals = best_result.std(axis=0)
        with open(result_path, 'a') as result_file:
            result_file.write(f"{file}\t{mean_vals[0]:.6f}\t{std_vals[0]:.6f}\t{mean_vals[1]:.6f}\t{std_vals[1]:.6f}\t"
                              f"{mean_vals[2]:.6f}\t{std_vals[2]:.6f}\t{mean_vals[3]:.6f}\t{std_vals[3]:.6f}\t"
                              f"{mean_vals[4]:.6f}\t{std_vals[4]:.6f}\t{mean_vals[5]:.6f}\t{std_vals[5]:.6f}\t"
                              f"{mean_vals[6]:.6f}\t{std_vals[6]:.6f}\t{mean_vals[7]:.6f}\t{std_vals[7]:.6f}\t"
                              f"{mean_vals[8]:.6f}\t{std_vals[8]:.6f}\t{best_params['C']}\t{best_params['scale']}\t"
                              f"{best_params['N']}\t{best_params['L']}\t{best_params['activation']}\t{best_params['warmup_epochs']}\n")

    print(f"✅ Finished evaluating: {file}\n")
