import tensorflow as tf

from rllab.core.serializable import Serializable

from sac.misc.mlp import MLPFunction
from sac.misc import tf_utils

class NNVFunction(MLPFunction):

    def __init__(self, env_spec, hidden_layer_sizes=(100, 100), name='vf', batchnormvf=False, dropoutvf_keep_prob=1.0):
        Serializable.quick_init(self, locals())

        self._Do = env_spec.observation_space.flat_dim
        self._obs_pl = tf.placeholder(
            tf.float32,
            shape=[None, self._Do],
            name='observation',
        )
        self.dropoutvf_placeholder = tf.placeholder(shape=[], dtype=tf.float32)
        super(NNVFunction, self).__init__(
            name, (self._obs_pl,), 
            hidden_layer_sizes, batchnorm=batchnormvf, 
            dropoutvf_placeholder=self.dropoutvf_placeholder if dropoutvf_keep_prob < 1.0 else None,
            dropoutvf_keep_prob=dropoutvf_keep_prob)


class NNQFunction(MLPFunction):
    def __init__(self, env_spec, hidden_layer_sizes=(100, 100), name='qf', batchnormvf=False):
        Serializable.quick_init(self, locals())

        self._Da = env_spec.action_space.flat_dim
        self._Do = env_spec.observation_space.flat_dim

        self._obs_pl = tf.placeholder(
            tf.float32,
            shape=[None, self._Do],
            name='observation',
        )

        self._action_pl = tf.placeholder(
            tf.float32,
            shape=[None, self._Da],
            name='actions',
        )
        # We currently do not use batchnorm on Q network, as it deteriorates performance
        super(NNQFunction, self).__init__(
            name, (self._obs_pl, self._action_pl), hidden_layer_sizes, batchnorm=False)


class NNDiscriminatorFunction(MLPFunction):
    def __init__(self, env_spec, hidden_layer_sizes=(100, 100), num_skills=None):
        assert num_skills is not None
        Serializable.quick_init(self, locals())
        Parameterized.__init__(self)

        self._Da = env_spec.action_space.flat_dim
        self._Do = env_spec.observation_space.flat_dim

        self._obs_pl = tf.placeholder(
            tf.float32,
            shape=[None, self._Do],
            name='observation',
        )
        self._action_pl = tf.placeholder(
            tf.float32,
            shape=[None, self._Da],
            name='actions',
        )

        self._name = 'discriminator'
        self._input_pls = (self._obs_pl, self._action_pl)
        self._layer_sizes = list(hidden_layer_sizes) + [num_skills]
        self._output_t = self.get_output_for(*self._input_pls)
