import numpy as np
import tensorflow as tf

INIT_SC = 0.25
DIM = 256
R_DIM = 64

class RandomNormal:
    def __init__(self, mean, logstd):
        self.raw_logstd = logstd
        if len(mean.get_shape().as_list()) > len(logstd.get_shape().as_list()):
            logstd = mean*0.0 + logstd
        self.mean = mean
        self.logstd = logstd
        self.std = tf.maximum(tf.exp(logstd),1e-2)

    def log_p(self, x):
        return tf.reduce_sum(
            -0.5*np.log(2.0*np.pi)-self.logstd-0.5*tf.square((x-self.mean)/self.std),
            axis=-1, keepdims=True
        )

    def entropy(self):
        return tf.reduce_sum(self.logstd+0.5*np.log(2.0*np.pi*np.e), axis=-1, keepdims=True)

    def kl(self, other):
        return tf.reduce_sum(
            -0.5+other.logstd-self.logstd
            +0.5*tf.square(self.std/other.std)
            +0.5*tf.square((self.mean-other.mean)/other.std),
            axis=-1, keepdims=True
        )

def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):
    with tf.variable_scope(scope):
        nin = x.get_shape()[-1].value
        w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
        b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias))
        return tf.matmul(x, w)+b

def ortho_init(scale=1.0):
    def _ortho_init(shape, dtype, partition_info=None):
        shape = tuple(shape)
        if len(shape) == 2:
            flat_shape = shape
        elif len(shape) == 4:
            flat_shape = (np.prod(shape[:-1]), shape[-1])
        else:
            raise NotImplementedError
        a = np.random.normal(0.0, 1.0, flat_shape)
        u, _, v = np.linalg.svd(a, full_matrices=False)
        q = u if u.shape == flat_shape else v
        q = q.reshape(shape)
        return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
    return _ortho_init

def policy_forward(obs, actdim, max_ac, name_space, reuse=True):
    with tf.variable_scope(name_space, reuse=reuse):
        activ = tf.nn.relu
        h1 = activ(fc(obs, 'pi_fc1', nh=DIM, init_scale=np.sqrt(INIT_SC)))
        h2 = activ(fc(h1, 'pi_fc2', nh=DIM, init_scale=np.sqrt(INIT_SC)))
        pi_h = fc(h2, 'pi', actdim + actdim, init_scale=np.sqrt(INIT_SC))
        pi_mean = pi_h[:,:actdim]
        pi_logstd = tf.clip_by_value(pi_h[:,actdim:], -20.0, 2.0)
        pi = RandomNormal(mean=pi_mean, logstd=pi_logstd)

    return pi

class MlpPolicy(object):
    def __init__(self, sess, ob_space, ac_space, max_ac):
        self.ob_shape = (None, ob_space[0])
        self.actdim = ac_space.shape[0]
        self.log_alpha = tf.get_variable(name='policy/log_alpha', shape=[1], dtype=np.float32, trainable=True)
        self.alpha = tf.stop_gradient(tf.exp(self.log_alpha))
        X = tf.placeholder(tf.float32, self.ob_shape, name='Ob')
        self.pi_noise_ph = tf.placeholder(shape=(None, self.actdim), dtype=tf.float32, name='pi_noise')
        self.pi_noise_target_ph = tf.placeholder(shape=(None, self.actdim), dtype=tf.float32, name='pi_target_noise')
        policy_forward(X, self.actdim, max_ac, 'policy', reuse=False)
        self.max_ac = max_ac
        self.X = X
        pi_act, pi_log_p = self.policy_func(X, self.pi_noise_ph, 'policy', reuse=True)
        self.ac = pi_act
        self.ac_log_p = pi_log_p
        self.pi_params = tf.trainable_variables("policy")

        def step(ob, test=False):
            if test:
                noise = np.zeros((ob.shape[0],self.actdim))
            else:
                noise = np.random.normal(0,1,size=(ob.shape[0],self.actdim))
            return sess.run([pi_act], {X:ob, self.pi_noise_ph:noise})

        self.step = step

    def policy_func(self, obs, pi_noise_ph, namespace, reuse=True):
        def get_pi_log_p(pi, pi_sample, pi_act):
            return pi.log_p(pi_sample) - tf.reduce_sum(tf.log(1-tf.square(pi_act)+1e-6), axis=-1, keepdims=True)

        pi = policy_forward(obs, self.actdim, self.max_ac, namespace, reuse=reuse)
        pi_sample = pi.mean+pi_noise_ph*pi.std
        pi_act = tf.tanh(pi_sample)
        pi_log_p = get_pi_log_p(pi, pi_sample, pi_act)

        return pi_act, pi_log_p

