import os
import numpy as np
import scipy.io
from sklearn.preprocessing import StandardScaler
from dRVFL_model import train_and_evaluate_drvfl

# ================= Configuration ================= #
result_path = r"PATH_TO_SAVE_RESULTS"
data_dir = r"PATH_TO_YOUR_DATASET_FOLDER"
no_part = 5

# Step 1 search ranges
C_range = 10.0 ** np.arange(-5, 6)
# N_range = [256, 512, 1024]



# Step 2 fine-tuning ranges
L_range = list(range(1, 11))
# scale_range = 2.0 ** np.arange(-2, 2.5, 0.5)
scale_range = 2.0 ** np.arange(-1, 2.1, 0.5)   # [0.5, 0.7071, 1.0, 1.4142, 2.0]

# percent_range = [0.9, 0.925, 0.95, 0.975, 1, 1.025, 1.05, 1.075, 1.1]
percent_range = [0.9, 0.95, 1, 1.05, 1.1]

activation_list = [1, 2, 3]

# =================================================== #

with open(result_path, 'w') as result_file:
    result_file.write("DataSetName\tBestMeanTrainAccuracy\tBestStdTrainAccuracy\tBestMeanTestAccuracy\tBestStdTestAccuracy\t"
                      "MeanSensitivity\tStdSensitivity\tMeanSpecificity\tStdSpecificity\tMeanPrecision\tStdPrecision\t"
                      "MeanF_measure\tStdF_measure\tMeanGmean\tStdGmean\tBestMeanTrainTime\tBestStdTrainTime\t"
                      "BestMeanTestTime\tBestStdTestTime\tBest_C\tBest_s\tBest_N\tBest_L\tBest_act\n")

mat_files = [f for f in os.listdir(data_dir) if f.endswith(".mat")]

for file_idx, file in enumerate(mat_files):
    print(f"\n🔍 Processing dataset {file_idx + 1}/{len(mat_files)}: {file}")

    data = scipy.io.loadmat(os.path.join(data_dir, file))
    key = list(data.keys())[-1]
    all_data = data[key]

    X_raw = all_data[:, :-1]
    y_raw = all_data[:, -1].astype(int)
    y_raw[y_raw == -1] = 0

    scaler = StandardScaler()
    X = scaler.fit_transform(X_raw)

    classes = np.unique(y_raw)
    nclass = len(classes)
    m = X.shape[0]
    data_all = np.hstack((X, y_raw.reshape(-1, 1)))

    # ✅ Dynamically choose N_range based on sample count
    if m < 1000:
        N_range = [256, 512, 1024]
    else:
        N_range = [1024, 2048, 4096]

    best_accuracy = 0
    best_result = None
    best_params = {}

    # ================== STEP 1 ================== #
    print("⚙️ Step 1: Coarse tuning (C, N)")
    for N in N_range:
        for C in C_range:
            results = []
            for part in range(no_part):
                size = m // no_part
                idx1 = part * size
                idx2 = (part + 1) * size

                test_data = data_all[idx1:idx2]
                train_data = np.vstack((data_all[:idx1], data_all[idx2:]))

                X_train, y_train = train_data[:, :-1], train_data[:, -1].astype(int)
                X_test, y_test = test_data[:, :-1], test_data[:, -1].astype(int)

                y_train_oh = np.eye(nclass)[y_train]
                y_test_oh = np.eye(nclass)[y_test]

                opt = {
                    "N": N,
                    "C": C,
                    "activation": 2,
                    "L": 2,
                    "scale": 1,
                    "renormal": 1,
                    "normal_type": 0,
                }

                _, train_eval, test_eval, train_time, test_time = train_and_evaluate_drvfl(
                    X_train, y_train_oh, X_test, y_test_oh, opt
                )

                row = train_eval[:1] + test_eval + [train_time, test_time]
                results.append(row)

            results = np.array(results)
            mean_acc = results[:, 1].mean()

            if mean_acc > best_accuracy:
                best_accuracy = mean_acc
                best_result = results
                best_params = {
                    "C": C,
                    "N": N,
                    "activation": 2,
                    "scale": 1,
                    "L": 2,
                }

    # ================== STEP 2 ================== #
    print("🔧 Step 2: Fine tuning (L, scale, C, N)")
    for L in L_range:
        for Np in percent_range:
            N_val = int(np.ceil(best_params["N"] * Np))
            for Cp in percent_range:
                C_val = best_params["C"] * Cp
                for act in activation_list:
                    for scale in scale_range:
                        results = []
                        for part in range(no_part):
                            size = m // no_part
                            idx1 = part * size
                            idx2 = (part + 1) * size

                            test_data = data_all[idx1:idx2]
                            train_data = np.vstack((data_all[:idx1], data_all[idx2:]))

                            X_train, y_train = train_data[:, :-1], train_data[:, -1].astype(int)
                            X_test, y_test = test_data[:, :-1], test_data[:, -1].astype(int)

                            y_train_oh = np.eye(nclass)[y_train]
                            y_test_oh = np.eye(nclass)[y_test]

                            opt = {
                                "N": N_val,
                                "C": C_val,
                                "activation": act,
                                "L": L,
                                "scale": scale,
                                "renormal": 1,
                                "normal_type": 0,
                            }

                            _, train_eval, test_eval, train_time, test_time = train_and_evaluate_drvfl(
                                X_train, y_train_oh, X_test, y_test_oh, opt
                            )

                            row = train_eval[:1] + test_eval + [train_time, test_time]
                            results.append(row)

                        results = np.array(results)
                        mean_acc = results[:, 1].mean()

                        if mean_acc >= best_accuracy:
                            best_accuracy = mean_acc
                            best_result = results
                            best_params = {
                                "C": C_val,
                                "N": N_val,
                                "activation": act,
                                "scale": scale,
                                "L": L,
                            }

    if best_result is not None:
        mean_vals = best_result.mean(axis=0)
        std_vals = best_result.std(axis=0)

        with open(result_path, 'a') as result_file:
            result_file.write(f"{file}\t{mean_vals[0]:.6f}\t{std_vals[0]:.6f}\t{mean_vals[1]:.6f}\t{std_vals[1]:.6f}\t"
                              f"{mean_vals[2]:.6f}\t{std_vals[2]:.6f}\t{mean_vals[3]:.6f}\t{std_vals[3]:.6f}\t"
                              f"{mean_vals[4]:.6f}\t{std_vals[4]:.6f}\t{mean_vals[5]:.6f}\t{std_vals[5]:.6f}\t"
                              f"{mean_vals[6]:.6f}\t{std_vals[6]:.6f}\t{mean_vals[7]:.6f}\t{std_vals[7]:.6f}\t"
                              f"{mean_vals[8]:.6f}\t{std_vals[8]:.6f}\t{best_params['C']}\t{best_params['scale']}\t"
                              f"{best_params['N']}\t{best_params['L']}\t{best_params['activation']}\n")

    print(f"✅ Finished evaluating: {file}\n")
