import time
import numpy as np

from typing import List, Dict
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from ConfigSpace.hyperparameters import UniformFloatHyperparameter, UniformIntegerHyperparameter, Constant
from tlbo.model.util_funcs import get_rng, get_types
from tlbo.acquisition_function.acquisition import EI
from tlbo.config_space import ConfigurationSpace, Configuration
from tlbo.facade.notl import NoTL
from tlbo.optimizer.ei_offline_optimizer import OfflineSearch
from tlbo.optimizer.random_configuration_chooser import ChooserProb
from tlbo.config_space.util import convert_configurations_to_array
from tlbo.utils.constants import MAXINT, SUCCESS, FAILDED
from tlbo.utils.normalization import zero_mean_unit_var_normalization, zero_one_normalization
from tlbo.acquisition_function.ta_acquisition import TAQ_EI
from tlbo.framework.smbo import BasePipeline
from tlbo.facade.base_facade import BaseFacade
import ConfigSpace as CS
from sklearn.linear_model import LogisticRegression
import random


class SMBO_SEARCH_SPACE_Enlarge_Continuous(BasePipeline):
    def __init__(self,
                 target_hpo_data: Dict,
                 config_space: ConfigurationSpace,
                 surrogate_model: BaseFacade,
                 target_func = None,
                 dims = None,
                 acq_func: str = 'ei',
                 mode='best',
                 model='gp',
                 source_hpo_data=None,
                 enable_init_design=False,
                 num_src_hpo_trial=50,
                 surrogate_type='rf',
                 max_runs=200,
                 logging_dir='./logs',
                 initial_runs=3,
                 task_id=None,
                 random_seed=None):
        super().__init__(config_space, task_id, output_dir=logging_dir)
        self.logger = super()._get_logger(self.__class__.__name__)
        if random_seed is None:
            _, rng = get_rng()
            random_seed = rng.randint(MAXINT)
        self.random_seed = random_seed
        self.rng = np.random.RandomState(self.random_seed)
        self.source_hpo_data = source_hpo_data
        self.num_src_hpo_trial = num_src_hpo_trial
        self.surrogate_type = surrogate_type
        self.acq_func = acq_func
        self.mode = mode
        self.clf_type = model
        self.target_func = target_func
        self.dims = dims
        print(mode, model)

        self.max_iterations = max_runs
        self.iteration_id = 0
        self.default_obj_value = MAXINT

        self.target_hpo_measurements = target_hpo_data
        self.configuration_list = list(self.target_hpo_measurements.keys()) if self.target_hpo_measurements else []
        print('Target problem space: %d configurations' % len(self.configuration_list))
        self.configurations = list()
        self.failed_configurations = list()
        self.perfs = list()
        self.test_perfs = list()
        self.reduce_cnt = 0

        if enable_init_design:
            raise NotImplementedError
        else:
            self.initial_configurations = None

        if self.initial_configurations is None:
            self.init_num = initial_runs
        else:
            self.init_num = len(self.initial_configurations)

        # Initialize the basic component in BO.
        self.config_space.seed(self.random_seed)
        np.random.seed(self.random_seed)
        self.model = surrogate_model
        self.acquisition_function = EI(self.model)
        self.space_classifier = None
        self.random_configuration_chooser = ChooserProb(
            prob=0.1,
            rng=np.random.RandomState(self.random_seed)
        )
        # self.prop_eta = 0.9
        # self.prop_init = 0.3

        # Set the parameter in metric.
        # ys = np.array(list(self.target_hpo_measurements.values()))
        # if ys.ndim == 2:
        #     assert ys.shape[1] == 2
        #     val_ys = ys[:, 0]
        # else:
        #     val_ys = ys
        # self.ys = val_ys
        # self.y_max, self.y_min = np.max(self.ys), np.min(self.ys)

        self.reduce_cnt = 0
        self.p_min = 5
        self.p_max = 50
        self.use_correct_rate = False

        if self.mode in ['box', 'ellipsoid']:
            continuous_types = (UniformFloatHyperparameter, UniformIntegerHyperparameter)
            self.continuous_mask = [isinstance(hp, continuous_types) for hp in self.config_space.get_hyperparameters()]
            self.calculate_box_area()

        self.space_threshold = 0.04
        self.box_cs = None

    # def get_nce(self):
    #     y_inc = self.get_inc_y()
    #     assert self.y_max != self.y_min
    #     return (y_inc - self.y_min) / (self.y_max - self.y_min)

    def get_inc_y(self):
        y_inc = np.min(self.perfs)
        return y_inc

    def run(self):
        while self.iteration_id < self.max_iterations:
            self.iterate()

    def evaluate(self, config):
        config = np.array(list(config.get_dictionary().values())).reshape(-1, self.dims)[0]
        value = self.target_func(config)
        try:
            perf, test_perf = value
        except Exception:
            perf = value
            test_perf = -1
        return perf, test_perf

    def iterate(self):
        if len(self.configurations) == 0:
            X = np.array([])
        else:
            X = convert_configurations_to_array(self.configurations)
        Y = np.array(self.perfs, dtype=np.float64)
        # start_time = time.time()
        config = self.choose_next(X, Y)
        # print('In %d-th iter, config selection took %.3fs' % (self.iteration_id, time.time() - start_time))

        trial_state = SUCCESS
        trial_info = None

        if config not in (self.configurations + self.failed_configurations):
            # Evaluate this configuration.
            perf, test_perf = self.evaluate(config)
            if perf == MAXINT:
                trial_info = 'failed configuration evaluation.'
                trial_state = FAILDED
                self.logger.error(trial_info)

            if trial_state == SUCCESS and perf < MAXINT:
                if len(self.configurations) == 0:
                    self.default_obj_value = perf

                self.configurations.append(config)
                self.perfs.append(perf)
                self.test_perfs.append(test_perf)
                self.history_container.add(config, perf)
            else:
                self.failed_configurations.append(config)
        else:
            self.logger.debug('This configuration has been evaluated! Skip it.')
            if config in self.configurations:
                config_idx = self.configurations.index(config)
                trial_state, perf = SUCCESS, self.perfs[config_idx]
            else:
                trial_state, perf = FAILDED, MAXINT

        self.iteration_id += 1
        self.logger.info(
            'Iteration-%d, objective improvement: %.4f' % (self.iteration_id, max(0, self.default_obj_value - perf)))
        return config, trial_state, perf, trial_info

    def sample_random_config(self, config_set=None, config_num=1):
        configs = list()
        sample_cnt = 0
        # configurations = self.configuration_list if config_set is None else config_set
        while len(configs) < config_num:
            sample_cnt += 1
            # _idx = self.rng.randint(len(configurations))
            # config = configurations[_idx]
            config = self.config_space.sample_configuration()
            if config not in (self.configurations + self.failed_configurations + configs):
                configs.append(config)
                sample_cnt = 0
            else:
                sample_cnt += 1
            if sample_cnt >= 200:
                configs.append(config)
                sample_cnt = 0
        return configs

    def choose_next(self, X: np.ndarray, Y: np.ndarray):
        """
        Step 1. sample a batch of random configs.
        Step 2. identify and preserve the configs in the good regions (union)
        Step 3. calculate their acquisition functions and choose the config with the largest value.
        Parameters
        ----------
        X
        Y

        Returns
        -------
        the config to evaluate next.
        """

        _config_num = X.shape[0]
        if self.mode not in['box', 'ellipsoid']:
            if _config_num < self.init_num:
                if self.initial_configurations is None:
                    config = self.sample_random_config()[0]  
                    return config
                else:
                    print('This is a config for warm-start!')
                    return self.initial_configurations[_config_num]

        
        if X.shape[0] > 0:
            start_time = time.time()
            self.model.train(X, Y)
            print('Training surrogate model took %.3f' % (time.time() - start_time))

            if self.model.method_id in ['tst', 'tstm', 'pogpe']:
                y_, _, _ = zero_one_normalization(Y)
            elif self.model.method_id in ['scot']:
                y_ = Y.copy()
            else:
                y_, _, _ = zero_mean_unit_var_normalization(Y)
            incumbent_value = np.min(y_)

            if self.acq_func == 'ei':
                self.acquisition_function.update(model=self.model, eta=incumbent_value,
                                                 num_data=len(self.history_container.data))
            else:
                raise ValueError('invalid acquisition function ~ %s.' % self.acq_func)

        # Select space
        X_candidate = self.get_X_candidate()
        if X.shape[0] == 0:
            return random.choice(X_candidate)

        # Check space
        # self.check_space(X_candidate)

        if self.rng.rand() < self.get_random_prob(self.iteration_id):
            excluded_set = list()
            candidate_set = set(X_candidate)
            for _config in self.configuration_list:
                if _config not in candidate_set:
                    excluded_set.append(_config)
            if len(excluded_set) == 0:
                excluded_set = self.configuration_list

            config = self.sample_random_config(config_set=excluded_set)[0]
            if len(self.model.target_weight) == 0:
                self.model.target_weight.append(0.)
            else:
                self.model.target_weight.append(self.model.target_weight[-1])
            print('Config sampled randomly.')
            return config

        acq_optimizer = OfflineSearch(X_candidate,
                                      self.acquisition_function,
                                      self.config_space,
                                      rng=np.random.RandomState(self.random_seed)
                                      )

        start_time = time.time()
        sorted_configs = acq_optimizer.maximize(
            runhistory=self.history_container,
            num_points=5000
        )
        print('Optimizing Acq. func took %.3f' % (time.time() - start_time))
        for _config in sorted_configs:
            if _config not in (self.configurations + self.failed_configurations):
                return _config

        print('[Warning] Reach unexpected?')
        excluded_set = list()
        candidate_set = set(X_candidate)
        for _config in self.configuration_list:
            if _config not in candidate_set and _config not in (self.configurations + self.failed_configurations):
                excluded_set.append(_config)
        if len(excluded_set) == 0:
            excluded_set = self.configuration_list
        return self.sample_random_config(config_set=excluded_set)[0]

    def get_X_candidate(self) -> List[Configuration]:
        if self.mode in ['box', 'ellipsoid']:
            return self.get_X_candidate_box()

        # Do task selection.
        if self.use_correct_rate:
            weights = self.model.correct_rate.copy()  # exclude target weight
            print('use correct rate:', weights)
        else:
            weights = self.model.w.copy()  # exclude target weight

        if self.mode == 'best':
            weights = weights[:-1]
            task_indexes = np.argsort(weights)[-1:]  # space
            task_indexes = [idx_ for idx_ in task_indexes if weights[idx_] > 0.]
        elif self.mode in ['all', 'all+-sample', 'all+', 'all+-threshold']:
            weights = weights[:-1]
            task_indexes = np.argsort(weights)  # space-all
            task_indexes = [idx_ for idx_ in task_indexes if weights[idx_] > 0.]
        elif self.mode in ['all+-sample+', 'all+-sample+-threshold']:
            n_src_samples = 5
            w_norm = 4
            weights_ = np.array(weights[:-1])
            weights_ = np.clip(weights_ - 0.5, 0.0, 1.0) ** w_norm
            if np.sum(weights_ > 0) <= n_src_samples:
                task_indexes = np.where(weights_ > 0)[0]    # can be empty!
            else:
                weights_ = [x / sum(weights_) for x in weights_]  # space-sample-new
                task_indexes = np.random.choice(list(range(len(weights_))), n_src_samples, p=weights_, replace=False)
            task_indexes = sorted(task_indexes, key=lambda x: weights_[x])
        elif self.mode == 'sample':
            # Target excluded
            weights = weights[:-1]
            weights_ = [x / sum(weights) for x in weights]  # space-sample
            task_indexes = np.random.choice(list(range(len(weights))), 1, p=weights_)
        elif self.mode == 'sample-new':
            # Target should also be sampled
            weights_ = [x / sum(weights) for x in weights]  # space-sample-new
            task_indexes = np.random.choice(list(range(len(weights))), 1, p=weights_)
        else:
            raise ValueError(self.mode)

        print('Task Indexes', task_indexes)

        if self.mode == 'sample-new' and task_indexes[0] == len(self.source_hpo_data):
            return self.choose_config_target_space()

        # Calculate the percentiles.
        p_min = self.p_min
        p_max = self.p_max
        percentiles = [p_max] * len(self.source_hpo_data)
        for _task_id in task_indexes:
            if self.use_correct_rate:
                _p = p_min + (1 - 2 * max(weights[_task_id] - 0.5, 0)) * (p_max - p_min)
            else:
                _p = p_min + (1 - weights[_task_id]) * (p_max - p_min)
            percentiles[_task_id] = _p

        print('Percentiles', percentiles)

        self.prepare_classifier(task_indexes, percentiles)

        self.update_configuration_list()  # for online benchmark

        

        X_candidate = list()
        num_samples = 10000
        for i in range(3):
            X_ALL_list = [self.config_space.sample_configuration() for _ in range(num_samples)]
            X_ALL = convert_configurations_to_array(X_ALL_list)
            y_pred = list()
            for _task_id in task_indexes:
                y_pred.append(self.space_classifier[_task_id].predict(X_ALL))

            if self.mode in ['all+', 'all+-sample', 'all+-threshold', 'all+-sample+', 'all+-sample+-threshold'] and len(
                    y_pred) > 0:
                if self.mode == 'all+-sample' and np.random.random_sample() < self.model.w[-1]:
                    print('Use the target space!')
                    if len(self.configurations) <= 20:
                        X_candidate = self.configuration_list
                    else:
                        X_candidate = self.choose_config_target_space()
                elif np.asarray(task_indexes).shape[0] == 0:    # if task_indexes is empty
                    print('No suitable source tasks! Use the target space!')
                    if len(self.configurations) <= 20:
                        X_candidate = self.configuration_list
                    else:
                        X_candidate = self.choose_config_target_space()
                else:
                    print('Use space transfer!')
                    pred_mat = np.array(y_pred)
                    while True:
                        # Count the #intersection.
                        _cnt = 0
                        config_indexes = list()
                        for _col in range(pred_mat.shape[1]):
                            if (pred_mat[:, _col] == 1).all():
                                _cnt += 1
                                config_indexes.append(_col)
                        print('The intersection of candidate space is %d.' % _cnt)

                        if self.mode in ['all+-threshold', 'all+-sample+-threshold']:
                            if _cnt < self.space_threshold * len(X_ALL_list) and len(pred_mat) > 1:
                                print('Threhold not meet!')
                                pred_mat = pred_mat[1:]
                                continue
                        elif _cnt == 0 and len(pred_mat) > 1:
                            print('Delete the least related task!')
                            pred_mat = pred_mat[1:]
                            continue

                        for _idx in config_indexes:
                            X_candidate.append(X_ALL_list[_idx])
                        print('The candidate space size is %d.' % len(X_candidate))

                        if len(X_candidate) > num_samples:
                            cand_idx  = np.random.choice( len(X_candidate), num_samples )
                            X_candidate = [X_candidate[idx] for idx in cand_idx]
                            assert len(X_candidate)==num_samples
                            return X_candidate
                        break
            elif len(y_pred) > 0:
                # Count the #intersection.
                pred_mat = np.array(y_pred)
                # print(pred_mat.shape)
                # print(np.sum(pred_mat))
                _cnt = 0
                config_indexes = list()
                for _col in range(pred_mat.shape[1]):
                    if (pred_mat[:, _col] == 1).all():
                        _cnt += 1
                        config_indexes.append(_col)
                print('The intersection of candidate space is %d.' % _cnt)

                for _idx in config_indexes:
                    X_candidate.append(self.configuration_list[_idx])
                print('The candidate space size is %d.' % len(X_candidate))

                if len(X_candidate) == 0:
                    print('[Warning] Intersect=0, please check!')
                    # Deal with the space with no candidates.
                    if len(self.configurations) <= 20:
                        X_candidate = self.configuration_list
                    else:
                        print('[Warning] len(y_pred)=0. choose_config_target_space, please check!')
                        X_candidate = self.choose_config_target_space()
            else:
                X_candidate = self.choose_config_target_space()
        
        # if len(X_candidate) < num_samples:
        #     num_remain = num_samples - len(X_candidate)
        #     X_candidate.extend([self.config_space.sample_configuration() for _ in range(num_remain)])
        
        # assert len(X_candidate) == num_samples
        # return X_candidate
        
        if len(X_candidate) > 0:
            return X_candidate
        return [self.config_space.sample_configuration() for _ in range(num_samples)]
    
        # if len(X_candidate)==0:
        #     X_candidate = [self.config_space.sample_configuration() for _ in range(num_samples)]
        # assert len(X_candidate) > 0
        # return X_candidate

    def prepare_classifier(self, task_ids, percentiles):
        # Train the binary classifier.
        print('Train binary classifiers.')
        start_time = time.time()
        self.space_classifier = [None] * len(self.source_hpo_data)
        normalize = 'standardize'

        for _task_id in task_ids:
            hpo_evaluation_data = self.source_hpo_data[_task_id]
            percentile_v = percentiles[_task_id]

            print('.', end='')
            _X, _y = list(), list()
            for _config, _config_perf in hpo_evaluation_data.items():
                _X.append(_config)
                _y.append(_config_perf)
            X = convert_configurations_to_array(_X)
            y = np.array(_y, dtype=np.float64)
            # X = X[:self.num_src_hpo_trial]
            # y = y[:self.num_src_hpo_trial]

            if normalize == 'standardize':
                if (y == y[0]).all():
                    y[0] += 1e-4
                y, _, _ = zero_mean_unit_var_normalization(y)
            elif normalize == 'scale':
                if (y == y[0]).all():
                    y[0] += 1e-4
                y, _, _ = zero_one_normalization(y)
                y = 2 * y - 1.
            else:
                raise ValueError('Invalid parameter in norm.')

            percentile = np.percentile(y, percentile_v)
            unique_ys = sorted(list(set(y)))
            if len(unique_ys) >= 2 and percentile <= unique_ys[0]:
                percentile = unique_ys[1]

            space_label = np.array(np.array(y) < percentile)
            if (np.array(y) == percentile).all():
                raise ValueError('Assertion violation: The same eval values!')
            if (space_label[0] == space_label).all():
                space_label = np.array(np.array(y) < np.mean(y))
                if (space_label[0] == space_label).all():
                    raise ValueError('Warning: Label treatment triggers!')
                else:
                    print('Warning: Label treatment triggers!')

            if self.clf_type == 'svm':
                clf = make_pipeline(StandardScaler(), SVC(gamma='auto'))
            elif self.clf_type == 'rf':
                clf = make_pipeline(StandardScaler(), RandomForestClassifier(n_estimators=50, max_depth=4))
            elif self.clf_type == 'knn':
                clf = make_pipeline(StandardScaler(), KNeighborsClassifier(n_neighbors=5))
            elif self.clf_type == 'gp':
                clf = make_pipeline(StandardScaler(), GaussianProcessClassifier())
            elif self.clf_type == 'lr':
                clf = make_pipeline(StandardScaler(), LogisticRegression())
            # print('Labels', space_label)
            # print('sum', np.sum(space_label))
            clf.fit(X, space_label)
            self.space_classifier[_task_id] = clf
        print('Building base classifier took %.3fs.' % (time.time() - start_time))

    def get_random_prob(self, iter_id):
        if self.mode in ['ellipsoid', 'box']:
            return 0
        return 0.1

    def update_configuration_list(self):
        return

    def check_space(self, X_candidate):
        all_perfs = np.array(list(self.target_hpo_measurements.values()))
        if all_perfs.ndim == 2:
            all_perfs = all_perfs[:, 0]
        print('Global Optimum:' + str(np.min(all_perfs)))
        best_perf_in_candidate = 1
        best_config = None
        for config in X_candidate:
            value = self.target_func(config)
            try:
                perf, _ = value
            except Exception:
                perf = value
            if perf < best_perf_in_candidate:
                best_perf_in_candidate = perf
                best_config = config
        print('Current Optimum:' + str(best_perf_in_candidate))
        print('Optimum in space:' + str(bool(np.min(all_perfs) == best_perf_in_candidate)))
        # print("Reduced: %.2f/%.2f, Rate: %.2f" % (
        #     len(X_candidate), len(self.target_hpo_measurements), len(X_candidate) / len(self.target_hpo_measurements)))
        # if len(X_candidate) != len(self.target_hpo_measurements):
        #     self.reduce_cnt += 1
        # print("Reduced space is applied for %d iterations!" % self.reduce_cnt)

    def choose_config_target_space(self):
        return self.configuration_list

    def calculate_box_area(self):
        """
        [NIPS 2019] Learning search spaces for Bayesian optimization: Another view of hyperparameter transfer learning
        """
        incumbent_src_configs = []
        for hpo_evaluation_data in self.source_hpo_data:
            # configs = list(hpo_evaluation_data.keys())[:self.num_src_hpo_trial]
            # perfs = list(hpo_evaluation_data.values())[:self.num_src_hpo_trial]
            configs = list(hpo_evaluation_data.keys())[:]
            perfs = list(hpo_evaluation_data.values())[:]
            idx = np.argmin(perfs)
            incumbent_src_configs.append(configs[idx])
        X_incumbents = convert_configurations_to_array(incumbent_src_configs)
        # exclude categorical params
        X_incumbents_ = X_incumbents[:, self.continuous_mask]
        print(X_incumbents_)

        if self.mode == 'ellipsoid':
            import cvxpy as cp
            lenth = X_incumbents_.shape[1]
            self.lenth = lenth
            A = cp.Variable((lenth, lenth), PSD=True)
            b = cp.Variable(lenth)
            objective = cp.Minimize(-cp.log_det(A))
            constraint = [cp.norm(A @ X_incumbents_[i] + b) <= 1 for i in range(X_incumbents_.shape[0])]
            prob = cp.Problem(objective, constraint)
            prob.solve(qcp=True)
            self.src_A = A.value
            self.src_b = b.value
            print('A', self.src_A, 'b', self.src_b)
        elif self.mode == 'box':
            self.src_X_min_ = np.min(X_incumbents_, axis=0)
            self.src_X_max_ = np.max(X_incumbents_, axis=0)
        else:
            raise ValueError(self.mode)

    def create_config_space(self, lb, ub):
        cs = CS.ConfigurationSpace()
        for i, (lower, upper) in enumerate(zip(lb, ub)):
            if lower == upper:
                hyperparameter = Constant(f"x{i}", lower)
            else:
                hyperparameter = UniformFloatHyperparameter(f"x{i}", lower, upper)
            cs.add_hyperparameter(hyperparameter)
        return cs
    
    def get_X_candidate_box(self) -> List[Configuration]:
        """
        [NIPS 2019] Learning search spaces for Bayesian optimization: Another view of hyperparameter transfer learning
        """
        nums_restart = 3
        X_candidate = []
        num_samples = 10000
        if self.mode == 'ellipsoid':
            for i in range(nums_restart):
                configuration_list = [self.config_space.sample_configuration() for _ in range(num_samples)]
                X_ALL = convert_configurations_to_array(configuration_list)
                X_ALL_ = X_ALL[:, self.continuous_mask]
                transformed_X_ALL = np.dot(self.src_A, X_ALL_.T).T + self.src_b
                norms = np.linalg.norm(transformed_X_ALL, axis=1)
                valid_idx = np.where(norms <= 1)[0]
                if len(X_candidate) + len(valid_idx) >= num_samples:
                    num_remain = num_samples - len(X_candidate)
                    valid_idx = valid_idx[:num_remain]
                    X_candidate.extend([configuration_list[i] for i in valid_idx])
                    return X_candidate
                X_candidate.extend([configuration_list[i] for i in valid_idx])
                if len(X_candidate) > num_samples:
                    cand_idx  = np.random.choice( len(X_candidate), num_samples )
                    X_candidate = [X_candidate[idx] for idx in cand_idx]
                    assert len(X_candidate)==num_samples
                    return X_candidate
            # if len(X_candidate)==0:
            #     X_candidate = self.solve_ellipoid()
            # assert len(X_candidate) > 0
            # return X_candidate
        elif self.mode == 'box':
            for i in range(nums_restart):
                configuration_list = [self.config_space.sample_configuration() for _ in range(num_samples)]
                X_ALL = convert_configurations_to_array(configuration_list)
                X_ALL_ = X_ALL[:, self.continuous_mask]

                for dim, (min_val, max_val) in enumerate(zip(self.src_X_min_, self.src_X_max_)):
                    if min_val == max_val:
                        X_ALL_[:, dim] = min_val
                        for i, config in enumerate(configuration_list):
                            param_name = self.config_space.get_hyperparameter_by_idx(dim)
                            param = self.config_space.get_hyperparameter(param_name)
                            lb = param.lower
                            ub = param.upper
                            config[param_name] = (ub - lb) * min_val + lb
                valid_mask = np.logical_and(self.src_X_min_ <= X_ALL_, X_ALL_ <= self.src_X_max_).all(axis=1)
                valid_idx = np.where(valid_mask)[0].tolist()
                if len(X_candidate) + len(valid_idx) >= num_samples:
                    num_remain = num_samples - len(X_candidate)
                    valid_idx = valid_idx[:num_remain]
                    X_candidate.extend([configuration_list[i] for i in valid_idx])
                    return X_candidate
                X_candidate.extend([configuration_list[i] for i in valid_idx])
                
                # if len(X_candidate) > num_samples:
                #     cand_idx  = np.random.choice( len(X_candidate), num_samples )
                #     X_candidate = [X_candidate[idx] for idx in cand_idx]
                #     assert len(X_candidate)==num_samples
                #     return X_candidate
        else:
            raise ValueError(self.mode)
        
        if len(X_candidate) > 0:
            return X_candidate
        X_candidate.extend([self.config_space.sample_configuration() for _ in range(num_samples)])
        return X_candidate
        
        # if len(X_candidate) < num_samples:
        #     num_remain = num_samples - len(X_candidate)
        #     X_candidate.extend([self.config_space.sample_configuration() for _ in range(num_remain)])
        # assert len(X_candidate) == num_samples
        # return X_candidate
        
        

    def get_bounds(self):
        bounds = []
        for hyperparameter in self.config_space.get_hyperparameters():
            lower_bound = hyperparameter.lower
            upper_bound = hyperparameter.upper
            bounds.append((lower_bound, upper_bound))
        return bounds
    
    def solve_ellipoid(self):
        from scipy.optimize import minimize
        
        A = np.array(self.src_A)  
        b = np.array(self.src_b)  
        
        def objective(X):
            return 0  

        def constraint(X):
            return 1 - np.linalg.norm(A @ X + b)
        
        X0 = np.zeros(A.shape[1])
        
        lb = np.zeros(self.dims)
        ub = np.ones(self.dims)
        bounds = [(lb[i], ub[i]) for i in range(len(lb))]
        
        con = {'type': 'ineq', 'fun': constraint}
        
        solutions = []
        for _ in range(100):
            X0 = np.random.uniform(low=lb, high=ub)
            sol = minimize(objective, X0, bounds=bounds, constraints=con)
            if sol.success:
                solutions.append(sol.x)
        

        candidates = []
        for sol in solutions:
            values_dict = {}
            for hp, norm_value in zip(self.config_space.get_hyperparameters(), sol):
                lower = hp.lower
                upper = hp.upper
                original_value = lower + (upper - lower) * norm_value
                values_dict[hp.name] = original_value
            configuration = Configuration(self.config_space, values=values_dict)
            candidates.append(configuration)
        return candidates