import numpy as np
from scipy.linalg import pinv
import time
import math
from sklearn.metrics import confusion_matrix
from metrics import evaluate
from l2_weights import l2_weights

# ----------------- Activation Functions ----------------- #
def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

def relu(x):
    return np.maximum(0, x)

def selu(x, lambda_=1.0507, alpha=1.67326):
    return lambda_ * np.where(x > 0, x, alpha * (np.exp(x) - 1))

# ----------------- StaR Regulation Function ----------------- #
def regulate_weights_using_StaR(W_init, rho_target):
    """

    Rescales the singular values of a random weight matrix to lie in [0.1 * rho_target, rho_target]
    """
    rho_min = 0.1 * rho_target
    U, S, Vh = np.linalg.svd(W_init, full_matrices=False)
    s_min, s_max = np.min(S), np.max(S)
    S_norm = (S - s_min) / (s_max - s_min + 1e-8)
    S_scaled = S_norm * (rho_target - rho_min) + rho_min
    return U @ np.diag(S_scaled) @ Vh

# ----------------- StaR-dRVFL Training & Evaluation ----------------- #
def train_and_evaluate_StaR_drvfl(trainX, trainY, testX, testY, option):
    np.random.seed(0)

    start_train = time.time()
    n_samples, n_features = trainX.shape
    N = option['N']
    L = option['L']
    C = option['C']
    s = option['scale']
    activation = option['activation']
    renormal = option['renormal']
    norm_type = option['normal_type']
    rho_target = option['spectral_radius']

    weights, biases, mu, sigma, A_merge = [], [], [], [], trainX.copy()
    A1 = trainX.copy()

    # Forward propagation through L layers
    for i in range(L):
        if i == 0:
            W_rand = s * (2 * np.random.rand(n_features, N) - 1)
        else:
            W_rand = s * (2 * np.random.rand(N, N) - 1)

        # ✅ Apply StaR spectral regulation
        W_reg = regulate_weights_using_StaR(W_rand, rho_target)

        b = s * np.random.rand(1, N)

        weights.append(W_reg)
        biases.append(b)

        A1 = A1 @ W_reg + b

        if renormal and norm_type == 0:
            mu_i = np.mean(A1, axis=0)
            sigma_i = np.std(A1, axis=0)
            mu.append(mu_i)
            sigma.append(sigma_i)
            A1 = (A1 - mu_i) / sigma_i
        else:
            mu.append(None)
            sigma.append(None)

        if activation == 1:
            A1 = sigmoid(A1)
        elif activation == 2:
            A1 = relu(A1)
        elif activation == 3:
            A1 = selu(A1)

        if renormal and norm_type == 1:
            mu_i = np.mean(A1, axis=0)
            sigma_i = np.std(A1, axis=0)
            mu[-1] = mu_i
            sigma[-1] = sigma_i
            A1 = (A1 - mu_i) / sigma_i

        A_merge = np.hstack((A_merge, A1))

    A_merge_temp = np.hstack((A_merge, np.ones((n_samples, 1))))
    beta = l2_weights(A_merge_temp, trainY, C)
    train_pred = A_merge_temp @ beta

    train_pred -= np.max(train_pred, axis=1, keepdims=True)
    prob_scores = np.exp(train_pred) / np.sum(np.exp(train_pred), axis=1, keepdims=True)
    predicted_train = np.argmax(prob_scores, axis=1)
    actual_train = np.argmax(trainY, axis=1)

    EVAL_Train = evaluate(actual_train, predicted_train)
    end_train = time.time()

    # ================= Prediction ================= #
    start_test = time.time()
    A1 = testX.copy()
    A_merge = testX.copy()
    for i in range(L):
        A1 = A1 @ weights[i] + biases[i]
        if renormal:
            if norm_type == 0 and mu[i] is not None:
                A1 = (A1 - mu[i]) / sigma[i]
        if activation == 1:
            A1 = sigmoid(A1)
        elif activation == 2:
            A1 = relu(A1)
        elif activation == 3:
            A1 = selu(A1)
        if renormal:
            if norm_type == 1 and mu[i] is not None:
                A1 = (A1 - mu[i]) / sigma[i]
        A_merge = np.hstack((A_merge, A1))

    A_merge_temp = np.hstack((A_merge, np.ones((testX.shape[0], 1))))
    raw_score = A_merge_temp @ beta
    raw_score -= np.max(raw_score, axis=1, keepdims=True)
    prob_scores = np.exp(raw_score) / np.sum(np.exp(raw_score), axis=1, keepdims=True)
    predicted_test = np.argmax(prob_scores, axis=1)
    actual_test = np.argmax(testY, axis=1)

    EVAL_Test = evaluate(actual_test, predicted_test)
    end_test = time.time()

    model = {
        "L": L,
        "w": weights,
        "b": biases,
        "beta": beta,
        "mu": mu,
        "sigma": sigma,
    }

    training_time = end_train - start_train
    testing_time = end_test - start_test

    return model, EVAL_Train, EVAL_Test, training_time, testing_time
