import os
import numpy as np
import scipy.io
from sklearn.preprocessing import StandardScaler
from StaR_dRVFL_model import train_and_evaluate_StaR_drvfl  # ✅ updated import

# ========== Configuration ========== #
result_path = r"PATH_TO_SAVE_RESULTS"
data_dir = r"PATH_TO_YOUR_DATASET_FOLDER" 
no_part = 5

C_range = 10.0 ** np.arange(-5, 6)
scale_range = 2.0 ** np.arange(-1, 2.1, 0.5)
percent_range = [0.9, 0.95, 1, 1.05, 1.1]
activation_list = [1, 2, 3]
spectral_range = np.arange(0.5, 2.01, 0.25)  # ✅ added spectral radius range

# ========== Open result file ========== #
with open(result_path, 'w') as result_file:
    result_file.write("DataSetName\tBestMeanTrainAccuracy\tBestStdTrainAccuracy\tBestMeanTestAccuracy\tBestStdTestAccuracy\t"
                      "MeanSensitivity\tStdSensitivity\tMeanSpecificity\tStdSpecificity\tMeanPrecision\tStdPrecision\t"
                      "MeanF_measure\tStdF_measure\tMeanGmean\tStdGmean\tBestMeanTrainTime\tBestStdTrainTime\t"
                      "BestMeanTestTime\tBestStdTestTime\tBest_C\tBest_s\tBest_N\tBest_L\tBest_act\tBest_rho\n")

mat_files = [f for f in os.listdir(data_dir) if f.endswith(".mat")]

for file_idx, file in enumerate(mat_files):
    print(f"\n🔍 Processing dataset {file_idx + 1}/{len(mat_files)}: {file}")
    data = scipy.io.loadmat(os.path.join(data_dir, file))
    key = list(data.keys())[-1]
    all_data = data[key]

    X_raw = all_data[:, :-1]
    y_raw = all_data[:, -1].astype(int)
    y_raw[y_raw == -1] = 0

    scaler = StandardScaler()
    X = scaler.fit_transform(X_raw)

    classes = np.unique(y_raw)
    nclass = len(classes)
    m = X.shape[0]
    data_all = np.hstack((X, y_raw.reshape(-1, 1)))

    if m < 1000:
        N_range = [256, 512, 1024]
    else:
        N_range = [1024, 2048, 4096]

    best_accuracy = 0
    best_result = None
    best_params = {}

    # ========== STEP 1: Coarse tuning ========== #
    print("⚙️ Step 1: Coarse tuning (C, N, rho)")
    for N in N_range:
        for C in C_range:
            for rho in spectral_range:
                results = []
                for part in range(no_part):
                    size = m // no_part
                    idx1 = part * size
                    idx2 = (part + 1) * size

                    test_data = data_all[idx1:idx2]
                    train_data = np.vstack((data_all[:idx1], data_all[idx2:]))

                    X_train, y_train = train_data[:, :-1], train_data[:, -1].astype(int)
                    X_test, y_test = test_data[:, :-1], test_data[:, -1].astype(int)

                    y_train_oh = np.eye(nclass)[y_train]
                    y_test_oh = np.eye(nclass)[y_test]

                    opt = {
                        "N": N,
                        "C": C,
                        "activation": 2,
                        "L": 2,
                        "scale": 1,
                        "renormal": 1,
                        "normal_type": 0,
                        "spectral_radius": rho,
                    }

                    _, train_eval, test_eval, train_time, test_time = train_and_evaluate_StaR_drvfl(
                        X_train, y_train_oh, X_test, y_test_oh, opt
                    )

                    row = train_eval[:1] + test_eval + [train_time, test_time]
                    results.append(row)

                results = np.array(results)
                mean_acc = results[:, 1].mean()

                if mean_acc > best_accuracy:
                    best_accuracy = mean_acc
                    best_result = results
                    best_params = {
                        "C": C,
                        "N": N,
                        "activation": 2,
                        "scale": 1,
                        "L": 2,
                        "spectral_radius": rho,
                    }

    # ========== STEP 2: Fine tuning ========== #
    print("🔧 Step 2: Fine tuning (L, scale, C, N, rho)")
    for L in range(1, 11):
        for Np in percent_range:
            N_val = int(np.ceil(best_params["N"] * Np))
            for Cp in percent_range:
                C_val = best_params["C"] * Cp
                for act in activation_list:
                    for scale in scale_range:
                        for rho in spectral_range:
                            results = []
                            for part in range(no_part):
                                size = m // no_part
                                idx1 = part * size
                                idx2 = (part + 1) * size

                                test_data = data_all[idx1:idx2]
                                train_data = np.vstack((data_all[:idx1], data_all[idx2:]))

                                X_train, y_train = train_data[:, :-1], train_data[:, -1].astype(int)
                                X_test, y_test = test_data[:, :-1], test_data[:, -1].astype(int)

                                y_train_oh = np.eye(nclass)[y_train]
                                y_test_oh = np.eye(nclass)[y_test]

                                opt = {
                                    "N": N_val,
                                    "C": C_val,
                                    "activation": act,
                                    "L": L,
                                    "scale": scale,
                                    "renormal": 1,
                                    "normal_type": 0,
                                    "spectral_radius": rho,
                                }

                                _, train_eval, test_eval, train_time, test_time = train_and_evaluate_StaR_drvfl(
                                    X_train, y_train_oh, X_test, y_test_oh, opt
                                )

                                row = train_eval[:1] + test_eval + [train_time, test_time]
                                results.append(row)

                            results = np.array(results)
                            mean_acc = results[:, 1].mean()

                            if mean_acc >= best_accuracy:
                                best_accuracy = mean_acc
                                best_result = results
                                best_params = {
                                    "C": C_val,
                                    "N": N_val,
                                    "activation": act,
                                    "scale": scale,
                                    "L": L,
                                    "spectral_radius": rho,
                                }

    # ========== Save Best Result ========== #
    if best_result is not None:
        mean_vals = best_result.mean(axis=0)
        std_vals = best_result.std(axis=0)

        with open(result_path, 'a') as result_file:
            result_file.write(f"{file}\t{mean_vals[0]:.6f}\t{std_vals[0]:.6f}\t{mean_vals[1]:.6f}\t{std_vals[1]:.6f}\t"
                              f"{mean_vals[2]:.6f}\t{std_vals[2]:.6f}\t{mean_vals[3]:.6f}\t{std_vals[3]:.6f}\t"
                              f"{mean_vals[4]:.6f}\t{std_vals[4]:.6f}\t{mean_vals[5]:.6f}\t{std_vals[5]:.6f}\t"
                              f"{mean_vals[6]:.6f}\t{std_vals[6]:.6f}\t{mean_vals[7]:.6f}\t{std_vals[7]:.6f}\t"
                              f"{mean_vals[8]:.6f}\t{std_vals[8]:.6f}\t{best_params['C']}\t{best_params['scale']}\t"
                              f"{best_params['N']}\t{best_params['L']}\t{best_params['activation']}\t{best_params['spectral_radius']}\n")

    print(f"✅ Finished evaluating: {file}\n")
