import numpy as np
import tensorflow as tf
import pickle


class AbstractRewardEstimator:
    def __init__(self, sess):
        self.sess = sess

    def update_model(self, samples, rewards):
        '''Update model with new samples and rewards'''
        pass

    def compute_value_map(self):
        '''Compute value density map of all context space'''
        pass


class RewardEstimatorNN(AbstractRewardEstimator):
    def __init__(self,
                 context_bounds,
                 scope,  # TF scope name
                 lr=1e-3,  # learning rate for optimizer
                 eps=1e-5,  # eps for optimizer
                 rel_err=0.05,  # relative error to be reached during optimization
                 min_steps=10,
                 max_steps=2000,
                 sess=None,
                 net_arch=None,  # Network architecture, in stable_baseline style, default [64, 64] act_func=relu
                 # if not None, pass a integer to create a network with network weights 20 iterations ago (For intrinsic motivation)
                 min_val=-np.inf,
                 max_val=np.inf,
                 reuse=False
                 ):

        init_self = False
        if sess is None:
            sess = tf.Session()
            init_self = True

        super(RewardEstimatorNN, self).__init__(sess)
        self.context_bounds = context_bounds
        context_dim = context_bounds[0].shape[0]
        self.contexts_ph = tf.placeholder(dtype=tf.float64, shape=[None, context_dim])
        self.lr = lr
        self.eps = eps
        self.rel_err = rel_err
        self.min_steps = min_steps
        self.max_steps = max_steps
        self.net_arch = net_arch
        self.scope = scope
        self.value_fn = self._build_model(scope, self.contexts_ph, net_arch=net_arch, reuse=reuse)
        self.value_flat = self.value_fn[:, 0]
        self.setup_model(scope, reuse=reuse)
        self.model_weights = tf.trainable_variables(scope=self.scope)
        self.min_val = min_val
        self.max_val = max_val

        if init_self:
            self.sess.run(tf.global_variables_initializer())

    def save(self, path):
        weights = {}
        for model_weight in self.model_weights:
            weights[model_weight.name] = model_weight.eval(self.sess)
        with open(path, "wb") as f:
            pickle.dump(weights, f)

    def load(self, path):
        with open(path, "rb") as f:
            weights = pickle.load(f)

        for model_weight in self.model_weights:
            assign_op = model_weight.assign(weights[model_weight.name])
            self.sess.run(assign_op)

    def get_prediction_graph(self, inputs):
        return self._build_model(self.scope, inputs, net_arch=self.net_arch, reuse=True)[:, 0]

    def __call__(self, inputs):
        return np.clip(self.sess.run(self.value_flat, {self.contexts_ph: inputs}), self.min_val, self.max_val)

    def setup_model(self, scope, reuse=False):
        with tf.variable_scope("loss_" + scope, reuse=reuse):
            self.vals_ph = tf.placeholder(tf.float64, [None], name="values_ph")
            loss = tf.square(self.vals_ph - self.value_flat)
            self.loss = .5 * tf.reduce_mean(loss)
        with tf.variable_scope(scope, reuse=reuse):
            self.params = tf.trainable_variables()
            grads = tf.gradients(self.loss, self.params)
            grads = list(zip(grads, self.params))
            trainer = tf.train.AdamOptimizer(learning_rate=self.lr, epsilon=self.eps)
            self._train = trainer.apply_gradients(grads)

    def update_model(self, samples,
                     rewards):  # Since TF doesn't support multi-input network, len(samples) must be equal to len(context)
        # Can be only 1 if we are working with 1d contexts
        if len(samples.shape) != 2:
            samples = samples[:, None]

        norm = np.maximum(1., np.abs(rewards))
        # FIrst we optimize for the specified number of iterations
        for i in range(self.min_steps):
            self.sess.run(self._train, {self.contexts_ph: samples, self.vals_ph: rewards})

        # Then we continue until reaching the desired minimum accuracy
        preds = self.sess.run(self.value_flat, {self.contexts_ph: samples, self.vals_ph: rewards})
        err = np.mean(np.abs(rewards - preds) / norm)
        count = 10
        while err > self.rel_err and count < self.max_steps:
            self.sess.run(self._train, {self.contexts_ph: samples, self.vals_ph: rewards})
            preds = self.sess.run(self.value_flat, {self.contexts_ph: samples, self.vals_ph: rewards})
            err = np.mean(np.abs(rewards - preds) / norm)
            count += 1

        if err > self.rel_err:
            print("WARNING! Relative prediction error of the value model is still %.3e > %.3e" % (err, self.rel_err))

    def _build_model(self, scope, input, net_arch=None, reuse=False):
        # First add a normalization layer such that the inputs are iwhtin [-1, 1]
        scale = tf.constant(self.context_bounds[1] - self.context_bounds[0], dtype=tf.float64)
        offset = tf.constant(self.context_bounds[0], dtype=tf.float64)
        # Now in [0, 1]
        input = (input - offset) / scale
        # Now in [-1, 1]
        input = 2 * (input - 0.5)

        if net_arch is None:
            layers = [64, 64]
            act_func = tf.nn.relu
            single_act_func = True
        else:
            layers = net_arch['layers']
            if isinstance(net_arch['act_func'], list):
                assert len(net_arch['act_func'],
                           len(
                               layers)), 'Value Estimator Error: Activation functions number must agree with the network layers'
                act_func = net_arch['act_func']
                single_act_func = False
            else:
                act_func = net_arch['act_func']
                single_act_func = True
        initializer = tf.initializers.glorot_normal()
        with tf.variable_scope(scope, reuse=reuse):
            if single_act_func:
                for i, layer_size in enumerate(layers):
                    if i == 0:
                        vf_h = act_func(tf.layers.dense(input, layer_size, name='value_estimator_l_' + str(i),
                                                        kernel_initializer=initializer))
                    else:
                        vf_h = act_func(tf.layers.dense(vf_h, layer_size, name='value_estimator_l_' + str(i),
                                                        kernel_initializer=initializer))

            else:
                for i, (layer_size, activ) in enumerate(zip(layers, act_func)):
                    if i == 0:
                        vf_h = activ(tf.layers.dense(input, layer_size, name='value_estimator_l_' + str(i),
                                                     kernel_initializer=initializer))
                    else:
                        vf_h = activ(tf.layers.dense(vf_h, layer_size, name='value_estimator_l_' + str(i),
                                                     kernel_initializer=initializer))
            return tf.layers.dense(vf_h, 1, name='value_function', kernel_initializer=initializer)


class MovingAverage:

    def __init__(self, bins, buffer_size, bounds=(0., 1.)):
        self.positions = np.linspace(bounds[0], bounds[1], bins)
        self.reward_buffer = [[0.] for _ in range(bins)]
        self.buffer_size = buffer_size
        self.values = np.zeros(bins)

    def update_model(self, idxs, rewards):
        for idx, reward in zip(idxs, rewards):
            self.reward_buffer[idx].append(reward)
            if len(self.reward_buffer[idx]) > self.buffer_size:
                self.reward_buffer[idx] = self.reward_buffer[idx][1:]

        self.values = np.array([np.mean(rb) for rb in self.reward_buffer])

    def __call__(self, samples):
        return np.interp(np.squeeze(samples), self.positions, self.values)
