import warnings

import numpy as np
from overrides import overrides
from sklearn import svm
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold

from symbols.symbols.reward_symbol import RewardSymbol




class SupportVectorRegressor(RewardSymbol):
    """
    An implementation of a reward symbol that uses support vector regression with RBF kernel to estimate the rewards
    """
    def __init__(self,
                 states,
                 reward,
                 c_range=np.arange(2, 50, 4),
                 gamma_range=np.arange(2, 50, 4)):
        """
        Learn a new reward estimate for the given states
        :param states: the states for which we want to learn a reward function
        :param reward: the actual rewards received from training data
        :param c_range: a range to search over to find the optimal penalty parameter of the error term
        :param gamma_range: a range to search over to find the optimal coefficient of the RBF kernel
        """
        self._svr = SupportVectorRegressor._internal_fit(states, reward, c_range, gamma_range)
        self.name = 'SVR'


    @staticmethod
    def _internal_fit(data,
                      labels,
                      c_range,
                      gamma_range):
        """
        Fits the regressor to the given data and labels
        :param data: the states from the training data
        :param labels: the reward for the given states
        :param c_range: a range to search over to find the optimal penalty parameter of the error term
        :param gamma_range: a range to search over to find the optimal coefficient of the RBF kernel
        :return: the support vector regressor that performed the best
        """
        param_grid = dict(gamma=gamma_range, C=c_range)
        n_splits = len(data) if 0 < len(data) < 3 else 3

        if n_splits < 3:
            warnings.warn("Very little data to do SVR")

        cv = KFold(n_splits=n_splits)
        grid = GridSearchCV(svm.SVR(kernel='rbf'), param_grid=param_grid, cv=cv)
        grid.fit(data, labels)
        return grid.best_estimator_

    @overrides
    def predict_reward(self,
                       state):
        return self._svr.predict(state.reshape(1, -1))

    @overrides
    def expected_reward(self,
                        distribution_symbol):
        masked_samples = distribution_symbol.sample()
        dim_x = self._svr.support_vectors_.shape[1]
        samples = np.random.random([masked_samples.shape[0], dim_x])
        samples[:, distribution_symbol.flat_mask] = masked_samples
        tot_reward = 0.0
        for pos in range(0, len(samples)):
            point = samples[pos, :]
            tot_reward += self.predict_reward(point)
        tot_reward = tot_reward / float(len(samples))
        return tot_reward[0]

    def __str__(self):
        return self.name