import tensorflow as tf
from utils.models import Actor, DistributionalCritic
from utils.utils_tools import Adam
from utils.wrappers import *
from utils.utils_tools import quantile_regression_loss, distortion_de
from algo.LatentActor import LatentAlgo
import functools



class Lodac(LatentAlgo):
    def __init__(self, config):
        super().__init__(config)
        self._config = config
        self._actor = Actor(config)
        self._critic = DistributionalCritic(config)

        if self._config.use_automatic_entropy:
            self.log_alpha = tf.Variable(0.0, trainable=True, dtype=tf.float32)
            self._alpha_opt = tf.keras.optimizers.Adam(self._config.alpha_lr, name='alpha_opt')
        else:
            self.alpha = self._config.alpha_entropy

        if self._config.with_lagrange:
            self.target_action_gap = tf.constant(self._config.target_action_gap, dtype=tf.float32)
            self.log_alpha_prime = tf.Variable(0.0, trainable=True, dtype=tf.float32)
            self.alpha_prime_optimizer = tf.keras.optimizers.Adam(self._config.alpha_prime_lr, name='alpha_prime_opt')

        self._num_actor_critic_train_step = 0

        Optimizer = functools.partial(Adam,
                                      wd=self._config['weight_decay'],
                                      clip=self._config['grad_clip'])
        self._critic_opt = Optimizer('critic', [self._critic.q1, self._critic.q2], self._config.critic_lr)
        self._actor_opt = Optimizer('actor', [self._actor], self._config.actor_lr)

    def get_tau(self, obs, actions, fp=None):
        if self._config.tau_type == 'fix':
            presum_tau = tf.zeros([len(actions), self._config.N_quantile_critic]) + 1. / self._config.N_quantile_critic
        elif self._config.tau_type == 'iqn':
            presum_tau = tf.random.uniform([len(actions), self._config.N_quantile_critic]) + 0.1
            presum_tau /= tf.reduce_sum(presum_tau, axis=-1, keepdims=True)
        elif self._config.tau_type == 'fqf':
            if fp is None:
                fp = self.fp
            presum_tau = fp(obs, actions)
        tau = tf.math.cumsum(presum_tau, axis=-1)
        tau_hat = np.zeros_like(tau)
        tau_np = tau.numpy()
        tau_hat[:, 0:1] = tau_np[:, 0:1] / 2.
        tau_hat[:, 1:] = (tau_np[:, 1:] + tau_np[:, :-1]) / 2.
        tau_hat = tf.convert_to_tensor(tau_hat)
        return tau, tau_hat, presum_tau

    def actor_critic_train_step(self, data):
        obs = data['obs']
        actions = data['actions']
        next_obs = data['next_obs']
        rewards = data['rewards']
        terminals = data['terminals']

        """
        Alpha training
        """
        if self._config.use_automatic_entropy:
            action, log_pi = self._actor.action_log_prob(obs)
            target_entropy = - tf.constant(action.shape[-1], dtype=tf.float32)
            with tf.GradientTape() as alpha_tape:
                alpha_tape.watch(self.log_alpha)
                alpha_loss = - tf.reduce_mean(self.log_alpha * tf.stop_gradient((log_pi + target_entropy)))
            grad = alpha_tape.gradient(alpha_loss, [self.log_alpha])
            self._alpha_opt.apply_gradients(zip(grad, [self.log_alpha]))
            self.alpha = tf.math.exp(self.log_alpha)

        """"
        Critic training
        """
        with tf.GradientTape() as critic_tape:

            tau, tau_hat, presum_hat = self.get_tau(obs, actions, fp=self._config.fp)
            Z_tau_K1 = self._critic.q1(obs, actions, tau_hat)
            Z_tau_K2 = self._critic.q2(obs, actions, tau_hat)

            next_action, next_log_pi = self._actor.action_log_prob(next_obs)
            next_tau, next_tau_hat, next_presum_hat = self.get_tau(next_obs, next_action, fp=self._config.target_fp)

            target_q_value = tf.reduce_min([self._critic.target_q1(next_obs, next_action, next_tau_hat),
                                            self._critic.target_q2(next_obs, next_action, next_tau_hat)],
                                            axis=0)
            target_q_value = target_q_value - self.alpha * next_log_pi

            terminals = tf.broadcast_to(terminals, shape=target_q_value.shape)
            rewards = tf.broadcast_to(rewards, shape=target_q_value.shape)
            target_Z_tau = tf.stop_gradient(rewards + self._config.critic_gamma * target_q_value * (1-terminals))

            # CODAC penalty
            penalty_index = np.random.randint(0, self._config.N_quantile_critic)
            truncated_tau_hat = tau_hat[:, penalty_index: penalty_index+1]

            random_actions = tf.random.uniform([actions.shape[0]*self._config.num_random_actions, actions.shape[1]], minval=-1, maxval=1)
            current_actions, log_pi_current = self._actions_in_expected_shape(obs, self._config.num_random_actions)
            new_current_actions, new_log_pi = self._actions_in_expected_shape(next_obs, self._config.num_random_actions)

            z1_rand, z2_rand = self._q_prediction_in_expected_shape(obs, random_actions, truncated_tau_hat)
            z1_current, z2_current = self._q_prediction_in_expected_shape(obs, current_actions, truncated_tau_hat)
            z1_next_actions, z2_next_actions = self._q_prediction_in_expected_shape(obs, new_current_actions, truncated_tau_hat)

            random_density = np.log(0.5 ** current_actions.shape[-1])
            cat_z1 = tf.concat(
                [z1_rand - random_density, z1_next_actions - tf.stop_gradient(new_log_pi),
                 z1_current - tf.stop_gradient(log_pi_current)], axis=1)
            cat_z2 = tf.concat(
                [z2_rand - random_density, z2_next_actions - tf.stop_gradient(new_log_pi),
                 z2_current - tf.stop_gradient(log_pi_current)], axis=1
            )
            q1_penalty = tf.math.reduce_logsumexp(cat_z1, axis=1)
            q2_penalty = tf.math.reduce_logsumexp(cat_z2, axis=1)

            min_q1_loss = self._config.min_z_weight * (tf.reduce_mean(q1_penalty) - tf.reduce_mean(Z_tau_K1))
            min_q2_loss = self._config.min_z_weight * (tf.reduce_mean(q2_penalty) - tf.reduce_mean(Z_tau_K2))

            if self._config.with_lagrange:
                with tf.GradientTape() as alpha_prime_tape:
                    alpha_prime = tf.clip_by_value(tf.math.exp(self.log_alpha_prime), clip_value_min=0.0,
                                                   clip_value_max=1e6)
                    min_q1_loss = alpha_prime * (min_q1_loss - self.target_action_gap)
                    min_q2_loss = alpha_prime * (min_q2_loss - self.target_action_gap)
                    alpha_prime_loss = (-min_q1_loss - min_q2_loss) * 0.5
                grad = alpha_prime_tape.gradient(alpha_prime_loss, [self.log_alpha_prime])
                self.alpha_prime_optimizer.apply_gradients(zip(grad, [self.log_alpha_prime]))

            qf1_loss = min_q1_loss + quantile_regression_loss(Z_tau_K1, target_Z_tau, tau_hat, next_presum_hat)
            qf2_loss = min_q2_loss + quantile_regression_loss(Z_tau_K2, target_Z_tau, tau_hat, next_presum_hat)
            critic_loss = qf1_loss + qf2_loss

        critic_norm = self._critic_opt(critic_tape, critic_loss)

        """
        Actor training
        """
        if self._num_actor_critic_train_step % self._config.actor_update_frequency == 0:
            with tf.GradientTape() as actor_tape:
                self.actor_loss = self._compute_actor_loss(obs)
            self.actor_norm = self._actor_opt(actor_tape, self.actor_loss)

            # Update target networks (softly)
            self._critic.update_target(self._config.target_update_tau)

        self._num_actor_critic_train_step += 1

        if self._num_actor_critic_train_step % self._config.log_every == 0:
            summaries = dict()
            summaries['critic_loss'] = critic_loss
            summaries['critic_norm'] = critic_norm
            summaries['actor_loss'] = self.actor_loss
            summaries['actor_norm'] = self.actor_norm
            self._write_summaries(summaries, self._num_actor_critic_train_step)

    def _actions_in_expected_shape(self, state, num_actions_repeat):
        obs_temp = tf.expand_dims(state, axis=1)
        obs_temp = tf.tile(obs_temp, [1, num_actions_repeat, 1])
        obs_temp = tf.reshape(obs_temp, [state.shape[0]*num_actions_repeat, state.shape[1]])
        action, log_pi_actions = self._actor.action_log_prob(obs_temp)
        log_pi_actions = tf.reshape(log_pi_actions, [state.shape[0], num_actions_repeat, 1])
        return action, log_pi_actions

    def _q_prediction_in_expected_shape(self, state, actions, tau):
        actions_shape = actions.shape[0]
        state_shape = state.shape[0]
        num_repeat = int(actions_shape / state_shape)

        # reshape state
        state_temp = tf.expand_dims(state, axis=1)
        state_temp = tf.tile(state_temp, [1, num_repeat, 1])
        state_temp = tf.reshape(state_temp, [state.shape[0]*num_repeat, state.shape[1]])

        # tau
        tau_temp = tf.expand_dims(tau, axis=1)
        tau_temp = tf.tile(tau_temp, [1, num_repeat, 1])
        tau_temp = tf.reshape(tau_temp, [tau.shape[0]*num_repeat, tau.shape[1]])

        pred1 = self._critic.q1(state_temp, actions, tau_temp)
        pred2 = self._critic.q2(state_temp, actions, tau_temp)

        pred1 = tf.reshape(pred1, [state.shape[0], num_repeat, -1])
        pred2 = tf.reshape(pred2, [state.shape[0], num_repeat, -1])

        return pred1, pred2

    def _compute_actor_loss(self, obs):
        new_actions, log_pi = self._actor.action_log_prob(obs)
        new_tau, new_tau_hat, new_presum_tau = self.get_tau(obs, new_actions, fp=self._config.fp)
        z1_new_actions = self._critic.q1(obs, new_actions, new_tau_hat)
        z2_new_actions = self._critic.q2(obs, new_actions, new_tau_hat)
        risk_weight = distortion_de(new_tau_hat, self._config.alpha_cvar)

        q1_new = tf.math.reduce_sum(risk_weight*new_presum_tau * z1_new_actions, axis=-1, keepdims=True)
        q2_new = tf.math.reduce_sum(risk_weight*new_presum_tau * z2_new_actions, axis=-1, keepdims=True)
        q = tf.minimum(q1_new, q2_new)

        actor_loss = tf.reduce_mean(self.alpha * log_pi - q)
        return actor_loss

    def select_action(self, feat):
        return self._actor(feat)


