from collections import OrderedDict

import numpy as np
import math
import tensorflow as tf
import tensorflow_probability as tfp
from flatten_dict import flatten
from numbers import Number


from softlearning.models.utils import flatten_input_structure
from softlearning.misc.kernel import adaptive_isotropic_gaussian_kernel
from softlearning.utils.tensorflow import nest
from softlearning.utils.gym import is_continuous_space, is_discrete_space
from .utils import gumbel_softmax

from .rl_algorithm import RLAlgorithm
from .sac import td_target
import pdb
from .utils import f_gan_conjuate,f_gan_activation

EPS = 1e-6



def assert_shape(tensor, expected_shape):
    tensor_shape = tensor.shape.as_list()
    assert len(tensor_shape) == len(expected_shape)
    assert all([a == b for a, b in zip(tensor_shape, expected_shape)])

def clip_but_pass_gradient(input_, lower=-1., upper=1.):
    clip_up = tf.cast(input_ > upper, tf.float32)
    clip_low = tf.cast(input_ < lower, tf.float32)
    return input_ + tf.stop_gradient((upper - input_) * clip_up + (lower - input_) * clip_low)

def heuristic_target_entropy(action_space):
    if is_continuous_space(action_space):
        heuristic_target_entropy = -np.prod(action_space.shape)
    elif is_discrete_space(action_space):
        raise NotImplementedError(
            "TODO(hartikainen): implement for discrete spaces.")
    else:
        raise NotImplementedError((type(action_space), action_space))

    return heuristic_target_entropy

