# ✅ FLAiR BLS Model with Adam Warmup (Feature + Enhancement Layer)

import numpy as np
import time
from scipy.linalg import orth
from sklearn.preprocessing import MinMaxScaler

def tansig(x):
    return 2 / (1 + np.exp(-2 * x)) - 1

def one_hot(x, n_class):
    y = np.zeros((len(x), n_class))
    for i in range(n_class):
        idx = (x == i)
        y[idx, i] = 1
    return y

def FLAiR_BLS_Model(trainX, trainY, testX, testY, option, num_classes,
                          warmup_epochs_list=[2, 4, 6, 8, 10], lr=0.01,
                          beta1=0.9, beta2=0.999, epsilon=1e-8):

    best_result = None
    best_test_acc = -np.inf

    for warmup_epochs in warmup_epochs_list:
        try:
            C1 = option['C']
            N1 = option['N']   # nodes per feature window
            N2 = option['NN']  # number of feature windows
            N3 = option['NNN'] # nodes per enhancement window
            N4 = option['N4']  # number of enhancement windows

            start_train = time.time()
            Nsample, Nfea = trainX.shape
            trainY_one_hot = one_hot(trainY, num_classes)

            # === Feature Layer Warmup ===
            H1 = np.hstack((trainX, 0.1 * np.ones((Nsample, 1))))
            We, m_We, v_We = [], [], []
            Z_list = []
            for _ in range(N2):
                we = 2 * np.random.rand(Nfea + 1, N1) - 1
                We.append(we)
                m_We.append(np.zeros_like(we))
                v_We.append(np.zeros_like(we))

            for epoch in range(1, warmup_epochs + 1):
                Z_list = []
                for i in range(N2):
                    A = H1 @ We[i]
                    A_scaled = MinMaxScaler(feature_range=(-1, 1)).fit_transform(A)
                    Z_list.append(A_scaled)
                Z = np.hstack(Z_list)

                # Compute dummy beta and error
                if Z.shape[1] < Nsample:
                    beta_temp = np.linalg.solve(np.eye(Z.shape[1]) / C1 + Z.T @ Z, Z.T @ trainY_one_hot)
                else:
                    beta_temp = Z.T @ np.linalg.solve(np.eye(Nsample) / C1 + Z @ Z.T, trainY_one_hot)

                error = Z @ beta_temp - trainY_one_hot

                for i in range(N2):
                    beta_block = beta_temp[i * N1: (i + 1) * N1, :]
                    partial_error = error @ beta_block.T
                    grad = H1.T @ partial_error / Nsample
                    m_We[i] = beta1 * m_We[i] + (1 - beta1) * grad
                    v_We[i] = beta2 * v_We[i] + (1 - beta2) * (grad ** 2)
                    m_hat = m_We[i] / (1 - beta1 ** epoch)
                    v_hat = v_We[i] / (1 - beta2 ** epoch)
                    We[i] -= lr * m_hat / (np.sqrt(v_hat) + epsilon)

            # Final Feature Output
            Z_list = []
            for i in range(N2):
                A = H1 @ We[i]
                A_scaled = MinMaxScaler(feature_range=(-1, 1)).fit_transform(A)
                Z_list.append(A_scaled)
            Z = np.hstack(Z_list)

            # === Enhancement Layer Warmup ===
            H2 = np.hstack((Z, 0.1 * np.ones((Nsample, 1))))
            Wh, m_Wh, v_Wh = [], [], []
            for _ in range(N4):
                if N1 * N2 >= N3:
                    wh = orth(2 * np.random.rand(N1 * N2 + 1, N3) - 1)
                else:
                    wh = orth((2 * np.random.rand(N3, N1 * N2 + 1) - 1).T).T
                Wh.append(wh)
                m_Wh.append(np.zeros_like(wh))
                v_Wh.append(np.zeros_like(wh))

            for epoch in range(1, warmup_epochs + 1):
                H_list, raw_list = [], []
                for i in range(N4):
                    A = H2 @ Wh[i]
                    A_act = tansig(A)
                    H_list.append(A_act)
                    raw_list.append(A)
                H_all = np.hstack(H_list)
                X_final = np.hstack((Z, H_all))

                if X_final.shape[1] < Nsample:
                    beta_temp = np.linalg.solve(np.eye(X_final.shape[1]) / C1 + X_final.T @ X_final, X_final.T @ trainY_one_hot)
                else:
                    beta_temp = X_final.T @ np.linalg.solve(np.eye(Nsample) / C1 + X_final @ X_final.T, trainY_one_hot)

                error = X_final @ beta_temp - trainY_one_hot

                for i in range(N4):
                    beta_block = beta_temp[Z.shape[1] + i * N3: Z.shape[1] + (i + 1) * N3, :]
                    dA = 1 - np.tanh(raw_list[i]) ** 2
                    partial_error = (error @ beta_block.T) * dA
                    grad = H2.T @ partial_error / Nsample
                    m_Wh[i] = beta1 * m_Wh[i] + (1 - beta1) * grad
                    v_Wh[i] = beta2 * v_Wh[i] + (1 - beta2) * (grad ** 2)
                    m_hat = m_Wh[i] / (1 - beta1 ** epoch)
                    v_hat = v_Wh[i] / (1 - beta2 ** epoch)
                    Wh[i] -= lr * m_hat / (np.sqrt(v_hat) + epsilon)

            # Final output computation
            H_all = np.hstack([tansig(H2 @ Wh[i]) for i in range(N4)])
            X_final = np.hstack((Z, H_all))

            if X_final.shape[1] < Nsample:
                beta = np.linalg.solve(np.eye(X_final.shape[1]) / C1 + X_final.T @ X_final, X_final.T @ trainY_one_hot)
            else:
                beta = X_final.T @ np.linalg.solve(np.eye(Nsample) / C1 + X_final @ X_final.T, trainY_one_hot)

            train_output = X_final @ beta
            train_time = time.time() - start_train
            train_preds = np.argmax(train_output, axis=1)
            train_acc = np.mean(train_preds == trainY) * 100

            # === Testing ===
            start_test = time.time()
            T1 = np.hstack((testX, 0.1 * np.ones((testX.shape[0], 1))))
            Z_test = [MinMaxScaler(feature_range=(-1, 1)).fit_transform(T1 @ We[i]) for i in range(N2)]
            Z_test_concat = np.hstack(Z_test)

            H2_test = np.hstack((Z_test_concat, 0.1 * np.ones((testX.shape[0], 1))))
            H_test_all = np.hstack([tansig(H2_test @ Wh[i]) for i in range(N4)])
            X_test_final = np.hstack((Z_test_concat, H_test_all))
            test_output = X_test_final @ beta
            test_time = time.time() - start_test
            test_preds = np.argmax(test_output, axis=1)
            test_acc = np.mean(test_preds == testY) * 100

            if test_acc > best_test_acc:
                best_test_acc = test_acc
                best_result = (warmup_epochs, train_acc, test_acc, train_time, test_time)

        except Exception as e:
            print(f"❌ Error at warmup_epochs={warmup_epochs}: {e}")
            continue

    return best_result
