from sklearn import model_selection, metrics, datasets
from sklearn.svm import SVC, SVR, NuSVC, NuSVR
import numpy as np
from localglobal.mixed_test_func.xgboost_hp import XGBoostOptTask


class SVMOptTask(XGBoostOptTask):
    problem_type = 'mixed'

    def __init__(self, lamda=1e-6, task=None, split=0.3, normalize=False, seed=None):
        super().__init__(lamda=lamda, task=task, split=split, normalize=normalize, seed=seed)
        # Set a upper bound for the fX value. Otherwise we get stuff like 1e18 as loss (for some reason)
        self.lb = np.array([-3 ,-6, -4])
        self.ub = np.array([np.log(0.999), 0, 2])
        self.fX_ub = 10000.

    def create_model(self, h, x):
        model_kwargs = self.convert_input_into_kwargs(h, x)

        if self.reg_or_clf == 'clf':
            model = NuSVC(**model_kwargs)
        else:
            model = NuSVR(**model_kwargs)
        return model

    def convert_input_into_kwargs(self, h, x) -> dict:

        x = x.flatten()

        # new_range = self.original_x_bounds[:, 1] - self.original_x_bounds[:, 0]
        # x = ((x - self.lb) * new_range / (self.ub - self.lb)) \
        #     + self.original_x_bounds[:, 0]

        kwargs = {}

        # setting this so that max time should be around 10-20s if all
        # iters are used. Estimating this using macbook speeds
        kwargs['max_iter'] = 1e6
        # kwargs['verbose'] = 1

        # Categorical vars
        kernels = ('linear', 'poly', 'rbf', 'sigmoid')
        kernel_idx = h[0]
        kwargs['kernel'] = kernels[kernel_idx]

        gammas = ('scale', 'auto')
        gamma_idx = h[1]
        kwargs['gamma'] = gammas[gamma_idx]

        shrinkings = (True, False)
        shrinking_idx = h[2]
        kwargs['shrinking'] = shrinkings[shrinking_idx]

        # Continuous vars
        kwargs['nu'] = 10 ** x[0]  # [0, 1]
        kwargs['tol'] = np.power(10, x[1])  # [1e-6, 1] on a log scale
        if self.reg_or_clf == 'reg':
            kwargs['C'] = 10. ** x[2]
            # kwargs['C'] = x[2]  # [0, 10]

        print(f"Create SVM with kwargs:\n{kwargs}")
        return kwargs

    def get_bnds(self):
        # self.original_x_bounds = np.array([[1e-3, 1 - 1e-3],  # nu
        #                                    [-6, 0],  # tol
        #                                    [1e-4, 100]])  # C
        self.original_x_bounds = np.array([[-3, np.log10(1 - 1e-3),
                                            [-6, 0],
                                            [-4, 2],
                                            ]])

        self.categorical_dims = np.array([0, 1, 2])
        self.continuous_dims = np.array([3, 4, 5])
        self.n_vertices = np.array([4, 2, 2])

    def compute(self, X, normalize=False):
        res = super(SVMOptTask, self).compute(X, False)
        # Try log transform
        # res = np.log(np.clip(res, 0., self.fX_ub))
        if normalize:
            res = (res - self.mean) / self.std
        return res


if __name__ == '__main__':
    t = SVMOptTask(task='boston', split=0.2)
    X = np.atleast_2d([0, 0, 1] + [0.5] * 3)
    y = t.compute(X)
