import tensorflow as tf
from utils.models import Actor, Critic
from utils.utils_tools import Adam
from utils.wrappers import *
from algo.LatentActor import LatentAlgo
import functools


class Combo(LatentAlgo):
    def __init__(self, config):
        super().__init__(config)
        self._config = config
        # models
        self._actor = Actor(config)
        self._critic = Critic(config)

        self._num_actor_critic_train_step = 0

        # optimizer
        Optimizer = functools.partial(Adam,
                                      wd=self._config['weight_decay'],
                                      clip=self._config['grad_clip'])
        self._critic_opt = Optimizer('critic', [self._critic.qf1, self._critic.qf2], self._config.critic_lr)
        self._actor_opt = Optimizer('actor', [self._actor._actor], self._config.actor_lr)

    def actor_critic_train_step(self, real_data, synthetic_data):

        synthetic_obs = synthetic_data['obs']
        synthetic_actions = synthetic_data['actions']
        synthetic_next_obs = synthetic_data['next_obs']
        synthetic_rewards = synthetic_data['rewards']
        synthetic_terminals = synthetic_data['terminals']

        real_obs = real_data['obs']
        real_actions = real_data['actions']
        real_next_obs = real_data['next_obs']
        real_rewards = real_data['rewards']
        real_terminals = real_data['terminals']

        obs = tf.concat([real_obs, synthetic_obs], axis=0)
        actions = tf.concat([real_actions, synthetic_actions], axis=0)
        next_obs = tf.concat([real_next_obs, synthetic_next_obs], axis=0)
        rewards = tf.concat([real_rewards, synthetic_rewards], axis=0)
        terminals = tf.concat([real_terminals, synthetic_terminals], axis=0)

        with tf.GradientTape() as q_tape:
            q1_pred = self._critic.qf1(tf.concat([obs, actions], axis=-1))
            q2_pred = self._critic.qf2(tf.concat([obs, actions], axis=-1))

            new_next_actions = self._actor(next_obs)

            target_q_values = tf.reduce_min([self._critic.target_qf1(tf.concat([next_obs, new_next_actions], axis=-1)),
                                            self._critic.target_qf2(tf.concat([next_obs, new_next_actions], axis=-1))],
                                            axis=0)
            q_target = rewards + self._config.discount * (1.0 - terminals) * target_q_values
            # combo penalization
            new_actions = self._actor(synthetic_obs)
            expanded_actions = tf.expand_dims(synthetic_actions, 0)
            tilled_actions = tf.tile(expanded_actions, [self._config.cql_samples, 1, 1])
            tilled_actions = tf.random.uniform(tilled_actions.shape, minval=-1, maxval=1)
            tilled_actions = tf.concat([tilled_actions, tf.expand_dims(new_actions, 0)], axis=0)

            expanded_obs = tf.expand_dims(synthetic_obs, 0)
            tilled_obs = tf.tile(expanded_obs, [self._config.cql_samples + 1, 1, 1])

            q1_values = self._critic.qf1(tf.concat([tilled_obs, tilled_actions], axis=-1))
            q2_values = self._critic.qf2(tf.concat([tilled_obs, tilled_actions], axis=-1))
            real_q1_values = self._critic.qf1(tf.concat([real_obs, real_actions], axis=-1))
            real_q2_values = self._critic.qf2(tf.concat([real_obs, real_actions], axis=-1))
            q1_penalty = tf.math.reduce_logsumexp(q1_values, axis=0)
            q2_penalty = tf.math.reduce_logsumexp(q2_values, axis=0)

            q1_penal = self._config.combo_beta * (tf.reduce_mean(q1_penalty) - tf.reduce_mean(real_q1_values))
            qf1_loss = q1_penal + 0.5*(tf.reduce_mean((q1_pred - tf.stop_gradient(q_target)) ** 2))

            q2_penal = self._config.combo_beta * (tf.reduce_mean(q2_penalty) - tf.reduce_mean(real_q2_values))
            qf2_loss = q2_penal + 0.5*(tf.reduce_mean((q2_pred - tf.stop_gradient(q_target)) ** 2))

            q_loss = qf1_loss + qf2_loss

        with tf.GradientTape() as actor_tape:
            new_obs_actions = self._actor(obs)
            q_new_actions = tf.reduce_min([self._critic.qf1(tf.concat([obs, new_obs_actions], axis=-1)),
                                           self._critic.qf2(tf.concat([obs, new_obs_actions], axis=-1))], axis=0)
            actor_loss = -tf.reduce_mean(q_new_actions)

        q_norm = self._critic_opt(q_tape, q_loss)
        actor_norm = self._actor_opt(actor_tape, actor_loss)
        self._num_actor_critic_train_step += 1

        if self._num_actor_critic_train_step % self._config.target_update_interval == 0:
            self._update_target_critics()

        if self._num_actor_critic_train_step % self._config.log_every == 0:
            agent_summaries = dict()
            agent_summaries['agent/Q1_value'] = tf.reduce_mean(q1_pred)
            agent_summaries['agent/Q2_value'] = tf.reduce_mean(q2_pred)
            agent_summaries['agent/Q_target'] = tf.reduce_mean(q_target)
            agent_summaries['agent/Q_loss'] = q_loss
            agent_summaries['agent/actor_loss'] = actor_loss
            agent_summaries['agent/Q_grad_norm'] = q_norm
            agent_summaries['agent/actor_grad_norm'] = actor_norm
            self._write_summaries(agent_summaries, self._num_actor_critic_train_step)

    def _update_target_critics(self):
        tau = tf.constant(self._config.tau)
        for source_weight, target_weight in zip(self._critic.qf1.trainable_variables,
                                                self._critic.target_qf1.trainable_variables):
                target_weight.assign(tau * source_weight + (1.0 - tau) * target_weight)
        for source_weight, target_weight in zip(self._critic.qf2.trainable_variables,
                                                self._critic.target_qf2.trainable_variables):
                target_weight.assign(tau * source_weight + (1.0 - tau) * target_weight)

    def select_action(self, feat):
        return self._actor(feat)

