import numpy as np
from overrides import overrides
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import StratifiedKFold
from sklearn.svm import SVC

from symbols.symbols.conditional_symbol import ConditionalSymbol




class SupportVectorClassifier(ConditionalSymbol):
    """
    An implementation of a probabilistic classifier that uses support vector machines with Platt scaling
    """

    def __init__(self,
                 mask,
                 data,
                 labels,
                 # c_range=np.arange(1, 16, 2),
                 # gamma_range=np.arange(4, 22, 2),
                 probabilistic=True,
                 # ,c_range=np.logspace(-3, 3, 6),
                 # gamma_range=np.logspace(-3, 3, 6),
                 c_range=np.logspace(0.01, 0.5, 10),
                 gamma_range=np.logspace(0.1, 1, 10),
                 use_mask=True
                 ):
        """
        Creates a new probabilistic classifier
        :param mask: the option's mask
        :param classifier: an existing classifier if we wish to copy one (optional)
        :param data: the training data (optional if a classifier is being copied)
        :param labels: the associated labels (optional if a classifier is being copied)
        :param c_range: a range to search over to find the optimal penalty parameter of the error term
        :param gamma_range: a range to search over to find the optimal coefficient of the RBF kernel
        """

        # TODO: We need to construct a single classifier over multiple objects.
        #      We CANNOT treat them individually (see notebook for counterexample)
        #      We need to store information so that if we get a new state from a different task, we can shove it in here
        # {'C': 1.2589254117941673, 'gamma': 1.9952623149688795}

        self._mask = mask
        if use_mask:
            data = data[:, self.mask]
            data = np.array([np.concatenate(sample).ravel() for sample in data])

        self.n_dimensions = data.shape[1]
        param_grid = dict(gamma=gamma_range, C=c_range)
        cv = StratifiedKFold(n_splits=3)
        # First find the best parameters with Platt scaling turned off
        grid = GridSearchCV(SVC(probability=False, class_weight='balanced'), param_grid=param_grid, cv=cv)
        grid.fit(data, labels)
        if probabilistic:
            params = grid.best_params_
            print(params)
            # Now do Platt scaling with the optimal parameters
            self._classifier = SVC(probability=True, class_weight='balanced', C=params['C'], gamma=params['gamma'])
            self._classifier.fit(data, labels)
            print("SCORE:", self._classifier.score(data, labels))

        else:
            self._classifier = grid.best_estimator_

        # data = data[:, self._mask]

        # self._classifiers = dict()
        #
        # for object in mask:
        #     # for col in range(data.shape[1]):
        #     x = data[:, object]
        #     # x = data[:, col]
        #     x = np.array([sample.ravel() for sample in x])
        #     param_grid = dict(gamma=gamma_range, C=c_range)
        #     cv = StratifiedKFold(n_splits=3)
        #     # First find the best parameters with Platt scaling turned off
        #     grid = GridSearchCV(SVC(probability=False, class_weight='balanced'), param_grid=param_grid, cv=cv)
        #     grid.fit(x, labels)
        #     if probabilistic:
        #         params = grid.best_params_
        #         print("Best params:", params)
        #         print("Best score:", grid.best_score_)
        #         # print(grid.best_score_)
        #         # Now do Platt scaling with the optimal parameters
        #         classifier = SVC(probability=True, class_weight='balanced', C=params['C'], gamma=params['gamma'])
        #         classifier.fit(x, labels)
        #     else:
        #         classifier = grid.best_estimator_
        #     self._classifiers[object] = classifier  # one classifier for each object

        # data = np.array([np.concatenate(sample).ravel() for sample in data])
        #
        # param_grid = dict(gamma=gamma_range, C=c_range)
        # cv = StratifiedKFold(n_splits=3)
        # # First find the best parameters with Platt scaling turned off
        # grid = GridSearchCV(SVC(probability=False, class_weight='balanced'), param_grid=param_grid, cv=cv)
        # grid.fit(data, labels)
        #
        # if probabilistic:
        #     params = grid.best_params_
        #     print(params)
        #     # Now do Platt scaling with the optimal parameters
        #     self._classifier = SVC(probability=True, class_weight='balanced', C=params['C'], gamma=params['gamma'])
        #     self._classifier.fit(data, labels)
        # else:
        #     self._classifier = grid.best_estimator_

        self.name = 'SVC'
        self.probabilistic = probabilistic

    @property
    def mask(self):
        return self._mask

    @overrides
    def probability(self,
                    state,
                    use_mask=True):
        # given a state which is a collection of objects, check if we recognise what's going on
        # prob = 1
        # for i, object in enumerate(self.mask):
        #     t_state = state[object]
        #     if self.probabilistic:
        #         prob *= self._classifiers[i].predict_proba(t_state.reshape(1, -1))[0][1]
        #     else:
        #         prob *= self._classifiers[i].predict(t_state.reshape(1, -1))[0]
        # return prob
        # t_state = state[self.mask]
        if use_mask:
            t_state = state[self.mask]
        else:
            t_state = state
        t_state = np.hstack(t_state)
        # t_state = np.array([sample.ravel() for sample in state])
        if self.probabilistic:
            return self._classifier.predict_proba(t_state.reshape(1, -1))[0][1]
        else:
            return self._classifier.predict(t_state.reshape(1, -1))[0]

    #
    # @overrides
    # def sample_error(self,
    #                  states,
    #                  labels):
    #     t_data = states[:, self._mask]
    #     outputs = self._classifier.predict(t_data)
    #     return np.sqrt(np.sum(pow((outputs - labels), 2.0)))

    def __str__(self):
        return self.name
