import numpy as np
from scipy.linalg import pinv
import time
import math
from sklearn.metrics import confusion_matrix
from metrics import evaluate
from l2_weights import l2_weights


# Activation functions
def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

def relu(x):
    return np.maximum(0, x)

def selu(x, lambda_=1.0507, alpha=1.67326):
    return lambda_ * np.where(x > 0, x, alpha * (np.exp(x) - 1))

# # L2-regularized closed-form output weights
# def l2_weights(A, b, C):
#     n_samples, n_features = A.shape
#     if C == 0:
#         return pinv(A) @ b
#     if n_features < n_samples:
#         return np.linalg.solve((np.eye(n_features) / C + A.T @ A), A.T @ b)
#     else:
#         return A.T @ np.linalg.solve((np.eye(n_samples) / C + A @ A.T), b)

# # Metric calculation function
# def evaluate(actual, predicted):
#     actual = np.asarray(actual).flatten()
#     predicted = np.asarray(predicted).flatten()

#     cm = confusion_matrix(actual, predicted, labels=np.unique(actual))
#     if cm.shape[0] == 2:
#         tn, fp, fn, tp = cm.ravel()
#     else:
#         tp = np.sum(actual == predicted)
#         tn = fp = fn = 0

#     p = tp + fn
#     n = tn + fp
#     N = p + n

#     sensitivity = 100 * tp / p if p else 0
#     specificity = 100 * tn / n if n else 0
#     accuracy = 100 * (tp + tn) / N if N else 0
#     precision = 100 * tp / (tp + fp) if (tp + fp) else 0
#     f_measure = 2 * (precision * sensitivity) / (precision + sensitivity) if (precision + sensitivity) else 0
#     gmean = 100 * math.sqrt((tp / p) * (tn / n)) if p and n else 0

#     return [accuracy, sensitivity, specificity, precision, f_measure, gmean]

# Unified training + prediction routine
def train_and_evaluate_drvfl(trainX, trainY, testX, testY, option):
    np.random.seed(0)

    start_train = time.time()
    n_samples, n_features = trainX.shape
    N = option['N']
    L = option['L']
    C = option['C']
    s = option['scale']
    activation = option['activation']
    renormal = option['renormal']
    norm_type = option['normal_type']

    weights, biases, mu, sigma, A_merge = [], [], [], [], trainX.copy()
    A1 = trainX.copy()

    # Forward propagation through L layers
    for i in range(L):
        if i == 0:
            w = s * (2 * np.random.rand(n_features, N) - 1)
        else:
            w = s * (2 * np.random.rand(N, N) - 1)
        b = s * np.random.rand(1, N)
        weights.append(w)
        biases.append(b)

        A1 = A1 @ w + b

        if renormal and norm_type == 0:
            mu_i = np.mean(A1, axis=0)
            sigma_i = np.std(A1, axis=0)
            mu.append(mu_i)
            sigma.append(sigma_i)
            A1 = (A1 - mu_i) / sigma_i
        else:
            mu.append(None)
            sigma.append(None)

        if activation == 1:
            A1 = sigmoid(A1)
        elif activation == 2:
            A1 = relu(A1)
        elif activation == 3:
            A1 = selu(A1)

        if renormal and norm_type == 1:
            mu_i = np.mean(A1, axis=0)
            sigma_i = np.std(A1, axis=0)
            mu[-1] = mu_i
            sigma[-1] = sigma_i
            A1 = (A1 - mu_i) / sigma_i

        A_merge = np.hstack((A_merge, A1))

    A_merge_temp = np.hstack((A_merge, np.ones((n_samples, 1))))
    beta = l2_weights(A_merge_temp, trainY, C)
    train_pred = A_merge_temp @ beta

    train_pred -= np.max(train_pred, axis=1, keepdims=True)
    prob_scores = np.exp(train_pred) / np.sum(np.exp(train_pred), axis=1, keepdims=True)
    predicted_train = np.argmax(prob_scores, axis=1)
    actual_train = np.argmax(trainY, axis=1)

    EVAL_Train = evaluate(actual_train, predicted_train)
    end_train = time.time()

    # ================= Prediction ================= #
    start_test = time.time()
    A1 = testX.copy()
    A_merge = testX.copy()
    for i in range(L):
        A1 = A1 @ weights[i] + biases[i]
        if renormal:
            if norm_type == 0 and mu[i] is not None:
                A1 = (A1 - mu[i]) / sigma[i]
        if activation == 1:
            A1 = sigmoid(A1)
        elif activation == 2:
            A1 = relu(A1)
        elif activation == 3:
            A1 = selu(A1)
        if renormal:
            if norm_type == 1 and mu[i] is not None:
                A1 = (A1 - mu[i]) / sigma[i]
        A_merge = np.hstack((A_merge, A1))

    A_merge_temp = np.hstack((A_merge, np.ones((testX.shape[0], 1))))
    raw_score = A_merge_temp @ beta
    raw_score -= np.max(raw_score, axis=1, keepdims=True)
    prob_scores = np.exp(raw_score) / np.sum(np.exp(raw_score), axis=1, keepdims=True)
    predicted_test = np.argmax(prob_scores, axis=1)
    actual_test = np.argmax(testY, axis=1)

    EVAL_Test = evaluate(actual_test, predicted_test)
    end_test = time.time()

    model = {
        "L": L,
        "w": weights,
        "b": biases,
        "beta": beta,
        "mu": mu,
        "sigma": sigma,
    }

    training_time = end_train - start_train
    testing_time = end_test - start_test

    return model, EVAL_Train, EVAL_Test, training_time, testing_time
