import numpy as np

from base_model import *
from learners import *
from dataset_loaders import MNISTLoader
from models import *
from rkhs_weightings import RKHSWeightingClassifier
from rks import RKSClassifier, get_rks_params_from_rkhs_weighting

from sklearn.metrics import mean_squared_error

def get_random_features(model):
    """
    Get the random features from the model.

    Parameters:
    - model: The model to extract random features from.

    Returns:
    - Random features.
    """
    if hasattr(model, 'W'):
        return model.W
    elif hasattr(model, 'get_center_params'):
        return model.get_center_params()
    else:
        raise ValueError("Model does not have random features (W) attribute.")

def rkhs_rademacher_bound(model: BaseRKHSWeighting, rho, delta=0.05):
    """
    Compute the Rademacher generalization gap for an RKHS Weighting model.

    Parameters:
    - model: The model to evaluate.
    - delta: Confidence level for the bound.

    Returns:
    - Rademacher bound.
    """
    m = model.data_x.shape[0]  # Number of samples
    B = model.norm()  # RKHS norm of the model
    theta = model.theta()
    # Compute the Rademacher complexity bound
    log_term = 1 + np.sqrt(np.log(1 / delta) / 2)
    return 2 * B * rho * theta / np.sqrt(m) * log_term

def l2p_rademacher_bound(model, rho, tau, train_X: np.ndarray, delta=0.05):
    """
    Compute the Rademacher generalization gap for an RKHS Weighting model.

    Parameters:
    - model: The model to evaluate.
    - rho: Lipschitz constant of the loss function.
    - train_X: Training data points.
    - test_X: Test data points.
    - delta: Confidence level for the bound.

    Returns:
    - Rademacher bound.
    """
    m = train_X.shape[0]  # Number of samples
    B = model.l2p_norm_approx()
    # Compute the Rademacher complexity bound
    log_term = 1 + np.sqrt(np.log(1 / delta) / 2)
    return 2 * B * rho * tau / np.sqrt(m) * log_term

if __name__ == "__main__":
    X_train, X_test, y_train, y_test = MNISTLoader([1,7]).load()
    model = RWExpRelu(X_train)
    learner = LeastSquaresLearner(1000)
    clf = RKHSWeightingClassifier(learner, model).fit(X_train, y_train)
    print(f'RKHS train MSE : {mean_squared_error(y_train, clf.model(X_train))}')
    print(f'RKHS test MSE : {mean_squared_error(y_test, clf.model(X_test))}')
    rho = learner.loss.lipschitz(clf.model.max_output(), y_max=np.max(np.abs(y_train)))
    tau = clf.model.tau_approx(X_test)
    print("RKHS Rademacher Bound:", rkhs_rademacher_bound(clf.model, rho))
    print("L2P Rademacher Bound:", l2p_rademacher_bound(clf.model, rho, tau, X_train))
    rks = RKSClassifier(**get_rks_params_from_rkhs_weighting(clf.model, keep_centers=True))
    rks.fit(X_train, y_train)
    print(f'RKS train MSE : {mean_squared_error(y_train, rks.raw_output(X_train))}')
    print(f'RKS test MSE : {mean_squared_error(y_test, rks.raw_output(X_test))}')
    rho_rks = learner.loss.lipschitz(max_model_output=max(rks.raw_output(X_test)), y_max=np.max(np.abs(y_train)))
    print("RKS Rademacher Bound:", l2p_rademacher_bound(rks, rho_rks, tau, X_train))