class SQLDM(RLAlgorithm):
    """Soft Q-learning via Divergence Minimization(SQLDM).

    Example:
        See `examples/development.py`.

    References
    ----------
    [1] Tuomas Haarnoja, Haoran Tang, Pieter Abbeel, and Sergey Levine,
        "Reinforcement Learning with Deep Energy-Based Policies," International
        Conference on Machine Learning, 2017. https://arxiv.org/abs/1702.08165
    """
    def __init__(
            self,
            training_environment,
            evaluation_environment,
            policy,
            Qs,
            critic,
            pool,
            value=False,
            extractor=None,
            hyper_extractor=None,
            plotter=None,
            latent_dim = 10,
            components = 4,
            policy_lr=3e-4,
            critic_lr=3e-1, # critic_lr
            value_lr = 3e-4,
            mcmc_lr=1,
            Q_lr=3e-4,
            target_entropy='auto',
            value_n_particles=16,
            target_update_interval=1,
            kernel_fn=adaptive_isotropic_gaussian_kernel,
            kernel_n_particles=16,
            kernel_update_ratio=0.5,
            discount=0.99,
            noise_mean = 0.0,
            noise_std = 0.005,
            l_steps = 20, # langevin dynamic steps
            tau=5e-3,
            reward_scale=1,
            proj_norm = 0.0,
            gp_weight = 1.0,
            proj_type = "li",
            critic_mode = "jsd",
            use_saved_Q=False,
            use_saved_policy=False,
            save_full_state=False,
            train_Q=True,
            train_policy=True,
            train_critic=True,
            train_value=False,
            mcmc_on_latent=False,
            **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment object used for training.
            policy: A policy function approximator.
            critic: A critic for estimating the duality of divergence
            dis_lr: the learning rate for the critic
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            Q_lr (`float`): Learning rate used for the Q-function approximator.
            value_n_particles (`int`): The number of action samples used for
                estimating the value of next state.
            target_update_interval (`int`): How often the target network is
                updated to match the current Q-function.
            kernel_fn (function object): A function object that represents
                a kernel function.
            kernel_n_particles (`int`): Total number of particles per state
                used in SVGD updates.
            kernel_update_ratio ('float'): The ratio of SVGD particles used for
                the computation of the inner/outer empirical expectation.
            discount ('float'): Discount factor.
            reward_scale ('float'): A factor that scales the raw rewards.
                Useful for adjusting the temperature of the optimal Boltzmann
                distribution.
            use_saved_Q ('boolean'): If true, use the initial parameters provided
                in the Q-function instead of reinitializing.
            use_saved_policy ('boolean'): If true, use the initial parameters provided
                in the policy instead of reinitializing.
            save_full_state ('boolean'): If true, saves the full algorithm
                state, including the replay pool.
            train_dis('')
        """
        super(SQLDM, self).__init__(**kwargs)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy
        self._critic = critic
        self._extractor = extractor
        self._hyper_extractor = hyper_extractor

        self._Qs = Qs
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._plotter = plotter

        self._Q_lr = Q_lr
        self._policy_lr = policy_lr
        self._critic_lr = critic_lr
        self._critic_mode = critic_mode
        self._mcmc_lr = mcmc_lr
        self._alpha_lr = 0.001
        self._discount = discount
        self._tau = tau
        self._reward_scale = reward_scale
        self._gpw = gp_weight
        self._target_entropy = (
            heuristic_target_entropy(self._training_environment.action_space)
            if target_entropy == 'auto'
            else target_entropy)

        self._value_n_particles = value_n_particles
        self._Q_target_update_interval = target_update_interval

        self._kernel_fn = kernel_fn
        self._kernel_n_particles = kernel_n_particles
        self._kernel_update_ratio = kernel_update_ratio

        self._save_full_state = save_full_state
        self._train_Q = train_Q
        self._train_policy = train_policy
        self._train_critic = train_critic

        self._mcmc_noise_mean = noise_mean
        # self._mcmc_noise_std = noise_std
        self._mcmc_noise_std = math.sqrt(2 * self._mcmc_lr)
        self._num_steps = l_steps
        self._proj_norm = proj_norm
        self._proj_norm_type = proj_type

        self._latent_dim = latent_dim
        self._components = components

        self._mcmc_on_latent = mcmc_on_latent
        assert self._mcmc_on_latent == self._policy._mcmc_on_latent

        if use_saved_Q:
            saved_Q_weights = tuple(Q.get_weights() for Q in self._Qs)
        if use_saved_policy:
            saved_policy_weights = policy.get_weights()

        self._session.run(tf.compat.v1.global_variables_initializer())

        if use_saved_Q:
            for Q, Q_weights in zip(self._Qs, saved_Q_weights):
                Q.set_weights(Q_weights)
        if use_saved_policy:
            self._policy.set_weights(saved_policy_weights)

        # diagnose statics
        self._grad_slopes = None

        self._build()

    def _build(self):
        super(SQLDM, self)._build()

        self._init_td_update()
        # self._init_svgd_update()
        self._init_mcmc_update_v2()
        self._init_diagnostics_ops()
    
    def _build_critic_mode(self):
        if self._critic_mode == "jsd":
            def discriminator_loss_fn(real_logit, fake_logit):
                real_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(real_logit), real_logit)
                fake_loss = tf.losses.sigmoid_cross_entropy(tf.zeros_like(fake_logit), fake_logit)
                return real_loss, fake_loss

            def generator_loss_fn(fake_logit):
                fake_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(fake_logit), fake_logit)
                return fake_loss

        elif self._critic_mode in ["KLD","RKL","CHI","SQH"]:

            def discriminator_loss_fn(real_logit, fake_logit):
                real_loss = -f_gan_activation(self._critic_mode, real_logit)
                fake_loss = f_gan_activation(self._critic_mode,f_gan_conjuate(self._critic_mode,fake_logit))
                return real_loss, fake_loss

            def generator_loss_fn(fake_logit):
                fake_loss =-f_gan_activation(self._critic_mode, fake_logit)
                return fake_loss

        elif self._critic_mode == 'ls':
            def discriminator_loss_fn(real_logit, fake_logit):
                real_loss = tf.losses.mean_squared_error(tf.ones_like(real_logit), real_logit)
                fake_loss = tf.losses.mean_squared_error(tf.zeros_like(fake_logit), fake_logit)
                return real_loss, fake_loss

            def generator_loss_fn(fake_logit):
                fake_loss = tf.losses.mean_squared_error(tf.ones_like(fake_logit), fake_logit)
                return fake_loss

        elif self._critic_mode == 'wasserstein':
            def discriminator_loss_fn(real_logit, fake_logit):
                real_loss = - tf.reduce_mean(real_logit)
                fake_loss = tf.reduce_mean(fake_logit)
                return real_loss, fake_loss

            def generator_loss_fn(fake_logit):
                fake_loss = - tf.reduce_mean(fake_logit)
                return fake_loss

        elif self._critic_mode == 'hinge':
            def discriminator_loss_fn(real_logit, fake_logit):
                real_loss = tf.reduce_mean(tf.maximum(1 - real_logit, 0))
                fake_loss = tf.reduce_mean(tf.maximum(1 + fake_logit, 0))
                return real_loss, fake_loss

            def generator_loss_fn(fake_logit):
                # fake_loss = tf.reduce_mean(tf.maximum(1 - fake_logit, 0))
                fake_loss = tf.reduce_mean(- fake_logit)
                return fake_loss
        else:
            raise NotImplementedError()

        return discriminator_loss_fn, generator_loss_fn

    def _get_Q_target(self):
        next_Q_observations = {
            name: tf.reshape(
                tf.tile(
                    self._placeholders['next_observations'][name][:, tf.newaxis, :],
                    (1, self._value_n_particles, 1)),
                (-1, *self._placeholders['next_observations'][name].shape[1:]))
            for name in self._Qs[0].observation_keys
        }

        action_shape = self._placeholders['actions'].shape[1:].as_list()
        target_actions = tf.random.uniform(
            (1, self._value_n_particles, *action_shape), -1, 1)
        target_actions = tf.tile(
            target_actions,
            (tf.shape(self._placeholders['actions'])[0], 1, 1))
        target_actions = tf.reshape(target_actions, (-1, *action_shape))

        next_Q_inputs = flatten_input_structure(
            {**next_Q_observations, 'actions': target_actions})
        next_Qs_values = tuple(Q(next_Q_inputs) for Q in self._Q_targets)

        min_next_Q = tf.reduce_min(next_Qs_values, axis=0)

        assert_shape(min_next_Q, (None, 1))

        min_Q_next_target = tf.reshape(min_next_Q, (-1, self._value_n_particles))

        assert_shape(min_Q_next_target, (None, self._value_n_particles))

        # Equation 10 in [1]:
        next_values = tf.reduce_logsumexp(
            min_Q_next_target, keepdims=True, axis=1)

        assert_shape(next_values, [None, 1])

        # Importance weights add just a constant to the value.
        next_values -= tf.math.log(
            tf.cast(self._value_n_particles, tf.float32))
        next_values += np.prod(action_shape) * np.log(2)

        assert_shape(next_values, [None, 1])

        terminals = tf.cast(self._placeholders['terminals'], next_values.dtype)

        # \hat Q in Equation 11 in [1]:
        Q_target = td_target(
            reward=self._reward_scale * self._placeholders['rewards'],
            discount=self._discount,
            next_value=(1 - terminals) * next_values)

        return tf.stop_gradient(Q_target)

    def _init_td_update(self):
        """Create a minimization operation for Q-function update."""
        Q_target = self._get_Q_target()
        assert_shape(Q_target, [None, 1])

        Q_observations = {
            name: self._placeholders['observations'][name]
            for name in self._Qs[0].observation_keys
        }
        Q_actions = self._placeholders['actions']
        Q_inputs = flatten_input_structure({
            **Q_observations, 'actions': Q_actions})
        Q_values = self._Q_values = tuple(Q(Q_inputs) for Q in self._Qs)

        for Q_value in self._Q_values:
            assert_shape(Q_value, [None, 1])

        # Equation 11 in [1]:
        Q_losses = self._Q_losses = tuple(
            tf.compat.v1.losses.mean_squared_error(
                labels=Q_target, predictions=Q_value, weights=0.5)
            for Q_value in Q_values)

        if self._train_Q:
            self._Q_optimizers = tuple(
                tf.compat.v1.train.AdamOptimizer(
                    learning_rate=self._Q_lr,
                    name='{}_{}_optimizer'.format(Q._name, i)
                ) for i, Q in enumerate(self._Qs))
            Q_training_ops = tuple(
                Q_optimizer.minimize(
                    loss=Q_loss, var_list=Q.trainable_variables)
                for i, (Q, Q_loss, Q_optimizer)
                in enumerate(zip(self._Qs, Q_losses, self._Q_optimizers)))

            self._training_ops.update({'Q': tf.group(Q_training_ops)})

    def _init_svgd_update(self):
        """Create a minimization operation for policy update (SVGD)."""

        policy_inputs = flatten_input_structure({
            name: tf.reshape(
                tf.tile(
                    self._placeholders['observations'][name][:, None, :],
                    (1, self._kernel_n_particles, 1)),
                (-1, *self._placeholders['observations'][name].shape[1:]))
            for name in self._policy.observation_keys
        })
        actions = self._policy.actions(policy_inputs)
        action_shape = actions.shape[1:]
        actions = tf.reshape(
            actions, (-1, self._kernel_n_particles, *action_shape))

        assert_shape(
            actions, (None, self._kernel_n_particles, *action_shape))

        # SVGD requires computing two empirical expectations over actions
        # (see Appendix C1.1.). To that end, we first sample a single set of
        # actions, and later split them into two sets: `fixed_actions` are used
        # to evaluate the expectation indexed by `j` and `updated_actions`
        # the expectation indexed by `i`.
        n_updated_actions = int(
            self._kernel_n_particles * self._kernel_update_ratio)
        n_fixed_actions = self._kernel_n_particles - n_updated_actions

        fixed_actions, updated_actions = tf.split(
            actions, [n_fixed_actions, n_updated_actions], axis=1)
        fixed_actions = tf.stop_gradient(fixed_actions)
        assert_shape(fixed_actions,
                     [None, n_fixed_actions, *action_shape])
        assert_shape(updated_actions,
                     [None, n_updated_actions, *action_shape])

        Q_observations = {
            name: tf.reshape(
                    tf.tile(
                        self._placeholders['observations'][name][:, None, :],
                        (1, n_fixed_actions, 1)),
                    (-1, *self._placeholders['observations'][name].shape[1:]))
            for name in self._policy.observation_keys
        }
        Q_actions = tf.reshape(fixed_actions, (-1, *action_shape))
        Q_inputs = flatten_input_structure({
            **Q_observations, 'actions': Q_actions})
        Q_log_targets = tuple(Q(Q_inputs) for Q in self._Qs)
        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)
        svgd_target_values = tf.reshape(
            min_Q_log_target, (-1, n_fixed_actions, 1))

        # Target log-density. Q_soft in Equation 13:
        assert self._policy._squash
        squash_correction = tf.reduce_sum(
            tf.math.log(clip_but_pass_gradient(1 - fixed_actions ** 2,0.,1.0)+ EPS), axis=-1, keepdims=True)
        log_probs = svgd_target_values + squash_correction

        grad_log_probs = tf.gradients(log_probs, fixed_actions)[0]
        grad_log_probs = tf.expand_dims(grad_log_probs, axis=2)
        grad_log_probs = tf.stop_gradient(grad_log_probs)
        assert_shape(grad_log_probs,
                     [None, n_fixed_actions, 1, *action_shape])

        kernel_dict = self._kernel_fn(xs=fixed_actions, ys=updated_actions)

        # Kernel function in Equation 13:
        kappa = kernel_dict["output"][..., tf.newaxis]
        assert_shape(kappa, [None, n_fixed_actions, n_updated_actions, 1])

        # Stein Variational Gradient in Equation 13:
        action_gradients = tf.reduce_mean(
            kappa * grad_log_probs + kernel_dict["gradient"], axis=1)
        assert_shape(action_gradients,
                     [None, n_updated_actions, *action_shape])

        # Propagate the gradient through the policy network (Equation 14).
        gradients = tf.gradients(
            updated_actions,
            self._policy.trainable_variables,
            grad_ys=action_gradients)

        surrogate_loss = tf.reduce_sum([
            tf.reduce_sum(w * tf.stop_gradient(g))
            for w, g in zip(self._policy.trainable_variables, gradients)
        ])

        self._policy_optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=self._policy_lr,
            name='policy_optimizer'
        )

        if self._train_policy:
            svgd_training_op = self._policy_optimizer.minimize(
                loss=-surrogate_loss,
                var_list=self._policy.trainable_variables)
            self._training_ops.update({
                'svgd': svgd_training_op
            })

    def _init_mcmc_update(self):
        """Create a minimization operation for policy update (MCMC)."""
        policy_inputs = flatten_input_structure({
            name: tf.reshape(
                tf.tile(
                    self._placeholders['observations'][name][:, None, :],
                    (1, 1, 1)),
                (-1, *self._placeholders['observations'][name].shape[1:]))
            for name in self._policy.observation_keys
        })

        actions = self._policy.actions(policy_inputs)
        action_shape = actions.shape[1:]
        actions = tf.reshape(
            actions, (-1,  *action_shape))

        assert_shape(
            actions, (None,  *action_shape))

        def langevin_step(counter, actions):
            actions = actions + tf.random_normal(tf.shape(actions),
                                             mean=self._mcmc_noise_mean,
                                             stddev=self._mcmc_noise_std)

            Q_observations = {
                name: tf.reshape(
                    tf.tile(
                        self._placeholders['observations'][name][:, None, :],
                        (1, 1, 1)),
                    (-1, *self._placeholders['observations'][name].shape[1:]))
                for name in self._policy.observation_keys
            }
            Q_actions = tf.reshape(actions, (-1, *action_shape))
            Q_inputs = flatten_input_structure({
                **Q_observations, 'actions': Q_actions})
            Q_log_targets = tuple(Q(Q_inputs) for Q in self._Qs)
            min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)
            energy_target = tf.squeeze(tf.reshape(
                min_Q_log_target, (-1, 1, 1)), axis=-1)
            grad_log_probs = tf.gradients(energy_target, actions)[0]
            grad_log_probs = tf.stop_gradient(grad_log_probs)
            assert_shape(grad_log_probs,
                         [None, *action_shape])

            if self._proj_norm != 0.0:
                if self._proj_norm_type == 'l2':
                    grad_log_probs = tf.clip_by_norm(grad_log_probs, self._proj_norm)
                elif self._proj_norm_type == 'li':
                    grad_log_probs = tf.clip_by_value(
                        grad_log_probs, -self._proj_norm, self._proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False
            act_last = actions + grad_log_probs * self._mcmc_lr
            actions = act_last
            actions = tf.clip_by_value(actions, -1, 1)
            counter = counter + 1

            return counter, actions

        def cast_and_concat(x):
            x = nest.map_structure(
                lambda element: tf.cast(element, tf.float32), x)
            x = nest.flatten(x)
            x = tf.concat(x, axis=-1)
            return x

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, self._num_steps)

        steps, action_updated = tf.while_loop(c, langevin_step, (steps, actions))

        Q_observations = {
            name: tf.reshape(
                tf.tile(
                    self._placeholders['observations'][name][:, None, :],
                    (1, 1, 1)),
                (-1, *self._placeholders['observations'][name].shape[1:]))
            for name in self._policy.observation_keys
        }
        # Q_actions = tf.reshape(actions, (-1, *action_shape))
        # Q_inputs = flatten_input_structure({
        #     **Q_observations, 'actions': Q_actions})
        # Q_log_targets = tuple(Q(Q_inputs) for Q in self._Qs)
        # min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)
        # # assert self._policy._squash
        # # squash_correction = tf.reduce_sum(
        # #     tf.math.log(1 - actions ** 2 + EPS), axis=-1, keepdims=True)
        # # we do not use the squash correction here.
        # energy_target = tf.squeeze(tf.reshape(
        #     min_Q_log_target, (-1, 1, 1)),axis=-1)
        # grad_log_probs = tf.gradients(energy_target, actions)[0]
        # # grad_log_probs = tf.expand_dims(grad_log_probs, axis=2)
        # grad_log_probs = tf.stop_gradient(grad_log_probs)
        # assert_shape(grad_log_probs,
        #              [None, *action_shape])
        #
        # action_updated = actions + grad_log_probs

        dis_real_inputs = flatten_input_structure({
            **Q_observations, 'actions': action_updated})
        dis_fake_inputs = flatten_input_structure({
            **Q_observations, 'actions': actions})



        # latent_z = self._extractor.extractor_model(dis_fake_inputs)
        #q(h|s,a)

        # q_hyper_k = gumbel_softmax(self._hyper_extractor.extractor_model(tf.concat([policy_inputs, latent_z], axis=-1)),temperature=1.0,hard =True )

        #q(k|s,h)

        #
        if self._use_wgan == False:
            disc_real = self._critic.critic_model(dis_real_inputs)
            disc_fake = self._critic.critic_model(dis_fake_inputs)

            gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_fake,
                labels=tf.ones_like(disc_fake)
            ))

            disc_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_fake,
                labels=tf.zeros_like(disc_fake)
            ))

            disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                logits=disc_real,
                labels=tf.ones_like(disc_real)
            ))
        else:
            disc_real = self._critic.critic_model(dis_real_inputs)
            disc_fake = self._critic.critic_model(dis_fake_inputs)

            gen_cost = -tf.reduce_mean(disc_fake)
            disc_cost = -tf.reduce_mean(disc_real) + tf.reduce_mean(disc_fake)

            epsilon = tf.random_uniform(
                shape=[tf.shape(disc_real)[0], 1],
                minval=0.,
                maxval=1.)
            action_hat = tf.clip_by_value(action_updated + epsilon * (actions - action_updated),-1,1)
            inputs_hat = flatten_input_structure({
                **Q_observations, 'actions': action_hat})
            grad_D_X_hat = self._critic._gradient(inputs_hat)
            slopes = tf.sqrt(tf.reduce_sum(tf.square(grad_D_X_hat)))
            gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
            disc_cost += gradient_penalty * self._gpw
            self._gradient_norm = slopes

        self._policy_optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=self._policy_lr,
            name='policy_optimizer'
        )
        self._critic_optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=self._critic_lr,
            name='critic_optimizer'
        )

        self._gen_losses = gen_cost
        self._disc_losses = disc_cost

        if self._use_wgan:
            self._disc_losses += gradient_penalty * self._gpw

        if self._train_policy:
            policy_train_op = self._policy_optimizer.minimize(
                loss=gen_cost,
                var_list=self._policy.trainable_variables)
            self._training_ops.update({
                'policy_train_op': policy_train_op
            })
        if self._train_critic:
            critic_train_op = self._critic_optimizer.minimize(
                loss=disc_cost,
                var_list=self._critic.trainable_variables)
            self._training_ops.update({
                'critic_train_op': critic_train_op
            })

    def _init_mcmc_update_v2(self):
        """Create a minimization operation for policy update (SVGD)."""

        policy_inputs = flatten_input_structure({
            name: tf.reshape(
                tf.tile(
                    self._placeholders['observations'][name][:, None, :],
                    (1, 2 * self._kernel_n_particles, 1)),
                (-1, *self._placeholders['observations'][name].shape[1:]))
            for name in self._policy.observation_keys
        })
        if self._mcmc_on_latent:
            actions = self._policy.raw_actions(policy_inputs)
        else:
            actions = self._policy.actions(policy_inputs)

        action_shape = actions.shape[1:]
        actions = tf.reshape(
            actions, (-1, 2* self._kernel_n_particles, *action_shape))

        assert_shape(
            actions, (None, 2 * self._kernel_n_particles, *action_shape))

        g_actions, updated_actions = tf.split(
            actions, [self._kernel_n_particles, self._kernel_n_particles], axis=1)

        #g_actions --> (256,16,6)
        g_actions = tf.reshape(g_actions, (-1, *action_shape))

        updated_actions = tf.reshape(updated_actions, (-1, *action_shape))


        def lag_steps(counter,actions_): # actions_ -> u

            actions_ = actions_ + tf.random_normal(tf.shape(actions_),  # a -> u
                                                 mean=self._mcmc_noise_mean,
                                                 stddev=self._mcmc_noise_std)
            Q_observations = {
                name: tf.reshape(
                    tf.tile(
                        self._placeholders['observations'][name][:, None, :],
                        (1, self._kernel_n_particles, 1)),
                    (-1, *self._placeholders['observations'][name].shape[1:]))
                for name in self._policy.observation_keys
            }
            Q_actions = tf.reshape(actions_, (-1, *action_shape))
            if self._mcmc_on_latent:
                Q_actions = tf.tanh(Q_actions) # Q always on action space a

            Q_inputs = flatten_input_structure({
                **Q_observations, 'actions': Q_actions})

            Q_log_targets = tuple(Q(Q_inputs) for Q in self._Qs)
            svgd_target_values = tf.reduce_min(Q_log_targets, axis=0)

            squash_correction = tf.reduce_sum(
                tf.math.log(clip_but_pass_gradient(1 - tf.tanh(actions_) ** 2, 0., 1.0) + EPS), axis=-1, keepdims=True)

            if self._mcmc_on_latent:
                log_probs = svgd_target_values + squash_correction
            else:
                log_probs = svgd_target_values  # it seems the original version are wrong.

            grad_log_probs = tf.gradients(log_probs, actions_)[0]
            grad_log_probs = tf.stop_gradient(grad_log_probs)
            actions_ = actions_ + grad_log_probs * self._mcmc_lr
            # actions_last = actions_
            actions_last = tf.clip_by_value(actions_, -1 + EPS, 1 - EPS)
            counter = counter + 1

            return counter,actions_last

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, self._num_steps)

        steps, updated_actions = tf.while_loop(c, lag_steps, (steps, updated_actions))
        # updated_actions = tf.stop_gradient(updated_actions)
        Q_observations = {
            name: tf.reshape(
                tf.tile(
                    self._placeholders['observations'][name][:, None, :],
                    (1, self._kernel_n_particles, 1)),
                (-1, *self._placeholders['observations'][name].shape[1:]))
            for name in self._policy.observation_keys
        }

        self._langevin_action_altered = updated_actions - g_actions 

        dis_real_inputs = flatten_input_structure({
            **Q_observations, 'actions': updated_actions})
        dis_fake_inputs = flatten_input_structure({
            **Q_observations, 'actions': g_actions})

        disc_logit_real = self._critic.critic_model(dis_real_inputs)
        disc_logit_fake = self._critic.critic_model(dis_fake_inputs)

        dis_loss_fn, gen_loss_fn = self._build_critic_mode()
        
        gen_loss = gen_loss_fn(disc_logit_fake)

        disc_real_loss, disc_fake_loss = dis_loss_fn(disc_logit_real, disc_logit_fake)

        def _gradient_penalty(Q_observations, real_action, fake_action):
            if self._critic_mode == "jsd":
                return 0
            epsilon = tf.random_uniform(
                shape=[tf.shape(real_action)[0], 1],
                minval=0.,
                maxval=1.)
            action_hat = real_action + epsilon * (fake_action - real_action)
            inputs_hat = flatten_input_structure({
                **Q_observations, 'actions': action_hat}) # action space
            grad_D_X_hat = self._critic._gradient(inputs_hat)
            slopes = tf.sqrt(tf.reduce_sum(tf.square(grad_D_X_hat)))
            gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
            self._grad_slopes = slopes

            return gradient_penalty

        gradient_penalty = _gradient_penalty(Q_observations, updated_actions, g_actions)

        disc_loss = disc_real_loss + disc_fake_loss + self._gpw * gradient_penalty

        regulation_logit = 0.0001 * tf.reduce_mean(tf.square(disc_logit_real))

        disc_loss += regulation_logit

        self._gen_losses = gen_loss
        self._disc_losses = disc_loss
        self._disc_real = disc_logit_real
        self._disc_fake = disc_logit_fake


        Q_after = tuple(Q(dis_real_inputs) for Q in self._Qs)
        Q_before = tuple(Q(dis_fake_inputs) for Q in self._Qs)
        self._langevin_Q_altered = tuple(Q_a - Q_b for Q_b, Q_a in zip(Q_before, Q_after))

        self._policy_optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=self._policy_lr,
            name='policy_optimizer'
        )
        self._critic_optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=self._critic_lr,
            name='critic_optimizer'
        )


        p_inputs = flatten_input_structure({
            name: self._placeholders['observations'][name]
            for name in self._policy.observation_keys
        })

        if self._mcmc_on_latent:
            acts = self._policy.raw_actions(p_inputs)
        else:
            acts = self._policy.actions(p_inputs)
        log_pis_ = self._policy.log_pis(p_inputs, acts)

        log_alpha = self._log_alpha = tf.compat.v1.get_variable(
            'log_alpha',
            dtype=tf.float32,
            initializer=0.0)
        alpha = tf.exp(log_alpha)

        if isinstance(self._target_entropy, Number):
            alpha_loss = -tf.reduce_mean(
                log_alpha * tf.stop_gradient(log_pis_ + self._target_entropy))

            self._alpha_optimizer = tf.compat.v1.train.AdamOptimizer(
                self._alpha_lr, name='alpha_optimizer')
            self._alpha_train_op = self._alpha_optimizer.minimize(
                loss=alpha_loss, var_list=[log_alpha])

            self._training_ops.update({
                'temperature_alpha': self._alpha_train_op
            })

        self._alpha = alpha

        min_Q_log_target = tf.reduce_min(Q_after, axis=0)
        # policy_loss = gen_loss + self._alpha * tf.reduce_mean(log_pis_) - min_Q_log_target 
        policy_loss = gen_loss + self._alpha * tf.reduce_mean(log_pis_)
        # policy_loss = gen_cost 


        if self._train_policy:
            policy_train_op = self._policy_optimizer.minimize(
                loss=policy_loss,
                var_list=self._policy.trainable_variables)
            self._training_ops.update({
                'policy_train_op': policy_train_op
            })
        if self._train_critic:
            critic_train_op = self._critic_optimizer.minimize(
                loss=disc_loss,
                var_list=self._critic.trainable_variables)
            self._training_ops.update({
                'critic_train_op': critic_train_op
            })

    def _init_diagnostics_ops(self):
        diagnosables = OrderedDict((
            ('Q_value', self._Q_values),
            ('Q_loss', self._Q_losses),
            ('critic/dis_loss', self._disc_losses),
            ('critic/dis_real', self._disc_real),
            ('critic/dis_fake', self._disc_fake),
            ('critic/gen_loss', self._gen_losses),
            # ('critic/gradient_norm', self._gradient_norm),
            ('critic/mcmc/noise_mean', self._mcmc_noise_mean),
            ('critic/mcmc/noise_std', self._mcmc_noise_std),
            ('critic/mcmc/proj_norm', self._proj_norm),
            ('critic/mcmc/Q_altered', self._langevin_Q_altered),
            ('critic/mcmc/actions_altered', tf.math.abs(self._langevin_action_altered)),
        ))
            
        if self._grad_slopes is not None:
            diagnosables.update(OrderedDict([
                ('critic/gradient_norm', self._grad_slopes),
            ]))


        diagnostic_metrics = OrderedDict((
            ('mean', tf.reduce_mean),
            ('std', lambda x: tfp.stats.stddev(x, sample_axis=None)),
        ))

        self._diagnostics_ops = OrderedDict([
            (f'{key}-{metric_name}', metric_fn(values))
            for key, values in diagnosables.items()
            for metric_name, metric_fn in diagnostic_metrics.items()
        ])

    def _init_training(self):
        self._update_target(tau=1.0)

    def _update_target(self, tau=None):
        tau = tau or self._tau

        for Q, Q_target in zip(self._Qs, self._Q_targets):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

    def _do_training(self, iteration, batch):
        """Run the operations for updating training and target ops."""

        feed_dict = self._get_feed_dict(iteration, batch)
        self._session.run(self._training_ops, feed_dict)

        if iteration % self._Q_target_update_interval == 0 and self._train_Q:
            self._update_target()

    def _get_feed_dict(self, iteration, batch):
        """Construct a TensorFlow feed dictionary from a sample batch."""

        batch_flat = flatten(batch)
        placeholders_flat = flatten(self._placeholders)

        feed_dict = {
            placeholders_flat[key]: batch_flat[key]
            for key in placeholders_flat.keys()
            if key in batch_flat.keys()
        }

        if iteration is not None:
            feed_dict[self._placeholders['iteration']] = iteration

        return feed_dict

    def get_diagnostics(self,
                        iteration,
                        batch,
                        evaluation_paths,
                        training_paths):
        """Return diagnostic information as ordered dictionary.

        Also calls the `draw` method of the plotter, if plotter defined.
        """

        feed_dict = self._get_feed_dict(iteration, batch)
        diagnostics = self._session.run(self._diagnostics_ops, feed_dict)

        diagnostics.update(OrderedDict([
            (f'policy/{key}', value)
            for key, value in
            self._policy.get_diagnostics(flatten_input_structure({
                name: batch['observations'][name]
                for name in self._policy.observation_keys
            })).items()
        ]))
        
        if not self._mcmc_on_latent:
            actions = self._policy.actions_np(flatten_input_structure({
                    name: batch['observations'][name]
                    for name in self._policy.observation_keys
            }))
        else:
            actions = self._policy.raw_actions_np(flatten_input_structure({
                    name: batch['observations'][name]
                    for name in self._policy.observation_keys
            }))


        diagnostics.update(OrderedDict((
            ('critic/mcmc/mcmc_steps', self._num_steps),
            ('critic/mcmc/mcmc_lr', self._mcmc_lr),
            ('critic/mcmc/critic_lr', self._critic_lr),
            # ('critic/mcmc/grad_log_probs', tf.reduce_mean(self._langevin_grad_log_probs)),
        )))

        critic_input = { name: batch['observations'][name] for name in self._critic.observation_keys}
        critic_input.update({'actions':actions})
        critic_input = flatten_input_structure(critic_input)

        diagnostics.update(OrderedDict([
            (f'critic/{key}', value)
            for key, value in
            self._critic.get_diagnostics(critic_input).items()
        ]))

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    @property
    def tf_saveables(self):
        return {
            '_policy_optimizer': self._policy_optimizer,
            '_critic_optimizer': self._critic_optimizer,
            **{
                f'Q_optimizer_{i}': optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
            '_log_alpha': self._log_alpha,
        }
