import numpy as np
from sklearn.datasets import make_classification, make_regression
from sklearn.preprocessing import StandardScaler
from typing import Type


from kernel import GaussianKernel, ExponentialKernel, IndicatorGaussianKernel
from loss import LogisticLoss, MSE, Hinge
from learners import SFGDLearner, LeastSquaresLearner, LassoLearner
from base_model import BaseRKHSWeighting
from models.stump import RWStumps
from models.relu import RWRelu, RWExpRelu
from models.sign import RWSign, RWExpSign
from rkhs_weightings import RKHSWeightingClassifier

INSTANTIATIONS = [RWSign, RWRelu, RWStumps,
                  RWExpSign, RWExpRelu]
ANALYTICAL_INSTANTIATIONS = [RWSign, RWRelu, RWStumps, RWExpSign]
FS_INSTANTIATIONS = [RWSign, RWRelu, RWExpSign]
LEARNERS = [SFGDLearner, LeastSquaresLearner, 
            LassoLearner]
LOSSES = [LogisticLoss, MSE, Hinge]
KERNELS = [GaussianKernel, ExponentialKernel, IndicatorGaussianKernel]
SAMPLE_SIZE = 200
N_ITER = 100
N_DIM = 5
SMALL_SAMPLE_SIZE = 20
SMALL_N_ITER = 10
N_DIM = 4
MC_PRECISION = 1
N_MC = 1000
LARGE_N_MC = 5000
RNG = np.random.default_rng(0)

def make_scaled_classification(n_samples=SAMPLE_SIZE, n_features=N_DIM, random_state=0):
    X, y = make_classification(n_samples=n_samples, 
                               n_features=n_features, 
                               n_clusters_per_class=1, 
                               random_state=random_state)
    X = StandardScaler().fit_transform(X)
    y = np.where(y == y[0], 1, -1)
    return X, y

def make_scaled_regression(n_samples=SAMPLE_SIZE, n_features=N_DIM, random_state=0):
    X, y = make_regression(n_samples=n_samples, 
                           n_features=n_features,
                           random_state=random_state)
    X = StandardScaler().fit_transform(X)
    y = (y - np.mean(y)) / np.std(y)
    return X, y

def basic_classification(model_class: Type[BaseRKHSWeighting]=RWSign):
    X, y = make_scaled_classification()
    model = model_class(data_x=X, data_y=y, rng=RNG)
    learner = LeastSquaresLearner(n_iter=N_ITER, rng=RNG)
    clf = RKHSWeightingClassifier(learner, model).fit(X, y) 
    return clf, X, y