import tensorflow as tf
from utils.models import ActorOraac, DistributionalCritic
from utils.utils_tools import Adam
from tensorflow_probability import distributions as tfd
from utils.wrappers import *
from tensorflow.keras import losses as tfkl
from utils.utils_tools import quantile_huber_loss
from algo.LatentActor import LatentAlgo
import functools


class Oraac(LatentAlgo):
    def __init__(self, config):
        super().__init__(config)
        self._config = config
        self._actor = ActorOraac(config)
        self._critic = DistributionalCritic(config)
        self.imitation_loss = tfkl.MeanSquaredError()
        self._num_imit_actor_train_step = 0
        self.distr_taus_uniform = tfd.Uniform()
        self.distr_taus_risk = tfd.Uniform(low=0.0, high=config.alpha_cvar)
        self._num_actor_critic_train_step = 0

        Optimizer = functools.partial(Adam,
                                      wd=self._config['weight_decay'],
                                      clip=self._config['grad_clip'])
        self._critic_opt = Optimizer('critic', [self._critic.q1, self._critic.q2], self._config.critic_lr)
        self._actor_opt = Optimizer('actor', [self._actor._cvar_actor], self._config.actor_lr)
        self._imit_actor_opt = Optimizer('imit_actor', [self._actor._imit_actor], self._config.imit_actor_lr)

    def imit_actor_training_step(self, data):
        state = data['obs']
        target_action = data['actions']
        with tf.GradientTape() as imit_tape:
            action = self._actor._imit_actor(state)
            imitation_loss = self.imitation_loss(target_action, action)
        imitation_norm = self._imit_actor_opt(imit_tape, imitation_loss)

        self._num_imit_actor_train_step += 1

        if self._num_imit_actor_train_step % self._config.log_every == 0:
            summaries = dict()
            summaries['imitation_loss'] = imitation_loss
            summaries['imitation_norm'] = imitation_norm
            self._write_summaries(summaries, self._num_imit_actor_train_step)

    def save_imit_actor(self, datadir):
        self._actor.save_imit_actor(datadir)

    def load_imit_actor(self, datadir):
        self._actor.load_imit_actor(datadir)

    def actor_critic_train_step(self, data):
        obs = data['obs']
        actions = data['actions']
        next_obs = data['next_obs']
        rewards = data['rewards']
        terminals = data['terminals']

        # Critic Training:
        tau_k = self.distr_taus_uniform.sample((self._config.N_quantile_critic,))
        tau_k_ = self.distr_taus_uniform.sample((self._config.N_quantile_critic,))
        next_action = self._actor.next_action(obs)

        Z_next_tau_k = tf.reduce_min([self._critic.target_q1.get_sampled_Z(next_obs, tau_k_, next_action),
                                     self._critic.target_q2.get_sampled_Z(next_obs, tau_k_, next_action)],
                                     axis=0)

        with tf.GradientTape() as critic_tape:
            Z_tau_K1 = self._critic.q1.get_sampled_Z(obs, tau_k, actions)
            Z_tau_K2 = self._critic.q2.get_sampled_Z(obs, tau_k, actions)

            terminals = tf.broadcast_to(terminals, shape=Z_next_tau_k.shape)
            rewards = tf.broadcast_to(rewards, shape=Z_next_tau_k.shape)

            target_Z_tau_k = tf.stop_gradient(rewards + self._config.critic_gamma * Z_next_tau_k * (1-terminals))
            critic_loss = quantile_huber_loss(target_Z_tau_k, Z_tau_K1, tau_k) + quantile_huber_loss(target_Z_tau_k, Z_tau_K2, tau_k)

        critic_norm = self._critic_opt(critic_tape, critic_loss)

        # actor_training
        if self._num_actor_critic_train_step % self._config.actor_update_frequency == 0:
            with tf.GradientTape() as actor_tape:
                self.actor_loss = -self._compute_actor_loss(obs)
            self.actor_norm = self._actor_opt(actor_tape, self.actor_loss)

            # Update target networks (softly)
            self._critic.update_target(self._config.target_update_tau)
            self._actor.update_cvar_target(self._config.target_update_tau)

        self._num_actor_critic_train_step += 1

        if self._num_actor_critic_train_step % self._config.log_every == 0:
            summaries = dict()
            summaries['critic_loss'] = critic_loss
            summaries['critic_norm'] = critic_norm
            summaries['actor_loss'] = self.actor_loss
            summaries['actor_norm'] = self.actor_norm
            self._write_summaries(summaries, self._num_actor_critic_train_step)

    def _compute_actor_loss(self, obs):
        action = self._actor(obs, cvar_training=True)
        tau_actor_k = self.distr_taus_risk.sample((self._config.n_quantile_policy,))
        tail_samples = self._critic.get_sampled_Z(obs, tau_actor_k, action)
        cvar = tf.reduce_mean(tail_samples)
        return cvar

    def select_action(self, feat):
        return self._actor(feat, cvar_training=True)


