

import numpy as np
import tensorflow as tf
from policies import MlpPolicy, b_Qfunc


class Model(object):
    def __init__(self, *, policy, ob_space, ac_space, batch_size, policy_freq, max_ac, reward_n, reg,
                scale_reward, tau, gamma, lr_ac, lr_cr, target_update_interval):

        self.sess = tf.get_default_session()
        self.batch_size = batch_size
        self.policy = policy(self.sess, ob_space, ac_space, max_ac)
        self.Q = b_Qfunc(ob_space, ac_space, reward_n)
        self.ob_space = ob_space
        self.ac_space = ac_space
        self.scale_reward = scale_reward
        self.poicy_freq = policy_freq
        self.tau = tau
        self.reg = reg
        self.reward_n = reward_n
        self.gamma = gamma
        self.lr_ac = lr_ac
        self.lr_cr = lr_cr
        self.target_update_interval = target_update_interval
        self.init_update()
        self.init_target_ops()

        tf.global_variables_initializer().run(session=self.sess)
        self.sess.run(self.target_ops)


    def init_update(self):
        DONE = tf.placeholder(shape=(self.batch_size, 1), dtype=tf.float32, name='done')
        REW = tf.placeholder(shape=(self.batch_size, 1), dtype=tf.float32, name='reward')
        REW_SUM = tf.placeholder(shape=(self.batch_size, 1), dtype=tf.float32, name='rew_sum')
        ACTION = tf.placeholder(shape=(self.batch_size, self.reward_n, self.ac_space.shape[0]), dtype=tf.float32, name='action')
        N_OBS = tf.placeholder(shape=(self.batch_size, self.ob_space[0]), dtype=tf.float32, name='n_obs')
        OBS = tf.placeholder(shape=(self.batch_size, self.ob_space[0]), dtype=tf.float32, name='obs')
        COEF = tf.placeholder(shape=(self.batch_size, self.reward_n, 1), dtype=tf.float32, name='coef')
        MASK = tf.placeholder(shape=(self.batch_size, self.reward_n, 1), dtype=tf.float32, name='mask')
        R_SIG = tf.placeholder(shape=(self.batch_size, 1), dtype=tf.float32, name='r_sig')

        n_ac, pi_log_p_target = self.policy.policy_func(N_OBS, self.policy.pi_noise_target_ph, 'policy', reuse=True)
        target_Q = self.Q.min_target_Q(N_OBS, n_ac, COEF, R_SIG, reuse=True) - self.policy.alpha * pi_log_p_target

        mask = tf.reduce_sum(MASK * COEF, axis=1)
        ys = tf.stop_gradient(self.scale_reward * REW + (1 - DONE) * self.gamma * target_Q)

        ac = tf.reduce_sum(ACTION * COEF, axis=1)
        q1, q2 = self.Q.Qvalue_forward(OBS, ac, COEF, reuse=True)

        mask_mean = tf.maximum(tf.reduce_mean(mask), 1E-3)

        b1, b2 = self.Q.get_b(MASK, 'reward')

        td_loss_1 = 0.5 * tf.reduce_mean(mask * (ys - q1) ** 2) / mask_mean
        td_loss_2 = 0.5 * tf.reduce_mean(mask * (ys - q2) ** 2) / mask_mean
        r1_loss = 0.5 * tf.reduce_mean((b1 - REW_SUM) ** 2)
        r2_loss = 0.5 * tf.reduce_mean((b2 - REW_SUM) ** 2)

        qf1_train_op = tf.train.AdamOptimizer(self.lr_cr, epsilon=1E-6).minimize(
            loss=td_loss_1 + self.reg * r1_loss,
            var_list=self.Q.q1_params
        )

        qf2_train_op = tf.train.AdamOptimizer(self.lr_cr, epsilon=1E-6).minimize(
            loss=td_loss_2 + self.reg * r2_loss,
            var_list=self.Q.q2_params
        )


        actions = self.policy.ac
        q, _ = self.Q.Qfunc_forward(OBS, actions, reuse=True)
        policy_loss = - tf.reduce_mean(mask * (q - self.policy.alpha * self.policy.ac_log_p)) / mask_mean
        alpha_loss = - self.policy.log_alpha * tf.stop_gradient(tf.reduce_mean(mask * self.policy.ac_log_p) / mask_mean - self.ac_space.shape[0])

        policy_train_op = tf.train.AdamOptimizer(self.lr_ac, epsilon=1E-6).minimize(
            loss=policy_loss + alpha_loss,
            var_list=self.policy.pi_params
        )

        def train(iter, batch, log=False):

            reward_input = np.concatenate((batch['obs'], batch['action']), axis=-1)
            data = {ACTION: batch['action'],
                   N_OBS: batch['n_obs'],
                   DONE: batch['done'],
                   REW_SUM: batch['rew_sum'],
                   REW: batch['reward'],
                   R_SIG: batch['r_sig'],
                   OBS: batch['pi_obs'],
                   COEF: batch['coef'],
                   MASK: batch['mask'],
                   self.policy.X: batch['pi_obs'],
                   self.Q.R_X: reward_input,
                   self.policy.pi_noise_ph:np.random.normal(0, 1, size=(batch['obs'].shape[0],self.ac_space.shape[0])),
                   self.policy.pi_noise_target_ph:np.random.normal(0, 1, size=(batch['obs'].shape[0],self.ac_space.shape[0]))
                   }

            op = list()

            if log:
                 q_mean = tf.reduce_mean(q1)
                 mu_mean = tf.reduce_mean(tf.abs(self.policy.ac))
                 op.extend([q_mean, mu_mean, td_loss_1, r1_loss])

            op.extend([qf1_train_op, qf2_train_op])

            if iter % self.poicy_freq == 0:
                op.extend([policy_train_op])

            logger = self.sess.run(op, data)

            if iter % self.target_update_interval == 0:
                self.sess.run(self.target_ops)

            return logger


        self.train = train


    def init_target_ops(self):

        self.target_ops = [
            tf.assign(target, (1 - self.tau) * target + self.tau * source)
            for target, source in zip(self.Q.qt1_params, self.Q.q1_params)
        ]
        self.target_ops.extend([
            tf.assign(target, (1 - self.tau) * target + self.tau * source)
            for target, source in zip(self.Q.qt2_params, self.Q.q2_params)
        ])
