import tensorflow as tf
import numpy as np


# from inverse_rl.models.fusion_manager import RamFusionDistr
# from inverse_rl.models.imitation_learning import SingleTimestepIRL
from scipy.stats import norm
from stable_baselines.common import tf_util

from causal_irl.algorithms.common.architectures import relu_net
from stable_baselines.common.mpi_running_mean_std import RunningMeanStd
# from inverse_rl.utils import TrainingIterator



class AIRL_Discriminator(object):
    """


    Args:
        fusion (bool): Use trajectories from old iterations to train.
        state_only (bool): Fix the learned reward to only depend on state.
        score_discrim (bool): Use log D - log 1-D as reward (if true you should not need to use an entropy bonus)
        max_itrs (int): Number of training iterations to run per fit step.
    """
    def __init__(self, obs_space, acs_space, policy,
                 reward_arch=relu_net,
                 reward_arch_args=None,
                 value_fn_arch=relu_net,
                 score_discrim=False,
                 discount=1.0,
                 state_only=True,
                 normalize=True,
                 scope='airl_adversary'):
        if reward_arch_args is None:
            reward_arch_args = {}
        self.observation_shape = obs_space.shape
        self.action_shape = acs_space.shape
        self.normalize = normalize
        self.policy = policy
        self.scope = scope
        self.reward_arch = reward_arch
        self.value_fn_arch = value_fn_arch
        # assert isinstance(env.action_space, Box)
        self.score_discrim = score_discrim
        self.gamma = discount
        assert value_fn_arch is not None
        # self.set_demos(expert_trajs)
        self.state_only = state_only
        self.obs_rms = None
        # self.max_itrs=max_itrs

        # Should be batch_size x T x dO/dU
        self.gen_obs_t = tf.placeholder(tf.float32, (None,) + self.observation_shape, name='obs')
        self.gen_nobs_t = tf.placeholder(tf.float32, (None,) + self.observation_shape, name='nobs')
        self.gen_act_t = tf.placeholder(tf.float32, (None,) + self.action_shape, name='act')

        self.expert_obs_t = tf.placeholder(tf.float32, (None,) + self.observation_shape, name='expert_obs')
        self.expert_nobs_t = tf.placeholder(tf.float32, (None,) + self.observation_shape, name='expert_nobs')
        self.expert_act_t = tf.placeholder(tf.float32, (None,) + self.action_shape, name='expert_act')
        self.lprobs = tf.placeholder(tf.float32, [None, 1], name='log_probs')
        self.expert_lprobs = tf.placeholder(tf.float32, [None, 1], name='expert_log_probs')
        # self.lr = tf.placeholder(tf.float32, (), name='lr')
        generator_logits, generator_logits2, gen_reward = self.build_graph(self.gen_obs_t, self.gen_act_t, self.gen_nobs_t, self.lprobs, reuse=False) # log(D) , self.gen_nobs_t
        expert_logits, expert_logits2, expert_reward = self.build_graph(self.expert_obs_t, self.expert_act_t, self.expert_nobs_t, self.expert_lprobs, reuse=True) # self.expert_nobs_t,

        # Build accuracy
        generator_acc = tf.reduce_mean(tf.cast(tf.exp(generator_logits) < 0.5, tf.float32))
        expert_acc = tf.reduce_mean(tf.cast(tf.exp(expert_logits) > 0.5, tf.float32))


        # Build regression loss
        # let x = logits, z = targets.
        # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
        # generator_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=generator_logits,
        #                                                          labels=tf.zeros_like(generator_logits))
        # generator_loss = tf.reduce_mean(generator_loss)
        # expert_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=expert_logits, labels=tf.ones_like(expert_logits))
        # expert_loss = tf.reduce_mean(expert_loss)

        generator_loss = -tf.reduce_mean(generator_logits2)
        expert_loss = -tf.reduce_mean(expert_logits)

        # Debugging
        self.gen_logits = generator_logits
        self.gen_logits2 = generator_logits2
        # self.gen_energy = gen_energy
        self.expert_logits = expert_logits
        self.expert_logits2 = expert_logits2
        # Loss + Accuracy terms
        self.losses = [generator_loss, expert_loss, generator_acc, expert_acc]
        self.loss_name = ["generator_loss", "expert_loss", "generator_acc", "expert_acc"]
        self.total_loss = generator_loss + expert_loss


        # Build Reward for policy
        self.reward = gen_reward
        var_list = self.get_trainable_variables()
        self.lossandgrad = tf_util.function(
            [self.gen_obs_t, self.gen_act_t, self.lprobs, self.gen_nobs_t, self.expert_obs_t, self.expert_act_t, self.expert_lprobs, self.expert_nobs_t],
            self.losses + [tf_util.flatgrad(self.total_loss, var_list)]) #

        # self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(tot_loss)
        # self._make_param_ops(_vs)


    def build_graph(self, obs_ph, acs_ph, nobs_ph, lprobs, reuse=False): #
        with tf.variable_scope(self.scope):
            if reuse:
                tf.get_variable_scope().reuse_variables()

            # self.nobs_shape = tf.shape(nobs_ph)
            if self.normalize:
                with tf.variable_scope("obfilter"):
                    self.obs_rms = RunningMeanStd(shape=self.observation_shape)
                if obs_ph.dtype == tf.float64:
                    obs_ph = tf.cast(obs_ph, tf.float32)
                    nobs_ph = tf.cast(nobs_ph, tf.float32)
                obs = (obs_ph - self.obs_rms.mean) / self.obs_rms.std
                nobs = (nobs_ph - self.obs_rms.mean) / self.obs_rms.std
            else:
                obs = obs_ph
                nobs = nobs_ph
            # obs_act = tf.concat([obs, acs_ph], axis=1)
            # with tf.variable_scope('discrim') as dvs:
                # with tf.variable_scope('energy'):
                    # energy = relu_net(obs_act)
                # we do not learn a separate log Z(s) because it is impossible to separate from the energy
                # In a discrete domain we can explicitly normalize to calculate log Z(s)
                # log_p_tau = -energy


            with tf.variable_scope('reward'):
                reward = self.reward_arch(obs, dout=1)
            # value function shaping
            with tf.variable_scope('vfn'):
                fitted_value_fn_n = self.value_fn_arch(nobs, dout=1)
            with tf.variable_scope('vfn', reuse=True):
                fitted_value_fn = self.value_fn_arch(obs, dout=1)
            # Define log p_tau(a|s) = r + gamma * V(s') - V(s)
            log_p_tau = reward + self.gamma * fitted_value_fn_n - fitted_value_fn
            self.log_p_tau = log_p_tau
            log_q_tau = lprobs
            log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)  # log(exp(f(s, a, s')) + pi(a|s)) the denominator
            self.d_tau = tf.exp(log_p_tau - log_pq)
        return log_p_tau - log_pq, log_q_tau - log_pq, reward


    def get_trainable_variables(self):
        """
        Get all the trainable variables from the graph

        :return: ([tf.Tensor]) the variables
        """
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)



    def get_reward(self, obs, actions, lprobs, nobs, use_reward=False):
        """
        Predict the reward using the observation and action

        :param obs: (tf.Tensor or np.ndarray) the observation
        :param actions: (tf.Tensor or np.ndarray) the action
        :return: (np.ndarray) the reward
        """
        sess = tf.get_default_session()
        if len(nobs.shape) == 1:
            nobs = np.expand_dims(nobs, 0)
        if len(lprobs.shape) == 1:
            lprobs = np.expand_dims(lprobs, 0)
        if len(obs.shape) == 1:
            obs = np.expand_dims(obs, 0)
        if len(actions.shape) == 1:
            actions = np.expand_dims(actions, 0)
        elif len(actions.shape) == 0:
            # one discrete action
            actions = np.expand_dims(actions, 0)
        feed_dict = {self.gen_obs_t: obs, self.gen_act_t: actions, self.lprobs: lprobs, self.gen_nobs_t:nobs}
        # energy = sess.run(self.gen_energy, feed_dict=feed_dict)
        # energy = -energy[:, 0]
        # return energy
        if not use_reward:
            scores = np.exp(sess.run(self.gen_logits, feed_dict))

            score = np.log(scores + 1e-8) - np.log(1 - scores + 1e-8)
            score = score[:, 0]
            return score
        else:
            scores = sess.run(self.reward, feed_dict)
            score = scores[:, 0]
            return score
    def debug(self, obs, acs, obsn, lprob, obs_e, acs_e, obsn_e, lprob_e):
        sess = tf.get_default_session()
        feed_dict = {self.gen_obs_t : obs,
                     self.gen_act_t : acs,
                     self.gen_nobs_t: obsn,
                     self.lprobs: lprob,
                     self.expert_obs_t: obs_e,
                     self.expert_act_t: acs_e,
                     self.expert_nobs_t: obsn_e,
                     self.expert_lprobs: lprob_e,
                     }
        gen_out, expert_out = sess.run([tf.exp(self.gen_logits), tf.exp(self.expert_logits)], feed_dict=feed_dict)
        print(np.mean(gen_out), np.mean(expert_out))


