import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
from ConfigSpace.conditions import InCondition

from sklearn import svm
from sklearn.linear_model import SGDClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression

def init_model(name_clf, cfg, seed):
    if name_clf == "lr":
        return LogisticRegression(**cfg, verbose=0, n_jobs=1, random_state=seed)
    elif name_clf == "dt":
        cfg = {k: cfg[k] for k in cfg if cfg[k]}
        if cfg["max_features"] == "None":
            cfg["max_features"] = None
        return DecisionTreeClassifier(**cfg, random_state=seed)
    elif name_clf == "svm":
        cfg = {k: cfg[k] for k in cfg if cfg[k]}
        cfg["shrinking"] = True if cfg["shrinking"] == "true" else False
        if "gamma" in cfg:
            cfg["gamma"] = cfg["gamma_value"] if cfg["gamma"] == "value" else "auto"
            cfg.pop("gamma_value", None)  # Remove "gamma_value"

        return svm.SVC(**cfg, random_state=seed, verbose=0)
    elif name_clf == "knn":
        return KNeighborsClassifier(**cfg, n_jobs=1)
    elif name_clf == "sgd":
        return SGDClassifier(**cfg, n_jobs=1)
    else:
        raise Exception("Not implemented")

def get_logistic_regression_params(seed, params):
    cs = CS.ConfigurationSpace(seed=seed)
    warm_start = CSH.CategoricalHyperparameter('warm_start',
                                               choices=[True, False])
    fit_intercept = CSH.CategoricalHyperparameter('fit_intercept',
                                               choices=[True, False])
    tol = CSH.UniformFloatHyperparameter('tol', lower=0.00001, upper=0.0001, log=False)
    C = CSH.UniformFloatHyperparameter('C', lower=1e-4, upper=1e4, log=True)
    solver = CSH.CategoricalHyperparameter('solver',
                    choices=['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'])
    max_iter = CSH.UniformIntegerHyperparameter(name='max_iter', lower=5, upper=1000, log=False)

    cs.add_hyperparameters([warm_start, fit_intercept, tol, C, solver, max_iter])
    return cs

def get_decision_tree_params(seed, params):
    cs = CS.ConfigurationSpace(seed=seed)
    criterion = CSH.CategoricalHyperparameter('criterion', choices=["gini", "entropy"])
    splitter = CSH.CategoricalHyperparameter('splitter', choices=["best", "random"])
    max_features = CSH.CategoricalHyperparameter('max_features', choices=["sqrt", "log2", "None"])
    min_samples_split = CSH.UniformFloatHyperparameter('min_samples_split', lower=0.01, upper=0.99, log=False)
    min_samples_leaf = CSH.UniformFloatHyperparameter('min_samples_leaf', lower=0.01, upper=0.5, log=False)

    cs.add_hyperparameters([criterion, splitter, max_features, min_samples_split])
    return cs


def get_svm_params(seed, params):
    # https://automl.github.io/SMAC3/master/examples/SMAC4HPO_svm.html
    cs = CS.ConfigurationSpace(seed=seed)
    kernel = CSH.CategoricalHyperparameter("kernel", choices=["linear", "rbf", "poly", "sigmoid"], default_value="poly")
    C = CSH.UniformFloatHyperparameter("C", lower=0.0001, upper=10000.0, default_value=1.0, log=True)
    shrinking = CSH.CategoricalHyperparameter("shrinking", choices=["true", "false"], default_value="true")
    degree = CSH.UniformIntegerHyperparameter("degree", 1, 5, default_value=3)
    coef0 = CSH.UniformFloatHyperparameter("coef0", 0.0, 10.0, default_value=0.0)
    use_degree = InCondition(child=degree, parent=kernel, values=["poly"])
    use_coef0 = InCondition(child=coef0, parent=kernel, values=["poly", "sigmoid"])
    gamma = CSH.CategoricalHyperparameter("gamma", choices=["auto", "value"], default_value="auto")
    gamma_value = CSH.UniformFloatHyperparameter("gamma_value", lower=0.0001, upper=8, default_value=1)
    max_iter = CSH.UniformIntegerHyperparameter(name='max_iter', lower=5, upper=1000, log=False)

    cs.add_hyperparameter(kernel)
    cs.add_hyperparameters([C, shrinking, max_iter])
    cs.add_hyperparameters([degree, coef0])
    cs.add_conditions([use_degree, use_coef0])
    cs.add_hyperparameters([gamma, gamma_value])
    cs.add_condition(InCondition(child=gamma_value, parent=gamma, values=["value"]))
    cs.add_condition(InCondition(child=gamma, parent=kernel, values=["rbf", "poly", "sigmoid"]))
    return cs


