from collections import OrderedDict

import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tf_slim as slim

from softlearning.misc.kernel import adaptive_isotropic_gaussian_kernel

from .rl_algorithm import RLAlgorithm

EPS = 1e-6


def assert_shape(tensor, expected_shape):
    tensor_shape = tensor.shape.as_list()
    assert len(tensor_shape) == len(expected_shape)
    assert all([a == b for a, b in zip(tensor_shape, expected_shape)])


class SQL(RLAlgorithm):
    """Soft Q-learning (SQL).

    Example:
        See `examples/mujoco_all_sql.py`.

    Reference:
        [1] Tuomas Haarnoja, Haoran Tang, Pieter Abbeel, and Sergey Levine,
        "Reinforcement Learning with Deep Energy-Based Policies," International
        Conference on Machine Learning, 2017. https://arxiv.org/abs/1702.08165
    """

    def __init__(
            self,
            training_environment,
            evaluation_environment,
            policy,
            Qs,
            pool,
            plotter=None,

            policy_lr=3e-4,
            Q_lr=3e-4,
            value_n_particles=16,
            target_update_interval=1,
            kernel_fn=adaptive_isotropic_gaussian_kernel,
            kernel_n_particles=16,
            kernel_update_ratio=0.5,
            discount=0.99,
            tau=5e-3,
            reward_scale=1,
            use_saved_Q=False,
            use_saved_policy=False,
            save_full_state=False,
            train_Q=True,
            train_policy=True,
            **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment object used for training.
            policy: A policy function approximator.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            Q_lr (`float`): Learning rate used for the Q-function approximator.
            value_n_particles (`int`): The number of action samples used for
                estimating the value of next state.
            target_update_interval (`int`): How often the target network is
                updated to match the current Q-function.
            kernel_fn (function object): A function object that represents
                a kernel function.
            kernel_n_particles (`int`): Total number of particles per state
                used in SVGD updates.
            kernel_update_ratio ('float'): The ratio of SVGD particles used for
                the computation of the inner/outer empirical expectation.
            discount ('float'): Discount factor.
            reward_scale ('float'): A factor that scales the raw rewards.
                Useful for adjusting the temperature of the optimal Boltzmann
                distribution.
            use_saved_Q ('boolean'): If true, use the initial parameters provided
                in the Q-function instead of reinitializing.
            use_saved_policy ('boolean'): If true, use the initial parameters provided
                in the policy instead of reinitializing.
            save_full_state ('boolean'): If true, saves the full algorithm
                state, including the replay pool.
        """
        super(SQL, self).__init__(**kwargs)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._plotter = plotter

        self._Q_lr = Q_lr
        self._policy_lr = policy_lr
        self._discount = discount
        self._tau = tau
        self._reward_scale = reward_scale

        self._value_n_particles = value_n_particles
        self._Q_target_update_interval = target_update_interval

        self._kernel_fn = kernel_fn
        self._kernel_n_particles = kernel_n_particles
        self._kernel_update_ratio = kernel_update_ratio

        self._save_full_state = save_full_state
        self._train_Q = train_Q
        self._train_policy = train_policy

        observation_shape = training_environment.active_observation_shape
        action_shape = training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        self._create_placeholders()

        self._training_ops = []

        self._create_td_update()
        self._create_svgd_update()

        if use_saved_Q:
            saved_Q_weights = tuple(Q.get_weights() for Q in self._Qs)
        if use_saved_policy:
            saved_policy_weights = policy.get_weights()

        self._session.run(tf.global_variables_initializer())

        if use_saved_Q:
            for Q, Q_weights in zip(self._Qs, saved_Q_weights):
                Q.set_weights(Q_weights)
        if use_saved_policy:
            self._policy.set_weights(saved_policy_weights)

    def _create_placeholders(self):
        """Create all necessary placeholders."""

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='observations')

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='next_observations')

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._action_shape),
            name='actions')

        self._next_actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._action_shape),
            name='next_actions')

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='rewards')

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='terminals')

    def _create_td_update(self):
        """Create a minimization operation for Q-function update."""

        next_observations = tf.tile(
            self._next_observations_ph[:, tf.newaxis, :],
            (1, self._value_n_particles, 1))
        next_observations = tf.reshape(
            next_observations, (-1, *self._observation_shape))

        target_actions = tf.random_uniform(
            (1, self._value_n_particles, *self._action_shape), -1, 1)
        target_actions = tf.tile(
            target_actions, (tf.shape(self._next_observations_ph)[0], 1, 1))
        target_actions = tf.reshape(target_actions, (-1, *self._action_shape))

        Q_next_targets = tuple(
            Q([next_observations, target_actions])
            for Q in self._Q_targets)

        min_Q_next_targets = tf.reduce_min(Q_next_targets, axis=0)

        assert_shape(min_Q_next_targets, (None, 1))

        min_Q_next_target = tf.reshape(
            min_Q_next_targets, (-1, self._value_n_particles))

        assert_shape(min_Q_next_target, (None, self._value_n_particles))

        # Equation 10:
        next_value = tf.reduce_logsumexp(
            min_Q_next_target, keepdims=True, axis=1)
        assert_shape(next_value, [None, 1])

        # Importance weights add just a constant to the value.
        next_value -= tf.log(tf.to_float(self._value_n_particles))
        next_value += np.prod(self._action_shape) * np.log(2)

        # \hat Q in Equation 11:
        Q_target = tf.stop_gradient(
            self._reward_scale
            * self._rewards_ph
            + (1 - self._terminals_ph)
            * self._discount
            * next_value)
        assert_shape(Q_target, [None, 1])

        Q_values = self._Q_values = tuple(
            Q([self._observations_ph, self._actions_ph])
            for Q in self._Qs)

        for Q_value in self._Q_values:
            assert_shape(Q_value, [None, 1])

        # Equation 11:
        Q_losses = self._Q_losses = tuple(
            tf.losses.mean_squared_error(
                labels=Q_target, predictions=Q_value, weights=0.5)
            for Q_value in Q_values)

        if self._train_Q:
            self._Q_optimizers = tuple(
                tf.train.AdamOptimizer(
                    learning_rate=self._Q_lr,
                    name='{}_{}_optimizer'.format(Q._name, i)
                ) for i, Q in enumerate(self._Qs))
            Q_training_ops = tuple(
                # tf.contrib.layers.optimize_loss(
                slim.optimize_loss(
                    Q_loss,
                    None,
                    learning_rate=self._Q_lr,
                    optimizer=Q_optimizer,
                    variables=Q.trainable_variables,
                    increment_global_step=False,
                    summaries=())
                for i, (Q, Q_loss, Q_optimizer)
                in enumerate(zip(self._Qs, Q_losses, self._Q_optimizers)))

            self._training_ops.append(tf.group(Q_training_ops))

    def _create_svgd_update(self):
        """Create a minimization operation for policy update (SVGD)."""

        actions = self._policy.actions([
            tf.reshape(
                tf.tile(
                    self._observations_ph[:, None, :],
                    (1, self._kernel_n_particles, 1)),
                (-1, *self._observation_shape))
        ])
        actions = tf.reshape(
            actions,
            (-1, self._kernel_n_particles, *self._action_shape))

        assert_shape(
            actions, (None, self._kernel_n_particles, *self._action_shape))

        # SVGD requires computing two empirical expectations over actions
        # (see Appendix C1.1.). To that end, we first sample a single set of
        # actions, and later split them into two sets: `fixed_actions` are used
        # to evaluate the expectation indexed by `j` and `updated_actions`
        # the expectation indexed by `i`.
        n_updated_actions = int(
            self._kernel_n_particles * self._kernel_update_ratio)
        n_fixed_actions = self._kernel_n_particles - n_updated_actions

        fixed_actions, updated_actions = tf.split(
            actions, [n_fixed_actions, n_updated_actions], axis=1)
        fixed_actions = tf.stop_gradient(fixed_actions)
        assert_shape(fixed_actions,
                     [None, n_fixed_actions, *self._action_shape])
        assert_shape(updated_actions,
                     [None, n_updated_actions, *self._action_shape])

        Q_log_targets = tuple(
            Q([
                tf.reshape(
                    tf.tile(
                        self._observations_ph[:, None, :],
                        (1, n_fixed_actions, 1)),
                    (-1, *self._observation_shape)),
                tf.reshape(fixed_actions, (-1, *self._action_shape))
            ])
            for Q in self._Qs)
        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)
        svgd_target_values = tf.reshape(
            min_Q_log_target,
            (-1, n_fixed_actions, 1))

        # Target log-density. Q_soft in Equation 13:
        assert self._policy._squash
        squash_correction = tf.reduce_sum(
            tf.log(1 - fixed_actions ** 2 + EPS), axis=-1, keepdims=True)
        log_probs = svgd_target_values + squash_correction

        grad_log_probs = tf.gradients(log_probs, fixed_actions)[0]
        grad_log_probs = tf.expand_dims(grad_log_probs, axis=2)
        grad_log_probs = tf.stop_gradient(grad_log_probs)
        assert_shape(grad_log_probs,
                     [None, n_fixed_actions, 1, *self._action_shape])

        kernel_dict = self._kernel_fn(xs=fixed_actions, ys=updated_actions)

        # Kernel function in Equation 13:
        kappa = kernel_dict["output"][..., tf.newaxis]
        assert_shape(kappa, [None, n_fixed_actions, n_updated_actions, 1])

        # Stein Variational Gradient in Equation 13:
        action_gradients = tf.reduce_mean(
            kappa * grad_log_probs + kernel_dict["gradient"], axis=1)
        assert_shape(action_gradients,
                     [None, n_updated_actions, *self._action_shape])

        # Propagate the gradient through the policy network (Equation 14).
        gradients = tf.gradients(
            updated_actions,
            self._policy.trainable_variables,
            grad_ys=action_gradients)

        surrogate_loss = tf.reduce_sum([
            tf.reduce_sum(w * tf.stop_gradient(g))
            for w, g in zip(self._policy.trainable_variables, gradients)
        ])

        self._policy_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._policy_lr,
            name='policy_optimizer'
        )

        if self._train_policy:
            svgd_training_op = self._policy_optimizer.minimize(
                loss=-surrogate_loss,
                var_list=self._policy.trainable_variables)
            self._training_ops.append(svgd_training_op)

    def train(self, *args, **kwargs):
        """Initiate training of the SAC instance."""
        return self._train(*args, **kwargs)

    def _init_training(self):
        self._update_target(tau=1.0)

    def _update_target(self, tau=None):
        tau = tau or self._tau

        for Q, Q_target in zip(self._Qs, self._Q_targets):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])

    def _do_training(self, iteration, batch):
        """Run the operations for updating training and target ops."""

        feed_dict = self._get_feed_dict(batch)
        self._session.run(self._training_ops, feed_dict)

        if iteration % self._Q_target_update_interval == 0 and self._train_Q:
            self._update_target()

    def _get_feed_dict(self, batch):
        """Construct a TensorFlow feed dictionary from a sample batch."""

        feeds = {
            self._observations_ph: batch['observations'],
            self._actions_ph: batch['actions'],
            self._next_observations_ph: batch['next_observations'],
            self._rewards_ph: batch['rewards'],
            self._terminals_ph: batch['terminals'],
        }

        return feeds

    def get_diagnostics(self,
                        iteration,
                        batch,
                        evaluation_paths,
                        training_paths):
        """Record diagnostic information.

        Records the mean and standard deviation of Q-function and the
        squared Bellman residual of the  s (mean squared Bellman error)
        for a sample batch.

        Also call the `draw` method of the plotter, if plotter is defined.
        """

        feeds = self._get_feed_dict(batch)
        Q_values, Q_losses =  self._session.run(
            [self._Q_values, self._Q_losses], feeds)

        diagnostics = OrderedDict({
            'Q-avg': np.mean(Q_values),
            'Q-std': np.std(Q_values),
            'Q_loss': np.mean(Q_losses),
        })

        policy_diagnostics = self._policy.get_diagnostics(batch['observations'])
        diagnostics.update({
            f'policy/{key}': value
            for key, value in policy_diagnostics.items()
        })

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    def get_snapshot(self, epoch):
        """Return loggable snapshot of the SQL algorithm.

        If `self._save_full_state == True`, returns snapshot including the
        replay pool. If `self._save_full_state == False`, returns snapshot
        of policy, Q-function, and environment instances.
        """

        state = {
            'epoch': epoch,
            'policy': self._policy,
            'Q': self._Q,
            'training_environment': self._training_environment,
            'evaluation_environment': self._evaluation_environment,
        }

        if self._save_full_state:
            state.update({'replay_pool': self._pool})

        return state

    @property
    def tf_saveables(self):
        return {
            '_policy_optimizer': self._policy_optimizer,
            **{
                f'Q_optimizer_{i}': optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
        }
