from copy import copy
from functools import reduce
import functools
import os

import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
from tensorflow.keras import regularizers
from ddpg import logger
from ddpg.common.mpi_adam import MpiAdam
import ddpg.common.tf_util as U
from ddpg.common.mpi_running_mean_std import RunningMeanStd
try:
    from mpi4py import MPI
except ImportError:
    MPI = None

def normalize(x, stats):
    if stats is None:
        return x
    return (x - stats.mean) / (stats.std + 1e-8)


def denormalize(x, stats):
    if stats is None:
        return x
    return x * stats.std + stats.mean

def reduce_std(x, axis=None, keepdims=False):
    return tf.sqrt(reduce_var(x, axis=axis, keepdims=keepdims))

def reduce_var(x, axis=None, keepdims=False):
    m = tf.reduce_mean(x, axis=axis, keepdims=True)
    devs_squared = tf.square(x - m)
    return tf.reduce_mean(devs_squared, axis=axis, keepdims=keepdims)

def get_target_updates(vars, target_vars, tau):
    logger.info('setting up target updates ...')
    soft_updates = []
    init_updates = []
    assert len(vars) == len(target_vars)
    for var, target_var in zip(vars, target_vars):
        logger.info('  {} <- {}'.format(target_var.name, var.name))
        init_updates.append(tf.assign(target_var, var))
        soft_updates.append(tf.assign(target_var, (1. - tau) * target_var + tau * var))
    assert len(init_updates) == len(vars)
    assert len(soft_updates) == len(vars)
    return tf.group(*init_updates), tf.group(*soft_updates)


def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev):
    assert len(actor.vars) == len(perturbed_actor.vars)
    assert len(actor.perturbable_vars) == len(perturbed_actor.perturbable_vars)

    updates = []
    for var, perturbed_var in zip(actor.vars, perturbed_actor.vars):
        if var in actor.perturbable_vars:
            logger.info('  {} <- {} + noise'.format(perturbed_var.name, var.name))
            updates.append(tf.assign(perturbed_var, var + tf.random_normal(tf.shape(var), mean=0., stddev=param_noise_stddev)))
        else:
            logger.info('  {} <- {}'.format(perturbed_var.name, var.name))
            updates.append(tf.assign(perturbed_var, var))
    assert len(updates) == len(actor.vars)
    return tf.group(*updates)

def samp(act):

    act = np.random.binomial(1, act)
    act = np.where(act > 0.5, act, 0.)  
    act = np.where(act == 0., act, 1.) 

    return act

def entry_stop_gradients(target, mask):
    mask_h = tf.abs(mask-1)
    return tf.stop_gradient(mask_h * target) + mask * target