def get_svm_params(seed, params):
    # https://automl.github.io/SMAC3/master/examples/SMAC4HPO_svm.html
    cs = CS.ConfigurationSpace(seed=seed)
    kernel = CSH.CategoricalHyperparameter("kernel", choices=["linear", "rbf", "poly", "sigmoid"], default_value="poly")
    C = CSH.UniformFloatHyperparameter("C", lower=0.0001, upper=10000.0, default_value=1.0, log=True)
    shrinking = CSH.CategoricalHyperparameter("shrinking", choices=["true", "false"], default_value="true")
    degree = CSH.UniformIntegerHyperparameter("degree", 1, 5, default_value=3)
    coef0 = CSH.UniformFloatHyperparameter("coef0", 0.0, 10.0, default_value=0.0)
    use_degree = InCondition(child=degree, parent=kernel, values=["poly"])
    use_coef0 = InCondition(child=coef0, parent=kernel, values=["poly", "sigmoid"])
    gamma = CSH.CategoricalHyperparameter("gamma", choices=["auto", "value"], default_value="auto")
    gamma_value = CSH.UniformFloatHyperparameter("gamma_value", lower=0.0001, upper=8, default_value=1)
    max_iter = CSH.UniformIntegerHyperparameter(name='max_iter', lower=5, upper=1000, log=False)

    cs.add_hyperparameter(kernel)
    cs.add_hyperparameters([C, shrinking, max_iter])
    cs.add_hyperparameters([degree, coef0])
    cs.add_conditions([use_degree, use_coef0])
    cs.add_hyperparameters([gamma, gamma_value])
    cs.add_condition(InCondition(child=gamma_value, parent=gamma, values=["value"]))
    cs.add_condition(InCondition(child=gamma, parent=kernel, values=["rbf", "poly", "sigmoid"]))
    return cs


def get_knn_params(seed, params):
    cs = CS.ConfigurationSpace(seed=seed)
    n_neighbors = CSH.UniformIntegerHyperparameter('n_neighbors', lower=1, upper=100, log=True)
    p = CSH.UniformIntegerHyperparameter('p', lower=1, upper=2, log=False)
    weights = CSH.CategoricalHyperparameter("weights", choices=["uniform", "distance"], default_value="uniform")

    cs.add_hyperparameters([n_neighbors, p, weights])
    return cs


def get_sgd_params(seed, params):
    cs = CS.ConfigurationSpace(seed=seed)
    alpha = CSH.UniformFloatHyperparameter("alpha", lower=1e-7, upper=0.1, default_value=0.0001, log=True)
    average = CSH.CategoricalHyperparameter('average', choices=[True, False])
    fit_intercept = CSH.CategoricalHyperparameter('fit_intercept', choices=[True, False])
    learning_rate = CSH.CategoricalHyperparameter('learning_rate', choices=["optimal", "invscaling", "constant"])
    loss = CSH.CategoricalHyperparameter('loss', choices=["hinge", "log", "modified_huber", "squared_hinge", "perceptron"])
    penalty = CSH.CategoricalHyperparameter('penalty', choices=["l1", "l2", "elasticnet"])
    tol = CSH.UniformFloatHyperparameter("tol", lower=1e-05, upper=0.1, default_value=0.0001, log=True)
    eta0 = CSH.UniformFloatHyperparameter("eta0", lower=1e-7, upper=0.1, default_value=0.0001, log=True)
    power_t = CSH.UniformFloatHyperparameter("power_t", lower=1e-5, upper=1, log=False)
    epsilon = CSH.UniformFloatHyperparameter("epsilon", lower=1e-5, upper=0.1, log=True)
    l1_ratio = CSH.UniformFloatHyperparameter("l1_ratio", lower=1e-9, upper=1, log=True)

    use_eta0 = InCondition(child=eta0, parent=learning_rate, values=["invscaling", "constant"])
    use_power_t = InCondition(child=power_t, parent=learning_rate, values=["invscaling"])
    use_epsilon = InCondition(child=epsilon, parent=loss, values=["modified_huber"])
    use_l1_ratio= InCondition(child=l1_ratio, parent=penalty, values=["elasticnet"])


    cs.add_hyperparameters([alpha, average, fit_intercept, learning_rate, loss, penalty, tol, eta0, power_t, epsilon, l1_ratio])
    cs.add_conditions([use_eta0, use_power_t, use_epsilon, use_l1_ratio])
    return cs


def get_search_sapce(seed, params):
    if params.name == "lr":
        return get_logistic_regression_params(seed, params)
    elif params.name == "dt":
        return get_decision_tree_params(seed, params)
    elif params.name == "svm":
        return get_svm_params(seed, params)
    elif params.name == "knn":
        return get_knn_params(seed, params)
    elif params.name == "sgd":
        return get_sgd_params(seed, params)