def Reward(R_X, name_space, reuse=False):
    activ = tf.nn.relu
    R_second_X = tf.concat((R_X[:, :-1], R_X[:, 1:]), axis=-1)
    with tf.variable_scope(name_space + '1', reuse=reuse):
        h1 = activ(fc(R_X, 'r_fc1', nh=R_DIM, init_scale=np.sqrt(INIT_SC)))
        h2 = activ(fc(h1, 'r_fc2', nh=R_DIM,  init_scale=np.sqrt(INIT_SC)))
        r1 = fc(h2, 'r', 1, init_scale=np.sqrt(INIT_SC))
    with tf.variable_scope(name_space + '2', reuse=reuse):
        h1 = activ(fc(R_X, 'r_fc1', nh=R_DIM, init_scale=np.sqrt(INIT_SC)))
        h2 = activ(fc(h1, 'r_fc2', nh=R_DIM, init_scale=np.sqrt(INIT_SC)))
        r2 = fc(h2, 'r', 1, init_scale=np.sqrt(INIT_SC))
    with tf.variable_scope(name_space + '1', reuse=reuse):
        h1 = activ(fc(R_second_X, 'r_second_fc1', nh=R_DIM, init_scale=np.sqrt(INIT_SC)))
        h2 = activ(fc(h1, 'r_second_fc2', nh=R_DIM,  init_scale=np.sqrt(INIT_SC)))
        rs1 = fc(h2, 'r_second', 1, init_scale=np.sqrt(INIT_SC))
    with tf.variable_scope(name_space + '2', reuse=reuse):
        h1 = activ(fc(R_second_X, 'r_second_fc1', nh=R_DIM, init_scale=np.sqrt(INIT_SC)))
        h2 = activ(fc(h1, 'r_second_fc2', nh=R_DIM,  init_scale=np.sqrt(INIT_SC)))
        rs2 = fc(h2, 'r_second', 1, init_scale=np.sqrt(INIT_SC))
    return r1, r2, rs1, rs2

def Qfunc(Q_X, name_space, reuse=False):
    activ = tf.nn.relu
    with tf.variable_scope(name_space + "1", reuse=reuse):
        h1 = activ(fc(Q_X, 'q1_fc1', nh=DIM, init_scale=np.sqrt(INIT_SC)))
        h2 = activ(fc(h1, 'q1_fc2', nh=DIM,  init_scale=np.sqrt(INIT_SC)))
        q1 = fc(h2, 'q1', 1, init_scale=np.sqrt(INIT_SC))
    with tf.variable_scope(name_space + "2", reuse=reuse):
        h1 = activ(fc(Q_X, 'q2_fc1', nh=DIM, init_scale=np.sqrt(INIT_SC)))
        h2 = activ(fc(h1, 'q2_fc2', nh=DIM, init_scale=np.sqrt(INIT_SC)))
        q2 = fc(h2, 'q2', 1, init_scale=np.sqrt(INIT_SC))

    return q1, q2


class b_Qfunc(object):
    def __init__(self, ob_space, ac_space, reward_freq):
        self.input_size = (reward_freq, ob_space[0] + ac_space.shape[0])
        q_input_shape = (None, self.input_size[1])
        r_input_shape = (None, reward_freq, self.input_size[1])
        Q_X = tf.placeholder(tf.float32, q_input_shape, name='Q_input')
        Qfunc(Q_X, 'qfunc', reuse=False)
        Qfunc(Q_X, 'target_qfunc', reuse=False)
        coef = tf.ones([reward_freq, reward_freq], dtype=tf.float32)
        self.coef = tf.linalg.band_part(coef, -1, 0) - tf.eye(reward_freq, dtype=tf.float32)
        self.target_coef = tf.linalg.band_part(coef, -1, 0)
        self.R_X = tf.placeholder(tf.float32, r_input_shape, name='R_input')
        self.r1, self.r2, self.rs1, self.rs2 = Reward(self.R_X, 'reward', reuse=False)
        self.target_r1, self.target_r2, self.target_rs1, self.target_rs2 = Reward(self.R_X, 'target_reward', reuse=False)
        self.q1_params = tf.trainable_variables("qfunc1") + tf.trainable_variables("reward1")
        self.q2_params = tf.trainable_variables("qfunc2") + tf.trainable_variables("reward2")
        self.qt1_params = tf.trainable_variables("target_qfunc1") + tf.trainable_variables("target_reward1")
        self.qt2_params = tf.trainable_variables("target_qfunc2") + tf.trainable_variables("target_reward2")

    def Qvalue_forward(self, ob, ac, coef, reuse=True):
        coef = tf.expand_dims(tf.reduce_sum(coef * self.coef, axis=1), -1)
        b1, b2 = self.get_b(coef, 'reward')
        q_input = tf.concat((ob, ac), axis=-1)
        q1, q2 = Qfunc(q_input, 'qfunc', reuse=reuse)
        return b1 + q1, b2 + q2

    def Qfunc_forward(self, ob, ac, reuse=True):
        q_input = tf.concat((ob, ac), axis=-1)
        q1, q2 = Qfunc(q_input, 'qfunc', reuse=reuse)
        return q1, q2

    def min_target_Q(self, ob, ac, coef, r_sig, reuse=True):
        target_coef = tf.expand_dims(tf.reduce_sum(coef * self.target_coef, axis=1), -1)
        bt1, bt2 = self.get_b(target_coef, 'target_reward')
        q_input = tf.concat((ob, ac), axis=-1)
        qt1, qt2 = Qfunc(q_input, 'target_qfunc', reuse=reuse)
        qt1, qt2 = (1. - r_sig) * bt1 + qt1, (1. - r_sig) * bt2 + qt2
        return tf.minimum(qt1, qt2)

    def get_b(self, coef, namespace):
        if namespace is 'reward':
            b1 = tf.reduce_sum(coef * self.r1, axis=1) + tf.reduce_sum(coef[:, 1:] * self.rs1, axis=1)
            b2 = tf.reduce_sum(coef * self.r2, axis=1) + tf.reduce_sum(coef[:, 1:] * self.rs2, axis=1)
            return b1, b2
        elif namespace is 'target_reward':
            bt1 = tf.reduce_sum(coef * self.target_r1, axis=1) + tf.reduce_sum(coef[:, 1:] * self.target_rs1, axis=1)
            bt2 = tf.reduce_sum(coef * self.target_r2, axis=1) + tf.reduce_sum(coef[:, 1:] * self.target_rs2, axis=1)
            return bt1, bt2