class DDPG(object):
    def __init__(self, actor, critic, memory, param_noise=None, action_noise=None,
        gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True,
        batch_size=128, observation_range=(-np.inf, np.inf), action_range=(0.2, 0.8), return_range=(-np.inf, np.inf),
        critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.):
        # Inputs.
        self.var0 = tf.placeholder(tf.float32, shape=(None,) + (20,), name='var0')
        self.var1 = tf.placeholder(tf.float32, shape=(None,) + (20,), name='var1')

        self.n_vars = tf.placeholder(tf.int32, shape=(None,), name='n_var')

        self.terminals1 = tf.placeholder(tf.float32, shape=(None, 1), name='terminals1')
        self.rewards = tf.placeholder(tf.float32, shape=(None, 1), name='rewards')
        self.actions = tf.placeholder(tf.float32, shape=(None,) + (1,), name='actions')
        self.critic_target = tf.placeholder(tf.float32, shape=(None, 1), name='critic_target')
        self.param_noise_stddev = tf.placeholder(tf.float32, shape=(), name='param_noise_stddev')
        self.next_actions = tf.placeholder(tf.float32, shape=(None,) + (1,), name='next_actions')

        self.Q0 = tf.placeholder(tf.float32, shape=(None, 1), name='Q0')          #QA2C

        # Parameters.
        self.gamma = gamma
        self.tau = tau
        self.memory = memory
        self.normalize_observations = normalize_observations
        self.normalize_returns = normalize_returns
        self.action_noise = action_noise
        self.param_noise = param_noise
        self.action_range = action_range
        self.return_range = return_range
        self.observation_range = observation_range
        self.critic = critic
        self.actor = actor
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.clip_norm = clip_norm
        self.enable_popart = enable_popart
        self.reward_scale = reward_scale
        self.batch_size = batch_size
        self.stats_sample = None
        self.critic_l2_reg = critic_l2_reg

        # TODO: Observation normalization.

        obs0 = [self.var0, self.n_vars]
        obs1 = [self.var1, self.n_vars]

        # Return normalization.
        if self.normalize_returns:
            with tf.variable_scope('ret_rms'):
                self.ret_rms = RunningMeanStd()
        else:
            self.ret_rms = None

        # Create target networks.
        target_actor = copy(actor)
        # target_actor.name = 'target_actor'   
        self.target_actor = target_actor 
        target_critic = copy(critic)
        # target_critic.name = 'target_critic'
        self.target_critic = target_critic

        # Create networks and core TF parts that are shared across setup parts.
        print(obs0)
        self.actor_tf = actor.call(obs0)

        self.normalized_critic_tf = critic.call(obs0, self.actions)
        self.critic_tf = denormalize(tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
        self.normalized_critic_with_actor_tf = critic.call(obs0, self.actor_tf)
        
        # 噪声扰动
        self.critic_with_actor_tf = denormalize(tf.clip_by_value(self.normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
#        Q_obs1 = denormalize(target_critic(normalized_obs1, target_actor(normalized_obs1)), self.ret_rms)

        self.target_act = target_actor.call(obs1)
        Q_obs1 = denormalize(target_critic.call(obs1, self.next_actions), self.ret_rms)

        self.target_Q = self.rewards + gamma * Q_obs1
#         self.target_Q = self.rewards + (1. - self.terminals1) * gamma * Q_obs1

        self.Q_obs0 = denormalize(critic.call(obs0, self.actions), self.ret_rms)        #QA2C

        # Set up parts.
        if self.param_noise is not None:
            self.setup_param_noise(obs0)
        self.setup_actor_optimizer()
        self.setup_critic_optimizer()
        if self.normalize_returns and self.enable_popart:
            self.setup_popart()
        self.setup_stats()
        self.setup_target_network_updates()

        self.initial_state = None # recurrent architectures not supported yet

    def setup_target_network_updates(self):
        actor_init_updates, actor_soft_updates = get_target_updates(self.actor.variables, self.target_actor.variables, self.tau)
        critic_init_updates, critic_soft_updates = get_target_updates(self.critic.variables, self.target_critic.variables, self.tau)
        self.target_init_updates = [actor_init_updates, critic_init_updates]
        self.target_soft_updates = [actor_soft_updates, critic_soft_updates]

    def setup_param_noise(self, normalized_obs0):
        assert self.param_noise is not None

        # Configure perturbed actor.
        param_noise_actor = copy(self.actor)
        param_noise_actor.name = 'param_noise_actor'
        self.perturbed_actor_tf = param_noise_actor(normalized_obs0)
        logger.info('setting up param noise')
        self.perturb_policy_ops = get_perturbed_actor_updates(self.actor, param_noise_actor, self.param_noise_stddev)

        # Configure separate copy for stddev adoption.
        adaptive_param_noise_actor = copy(self.actor)
        adaptive_param_noise_actor.name = 'adaptive_param_noise_actor'
        adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0)
        self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(self.actor, adaptive_param_noise_actor, self.param_noise_stddev)
        self.adaptive_policy_distance = tf.sqrt(tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))

    def setup_actor_optimizer(self):
        logger.info('setting up actor optimizer')
                  #### original
#        self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)

         ##QA2C 1
#        self.actor_tff = tf.math.log(self.actor_tf)
#        self.inter = entry_stop_gradients(self.actor_tff, self.actions)
#        self.actor_loss = -tf.reduce_mean(tf.squeeze(tf.reduce_sum(self.inter,1))*self.Q0)                      #QA2C

        ##QA2C 2
        self.prob1 = 1-self.actions
        self.actor_tff=self.actor_tf-self.prob1
        self.actor_tff=tf.math.abs(self.actor_tff)
        self.actor_tff = tf.clip_by_value(self.actor_tff, 0.001, 0.999)
        self.inter = tf.math.log(self.actor_tff)       

        self.actor_loss = -tf.reduce_mean(tf.reduce_mean(self.inter,1)*self.Q0)                      #QA2C

        actor_shapes = [var.get_shape().as_list() for var in self.actor.variables]
        actor_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
        logger.info('  actor shapes: {}'.format(actor_shapes))
        logger.info('  actor params: {}'.format(actor_nb_params))

        self.actor_grads = U.flatgrad(self.actor_loss, self.actor.variables, clip_norm=self.clip_norm)

        self.actor_optimizer = MpiAdam(var_list=self.actor.variables,
            beta1=0.9, beta2=0.999, epsilon=1e-08)

    def setup_critic_optimizer(self):
        logger.info('setting up critic optimizer')
        normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms), self.return_range[0], self.return_range[1])
        self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
        if self.critic_l2_reg > 0.:
            critic_reg_vars = [var for var in self.critic.variables if 'output' not in var.name]
            for var in critic_reg_vars:
                logger.info('  regularizing: {}'.format(var.name))
            logger.info('  applying l2 regularization with {}'.format(self.critic_l2_reg))
            l2_regularizer = regularizers.l2(self.critic_l2_reg)
            critic_reg = sum([tf.reduce_sum(l2_regularizer(var)) for var in critic_reg_vars])
            # critic_reg = tc.layers.apply_regularization(
            #     tc.layers.l2_regularizer(self.critic_l2_reg),
            #     weights_list=critic_reg_vars
            # )
            self.critic_loss += critic_reg
        critic_shapes = [var.get_shape().as_list() for var in self.critic.variables]
        critic_nb_params = sum([reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
        logger.info('  critic shapes: {}'.format(critic_shapes))
        logger.info('  critic params: {}'.format(critic_nb_params))
        self.critic_grads = U.flatgrad(self.critic_loss, self.critic.variables, clip_norm=self.clip_norm)
        self.critic_optimizer = MpiAdam(var_list=self.critic.variables,
            beta1=0.9, beta2=0.999, epsilon=1e-08)

    def setup_popart(self):
        # See https://arxiv.org/pdf/1602.07714.pdf for details.
        self.old_std = tf.placeholder(tf.float32, shape=[1], name='old_std')
        new_std = self.ret_rms.std
        self.old_mean = tf.placeholder(tf.float32, shape=[1], name='old_mean')
        new_mean = self.ret_rms.mean

        self.renormalize_Q_outputs_op = []
        for vs in [self.critic.output_vars, self.target_critic.output_vars]:
            assert len(vs) == 2
            M, b = vs
            assert 'kernel' in M.name
            assert 'bias' in b.name
            assert M.get_shape()[-1] == 1
            assert b.get_shape()[-1] == 1
            self.renormalize_Q_outputs_op += [M.assign(M * self.old_std / new_std)]
            self.renormalize_Q_outputs_op += [b.assign((b * self.old_std + self.old_mean - new_mean) / new_std)]

    def setup_stats(self):
        ops = []
        names = []

        if self.normalize_returns:
            ops += [self.ret_rms.mean, self.ret_rms.std]
            names += ['ret_rms_mean', 'ret_rms_std']

        if self.normalize_observations:
            ops += [tf.reduce_mean(self.obs_rms.mean), tf.reduce_mean(self.obs_rms.std)]
            names += ['obs_rms_mean', 'obs_rms_std']

        ops += [tf.reduce_mean(self.critic_tf)]
        names += ['reference_Q_mean']
        ops += [reduce_std(self.critic_tf)]
        names += ['reference_Q_std']

        ops += [tf.reduce_mean(self.critic_with_actor_tf)]
        names += ['reference_actor_Q_mean']
        ops += [reduce_std(self.critic_with_actor_tf)]
        names += ['reference_actor_Q_std']

        ops += [tf.reduce_mean(self.actor_tf)]
        names += ['reference_action_mean']
        ops += [reduce_std(self.actor_tf)]
        names += ['reference_action_std']

        if self.param_noise:
            ops += [tf.reduce_mean(self.perturbed_actor_tf)]
            names += ['reference_perturbed_action_mean']
            ops += [reduce_std(self.perturbed_actor_tf)]
            names += ['reference_perturbed_action_std']

        self.stats_ops = ops
        self.stats_names = names

    def step(self, obs, apply_noise=True, compute_Q=True):
        if self.param_noise is not None and apply_noise:
            actor_tf = self.perturbed_actor_tf
        else:
            actor_tf = self.actor_tf
        
        updated_v_features, n_vs_per_sample = obs

        feed_dict = {self.var0: updated_v_features, self.n_vars:n_vs_per_sample}
        if compute_Q:
            
            action, q = self.sess.run([actor_tf, self.critic_with_actor_tf], feed_dict=feed_dict)

        else:
            action = self.sess.run(actor_tf, feed_dict=feed_dict)
            q = None
        if self.action_noise is not None and apply_noise:
            noise = self.action_noise()
            noise = noise[:,np.newaxis]
            assert noise.shape == action[0].shape
            action += noise

        action = np.clip(action, self.action_range[0], self.action_range[1])
        return action, q, None, None

    def next_step(self, obs, apply_noise=True, compute_Q=False):
        if self.param_noise is not None and apply_noise:
            actor_tf = self.target_act
        else:
            actor_tf = self.target_act
        feed_dict = {self.obs1: U.adjust_shape(self.obs1, [obs])}
        if compute_Q:
            action, q = self.sess.run([actor_tf, self.critic_with_actor_tf], feed_dict=feed_dict)
        else:
            action = self.sess.run(actor_tf, feed_dict=feed_dict)
            q = None

        if self.action_noise is not None and apply_noise:
            noise = self.action_noise()
            assert noise.shape == action[0].shape
            action += noise
        action = np.clip(action, self.action_range[0], self.action_range[1])

        return action, q, None, None


    def store_transition(self, states, action, reward, dynamic_feature0, dynamic_feature1, next_action, ins, variable_num):
        reward *= self.reward_scale

        self.memory.append(states, action, reward, dynamic_feature0, dynamic_feature1, next_action, ins, variable_num)
        if self.normalize_observations:
            self.obs_rms.update(np.array([dynamic_feature0]))
            self.obs_rms.update(np.array([dynamic_feature1]))


    def train(self):
        # Get a batch.
        flag = 0
        while flag == 0:
            batch = self.memory.sample(batch_size=self.batch_size)
            if abs(np.min(batch['rewards'])) > 1e-9 or abs(np.max(batch['rewards'])) > 1e-9:
                flag = 1 

        v_features0, v_features1, n_vs_per_sample = batch['variable_features0'], batch['variable_features1'], batch['variable_num']
        rewards, actions, actions_next = batch['rewards'], batch['actions'], batch['next_actions']

        feed_dict = {self.var0: v_features0, self.var1: v_features1,
                     self.n_vars:n_vs_per_sample, self.rewards: rewards,
                    self.actions: actions, self.next_actions: actions_next}

        if self.normalize_returns and self.enable_popart:
            
            old_mean, old_std, target_Q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_Q], feed_dict=feed_dict)
            feed_dict[self.critic_target] = target_Q

            self.ret_rms.update(target_Q.flatten())
            self.sess.run(self.renormalize_Q_outputs_op, feed_dict={
                self.old_std : np.array([old_std]),
                self.old_mean : np.array([old_mean]),
            })

            # Run sanity check. Disabled by default since it slows down things considerably.
            # print('running sanity check')
            # target_Q_new, new_mean, new_std = self.sess.run([self.target_Q, self.ret_rms.mean, self.ret_rms.std], feed_dict={
            #     self.obs1: batch['obs1'],
            #     self.rewards: batch['rewards'],
            #     self.terminals1: batch['terminals1'].astype('float32'),
            # })
            # print(target_Q_new, target_Q, new_mean, new_std)
            # assert (np.abs(target_Q - target_Q_new) < 1e-3).all()
        else:
            target_Q = self.sess.run(self.target_Q, feed_dict=feed_dict)

            #QA2C
            Q0 = self.sess.run(self.Q_obs0, feed_dict=feed_dict)

            feed_dict[self.critic_target] = target_Q
            feed_dict[self.Q0] = Q0
            # print(Q0, target_Q)

            act_loss = self.sess.run(self.actor_loss, feed_dict=feed_dict)
            print(act_loss)

        # Get all gradients and perform a synced update.
        ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss]
        actor_grads, actor_loss, critic_grads, critic_loss = self.sess.run(ops, feed_dict=feed_dict)
        
        self.actor_optimizer.update(actor_grads, stepsize=self.actor_lr)
        self.critic_optimizer.update(critic_grads, stepsize=self.critic_lr)
        return critic_loss, actor_loss

    def initialize(self, sess):
        self.sess = sess

        self.save = functools.partial(U.save_variables, sess=self.sess)   #ADD
        self.load = functools.partial(U.load_variables, sess=self.sess)   #ADD

        self.sess.run(tf.global_variables_initializer())
        self.actor_optimizer.sync()
        self.critic_optimizer.sync()
        self.sess.run(self.target_init_updates)

    def update_target_net(self):
        self.sess.run(self.target_soft_updates)

    def get_stats(self):
        if self.stats_sample is None:
            # Get a sample and keep that fixed for all further computations.
            # This allows us to estimate the change in value for the same set of inputs.
            self.stats_sample = self.memory.sample(batch_size=self.batch_size)
        values = self.sess.run(self.stats_ops, feed_dict={
            self.obs0: self.stats_sample['obs0'],
            self.actions: self.stats_sample['actions'],
        })

        names = self.stats_names[:]
        assert len(names) == len(values)
        stats = dict(zip(names, values))

        if self.param_noise is not None:
            stats = {**stats, **self.param_noise.get_stats()}

        return stats

    def adapt_param_noise(self):
        try:
            from mpi4py import MPI
        except ImportError:
            MPI = None

        if self.param_noise is None:
            return 0.

        # Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
        batch = self.memory.sample(batch_size=self.batch_size)
        self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
            self.param_noise_stddev: self.param_noise.current_stddev,
        })
        distance = self.sess.run(self.adaptive_policy_distance, feed_dict={
            self.obs0: batch['obs0'],
            self.param_noise_stddev: self.param_noise.current_stddev,
        })

        if MPI is not None:
            mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
        else:
            mean_distance = distance

        self.param_noise.adapt(mean_distance)
        return mean_distance

    def reset(self):
        # Reset internal state after an episode is complete.
        if self.action_noise is not None:
            self.action_noise.reset()
        if self.param_noise is not None:
            self.sess.run(self.perturb_policy_ops, feed_dict={
                self.param_noise_stddev: self.param_noise.current_stddev,
            })
