import tensorflow as tf
from utils.models import Actor, Critic
from utils.utils_tools import Adam
from utils.wrappers import *
from algo.LatentActor import LatentAlgo
import functools


class Lompo(LatentAlgo):
    def __init__(self, config):
        super().__init__(config)
        self._config = config
        # models
        self._actor = Actor(config)
        self._critic = Critic(config)

        self._num_actor_critic_train_step = 0

        # optimizer
        Optimizer = functools.partial(Adam,
                                      wd=self._config['weight_decay'],
                                      clip=self._config['grad_clip'])
        self._critic_opt = Optimizer('critic', [self._critic.qf1, self._critic.qf2], self._config.critic_lr)
        self._actor_opt = Optimizer('actor', [self._actor._actor], self._config.actor_lr)

    def actor_critic_train_step(self, data):
        obs = data['obs']
        actions = data['actions']
        next_obs = data['next_obs']
        rewards = data['rewards']
        terminals = data['terminals']

        with tf.GradientTape() as q_tape:
            q1_pred = self._critic.qf1(tf.concat([obs, actions], axis=-1))
            q2_pred = self._critic.qf2(tf.concat([obs, actions], axis=-1))
            new_actions = self._actor(obs)
            new_next_actions = self._actor(next_obs)

            target_q_values = tf.reduce_min([self._critic.target_qf1(tf.concat([next_obs, new_next_actions], axis=-1)),
                                            self._critic.target_qf2(tf.concat([next_obs, new_next_actions], axis=-1))],
                                            axis=0)
            q_target = rewards + self._config.discount * (1.0 - terminals) * target_q_values

            qf1_loss = tf.reduce_mean((q1_pred - tf.stop_gradient(q_target)) ** 2)
            qf2_loss = tf.reduce_mean((q2_pred - tf.stop_gradient(q_target)) ** 2)
            q_loss = qf1_loss + qf2_loss

        with tf.GradientTape() as actor_tape:
            new_obs_actions = self._actor(obs)
            q_new_actions = tf.reduce_min([self._critic.qf1(tf.concat([obs, new_obs_actions], axis=-1)),
                                           self._critic.qf2(tf.concat([obs, new_obs_actions], axis=-1))], axis=0)
            actor_loss = -tf.reduce_mean(q_new_actions)

        q_norm = self._critic_opt(q_tape, q_loss)
        actor_norm = self._actor_opt(actor_tape, actor_loss)
        self._num_actor_critic_train_step += 1

        if self._num_actor_critic_train_step % self._config.target_update_interval == 0:
            self._update_target_critics()

        if self._num_actor_critic_train_step % self._config.log_every == 0:
            agent_summaries = dict()
            agent_summaries['agent/Q1_value'] = tf.reduce_mean(q1_pred)
            agent_summaries['agent/Q2_value'] = tf.reduce_mean(q2_pred)
            agent_summaries['agent/Q_target'] = tf.reduce_mean(q_target)
            agent_summaries['agent/Q_loss'] = q_loss
            agent_summaries['agent/actor_loss'] = actor_loss
            agent_summaries['agent/Q_grad_norm'] = q_norm
            agent_summaries['agent/actor_grad_norm'] = actor_norm
            self._write_summaries(agent_summaries, self._num_actor_critic_train_step)

    def _update_target_critics(self):
        tau = tf.constant(self._config.tau)
        for source_weight, target_weight in zip(self._critic.qf1.trainable_variables,
                                                self._critic.target_qf1.trainable_variables):
                target_weight.assign(tau * source_weight + (1.0 - tau) * target_weight)
        for source_weight, target_weight in zip(self._critic.qf2.trainable_variables,
                                                self._critic.target_qf2.trainable_variables):
                target_weight.assign(tau * source_weight + (1.0 - tau) * target_weight)

    def select_action(self, feat):
        return self._actor(feat)

