import numpy as np
import tensorflow as tf
from tensorflow_gan.python.losses import losses_impl as tfgan_losses
import tensorflow_probability as tfp
from tf_agents.networks import network, utils
from tf_agents.specs.tensor_spec import TensorSpec
import pickle
import os

EPS = np.finfo(np.float32).eps
EPS2 = 1e-3


class TanhActor(network.Network):
    def __init__(self, state_dim, action_dim, hidden_size=256, name='TanhNormalPolicy',
                 mean_range=(-7., 7.), logstd_range=(-5., 2.), eps=EPS, initial_std_scaler=1,
                 kernel_initializer='he_normal', activation_fn=tf.nn.relu):
        self._input_specs = TensorSpec(state_dim)
        self._action_dim = action_dim
        self._initial_std_scaler = initial_std_scaler

        super(TanhActor, self).__init__(self._input_specs, state_spec=(), name=name)

        hidden_sizes = (hidden_size, hidden_size)

        self._fc_layers = utils.mlp_layers(fc_layer_params=hidden_sizes, activation_fn=activation_fn,
                                           kernel_initializer=kernel_initializer, name='mlp')
        self._fc_mean = tf.keras.layers.Dense(action_dim, name='policy_mean/dense',
                                              kernel_initializer=kernel_initializer)
        self._fc_logstd = tf.keras.layers.Dense(action_dim, name='policy_logstd/dense',
                                                kernel_initializer=kernel_initializer)

        self.mean_min, self.mean_max = mean_range
        self.logstd_min, self.logstd_max = logstd_range
        self.eps = eps

    def call(self, inputs, step_type=(), network_state=(), training=True):
        del step_type  # unused
        h = inputs
        for layer in self._fc_layers:
            h = layer(h, training=training)

        mean = self._fc_mean(h)
        mean = tf.clip_by_value(mean, self.mean_min, self.mean_max)
        logstd = self._fc_logstd(h)
        logstd = tf.clip_by_value(logstd, self.logstd_min, self.logstd_max)
        std = tf.exp(logstd) * self._initial_std_scaler
        pretanh_action_dist = tfp.distributions.MultivariateNormalDiag(mean, std)
        pretanh_action = pretanh_action_dist.sample()
        action = tf.tanh(pretanh_action)
        log_prob, pretanh_log_prob = self.log_prob(pretanh_action_dist, pretanh_action, is_pretanh_action=True)

        return (tf.tanh(mean), action, log_prob), network_state

    def log_prob(self, pretanh_action_dist, action, is_pretanh_action=True):
        if is_pretanh_action:
            pretanh_action = action
            action = tf.tanh(pretanh_action)
        else:
            pretanh_action = tf.atanh(tf.clip_by_value(action, -1 + self.eps, 1 - self.eps))

        pretanh_log_prob = pretanh_action_dist.log_prob(pretanh_action)
        log_prob = pretanh_log_prob - tf.reduce_sum(tf.math.log(1 - action ** 2 + self.eps), axis=-1)

        return log_prob, pretanh_log_prob

    def get_log_prob(self, states, actions):
        """Evaluate log probs for actions conditined on states.
        Args:
            states: A batch of states.
            actions: A batch of actions to evaluate log probs on.
        Returns:
            Log probabilities of actions.
        """
        h = states
        for layer in self._fc_layers:
            h = layer(h, training=True)

        mean = self._fc_mean(h)
        mean = tf.clip_by_value(mean, self.mean_min, self.mean_max)
        logstd = self._fc_logstd(h)
        logstd = tf.clip_by_value(logstd, self.logstd_min, self.logstd_max)
        std = tf.exp(logstd) * self._initial_std_scaler

        pretanh_action_dist = tfp.distributions.MultivariateNormalDiag(mean, std)
        pretanh_actions = tf.atanh(tf.clip_by_value(actions, -1 + self.eps, 1 - self.eps))
        pretanh_log_prob = pretanh_action_dist.log_prob(pretanh_actions)

        log_probs = pretanh_log_prob - tf.reduce_sum(tf.math.log(1 - actions ** 2 + self.eps), axis=-1)
        log_probs = tf.expand_dims(log_probs, -1)  # To avoid broadcasting
        return log_probs


class DiscreteActor(network.Network):
    def __init__(self, state_dim, action_dim, hidden_size=256, name='DiscretePolicy',
                 kernel_initializer='he_normal', activation_fn=tf.nn.relu):
        self._input_specs = TensorSpec(state_dim)
        self._action_dim = action_dim

        super(DiscreteActor, self).__init__(self._input_specs, state_spec=(), name=name)

        hidden_sizes = (hidden_size, hidden_size)

        self._fc_layers = utils.mlp_layers(fc_layer_params=hidden_sizes, activation_fn=activation_fn, kernel_initializer=kernel_initializer, name='mlp')
        self._logit_layer = tf.keras.layers.Dense(action_dim, name='logits/dense', kernel_initializer=kernel_initializer)

    def call(self, inputs, step_type=(), network_state=(), training=True):
        h = inputs
        for layer in self._fc_layers:
            h = layer(h, training=training)

        logits = self._logit_layer(h)
        dist = tfp.distributions.OneHotCategorical(logits)
        action = tf.cast(dist.sample(), tf.float32)
        greedy_action = tf.one_hot(tf.argmax(logits, axis=1), self._action_dim)
        log_prob = dist.log_prob(action)

        return (greedy_action, action, log_prob), network_state

    def get_log_prob(self, states, actions, training=True):
        """Evaluate log probs for actions conditined on states.
        Args:
          states: A batch of states.
          actions: A batch of actions to evaluate log probs on.
        Returns:
          Log probabilities of actions.
        """
        # h = tf.concat(states, axis=-1)
        h = states
        for layer in self._fc_layers:
            h = layer(h, training=training)

        logits = self._logit_layer(h)
        dist = tfp.distributions.OneHotCategorical(logits)

        log_probs = tf.expand_dims(dist.log_prob(actions), -1)  # To avoid broadcasting?

        return log_probs


class Critic(network.Network):
    def __init__(self, state_dim, action_dim, hidden_size=256, output_activation_fn=None, use_last_layer_bias=False,
                 output_dim=None, kernel_initializer='he_normal', name='ValueNetwork'):
        self._input_specs = TensorSpec(state_dim + action_dim)
        self._output_dim = output_dim

        super(Critic, self).__init__(self._input_specs, state_spec=(), name=name)

        hidden_sizes = (hidden_size, hidden_size)

        self._fc_layers = utils.mlp_layers(fc_layer_params=hidden_sizes, activation_fn=tf.nn.relu,
                                           kernel_initializer=kernel_initializer, name='mlp')
        if use_last_layer_bias:
            last_layer_initializer = tf.keras.initializers.RandomUniform(-3e-3, 3e-3)
            self._last_layer = tf.keras.layers.Dense(output_dim or 1, activation=output_activation_fn,
                                                     kernel_initializer=last_layer_initializer,
                                                     bias_initializer=last_layer_initializer, name='value')
        else:
            self._last_layer = tf.keras.layers.Dense(output_dim or 1, activation=output_activation_fn, use_bias=False,
                                                     kernel_initializer=kernel_initializer, name='value')

    def call(self, inputs, step_type=(), network_state=(), training=False):
        del step_type  # unused
        h = inputs
        for layer in self._fc_layers:
            h = layer(h, training=training)
        h = self._last_layer(h)

        if self._output_dim is None:
            h = tf.reshape(h, [-1])

        return h, network_state


class DemoDICE(tf.keras.layers.Layer):
    """ Class that implements DemoDICE training """
    def __init__(self, state_dim, action_dim, is_discrete_action: bool, config):
        super(DemoDICE, self).__init__()
        hidden_size = config['hidden_size']
        critic_lr = config['critic_lr']
        actor_lr = config['actor_lr']
        self.is_discrete_action = is_discrete_action
        self.grad_reg_coeffs = config['grad_reg_coeffs']
        self.discount = config['gamma']
        self.non_expert_regularization = config['alpha'] + 1.

        self.cost = Critic(state_dim, action_dim, hidden_size=hidden_size,
                                 use_last_layer_bias=config['use_last_layer_bias_cost'],
                                 kernel_initializer=config['kernel_initializer'])
        self.critic = Critic(state_dim, 0, hidden_size=hidden_size,
                                   use_last_layer_bias=config['use_last_layer_bias_critic'],
                                   kernel_initializer=config['kernel_initializer'])
        if self.is_discrete_action:
            self.actor = DiscreteActor(state_dim, action_dim)
        else:
            self.actor = TanhActor(state_dim, action_dim, hidden_size=hidden_size)

        self.cost.create_variables()
        self.critic.create_variables()
        self.actor.create_variables()

        self.cost_optimizer = tf.keras.optimizers.Adam(learning_rate=critic_lr)
        self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=critic_lr)
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)

    def eval(self):...
    

    def train(self, buffer_expert=None, buffer_union=None, union_init_state=None, iters=1, batch_size=100):

        if(buffer_expert is None and buffer_union is None):
            return
        if(buffer_expert is None or buffer_union is None):
            raise ValueError("Both buffer_expert and buffer_offline should be either None of not None")
        
        log_dict = {
            'cost_loss': [],
            'nu_loss': [],
            'actor_loss': [],
            'expert_nu': [],
            'union_nu': [],
            'init_nu': [],
            'union_adv': [],
        }
        for i in range(iters):

            union_init_indices = np.random.randint(0, len(union_init_state), size=batch_size)
            init_state = union_init_state[union_init_indices]
            batch_e = buffer_expert.sample(batch_size)
            batch_u = buffer_union.sample(batch_size)

            batch_e = [b.cpu().numpy() for b in batch_e]
            batch_u = [b.cpu().numpy() for b in batch_u]

            state_e, action_e, next_state_e, reward_e, not_done_e, flag_e = batch_e
            state_u, action_u, next_state_u, reward_u, not_done_u, flag_u = batch_u

            init_state = np.c_[init_state, np.zeros(init_state.shape[0])].astype(np.float32)
            state_e = np.c_[state_e, np.zeros(state_e.shape[0])].astype(np.float32)
            next_state_e = np.c_[next_state_e, np.zeros(next_state_e.shape[0])].astype(np.float32)
            state_u = np.c_[state_u, np.zeros(state_u.shape[0])].astype(np.float32)
            next_state_u = np.c_[next_state_u, np.zeros(next_state_u.shape[0])].astype(np.float32)

            log_dict_ = self.update(init_state,
                                    state_e,
                                    action_e,
                                    next_state_e,
                                    state_u,
                                    action_u,
                                    next_state_u
                                    )
            log_dict["cost_loss"].append(log_dict_["cost_loss"])
            log_dict["nu_loss"].append(log_dict_["nu_loss"])
            log_dict["actor_loss"].append(log_dict_["actor_loss"])
            log_dict["expert_nu"].append(log_dict_["expert_nu"])
            log_dict["union_nu"].append(log_dict_["union_nu"])
            log_dict["init_nu"].append(log_dict_["init_nu"])
            log_dict["union_adv"].append(log_dict_["union_adv"])


        return log_dict

    @tf.function
    def update(self, init_states, expert_states, expert_actions, expert_next_states,
               union_states, union_actions, union_next_states):
        with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape:
            tape.watch(self.cost.variables)
            tape.watch(self.actor.variables)
            tape.watch(self.critic.variables)

            # define inputs
            expert_inputs = tf.concat([expert_states, expert_actions], -1)
            union_inputs = tf.concat([union_states, union_actions], -1)

            # call cost functions
            expert_cost_val, _ = self.cost(expert_inputs)
            union_cost_val, _ = self.cost(union_inputs)
            unif_rand = tf.random.uniform(shape=(expert_states.shape[0], 1))
            mixed_inputs1 = unif_rand * expert_inputs + (1 - unif_rand) * union_inputs
            mixed_inputs2 = unif_rand * tf.random.shuffle(union_inputs) + (1 - unif_rand) * union_inputs
            mixed_inputs = tf.concat([mixed_inputs1, mixed_inputs2], 0)

            # gradient penalty for cost
            with tf.GradientTape(watch_accessed_variables=False) as tape2:
                tape2.watch(mixed_inputs)
                cost_output, _ = self.cost(mixed_inputs)
                cost_output = tf.math.log(1 / (tf.nn.sigmoid(cost_output) + EPS2) - 1 + EPS2)
            cost_mixed_grad = tape2.gradient(cost_output, [mixed_inputs])[0] + EPS
            cost_grad_penalty = tf.reduce_mean(
                tf.square(tf.norm(cost_mixed_grad, axis=-1, keepdims=True) - 1))
            cost_loss = tfgan_losses.minimax_discriminator_loss(expert_cost_val, union_cost_val, label_smoothing=0.) \
                        + self.grad_reg_coeffs[0] * cost_grad_penalty
            union_cost = tf.math.log(1 / (tf.nn.sigmoid(union_cost_val) + EPS2) - 1 + EPS2)

            # nu learning
            init_nu, _ = self.critic(init_states)
            expert_nu, _ = self.critic(expert_states)
            expert_next_nu, _ = self.critic(expert_next_states)
            union_nu, _ = self.critic(union_states)
            union_next_nu, _ = self.critic(union_next_states)
            union_adv_nu = - tf.stop_gradient(union_cost) + self.discount * union_next_nu - union_nu

            non_linear_loss = self.non_expert_regularization * tf.reduce_logsumexp(
                union_adv_nu / self.non_expert_regularization)
            linear_loss = (1 - self.discount) * tf.reduce_mean(init_nu)
            nu_loss = non_linear_loss + linear_loss

            # weighted BC
            weight = tf.expand_dims(tf.math.exp((union_adv_nu - tf.reduce_max(union_adv_nu)) / self.non_expert_regularization), 1)
            weight = weight / tf.reduce_mean(weight)
            pi_loss = - tf.reduce_mean(
                tf.stop_gradient(weight) * self.actor.get_log_prob(union_states, union_actions))

            # gradient penalty for nu
            if self.grad_reg_coeffs[1] is not None:
                unif_rand2 = tf.random.uniform(shape=(expert_states.shape[0], 1))
                nu_inter = unif_rand2 * expert_states + (1 - unif_rand2) * union_states
                nu_next_inter = unif_rand2 * expert_next_states + (1 - unif_rand2) * union_next_states

                nu_inter = tf.concat([union_states, nu_inter, nu_next_inter], 0)

                with tf.GradientTape(watch_accessed_variables=False) as tape3:
                    tape3.watch(nu_inter)
                    nu_output, _ = self.critic(nu_inter)

                nu_mixed_grad = tape3.gradient(nu_output, [nu_inter])[0] + EPS
                nu_grad_penalty = tf.reduce_mean(
                    tf.square(tf.norm(nu_mixed_grad, axis=-1, keepdims=True)))
                nu_loss += self.grad_reg_coeffs[1] * nu_grad_penalty

        nu_grads = tape.gradient(nu_loss, self.critic.variables)
        cost_grads = tape.gradient(cost_loss, self.cost.variables)
        pi_grads = tape.gradient(pi_loss, self.actor.variables)
        self.critic_optimizer.apply_gradients(zip(nu_grads, self.critic.variables))
        self.cost_optimizer.apply_gradients(zip(cost_grads, self.cost.variables))
        self.actor_optimizer.apply_gradients(zip(pi_grads, self.actor.variables))
        info_dict = {
            'cost_loss': cost_loss,
            'nu_loss': nu_loss,
            'actor_loss': pi_loss,
            'expert_nu': tf.reduce_mean(expert_nu),
            'union_nu': tf.reduce_mean(union_nu),
            'init_nu': tf.reduce_mean(init_nu),
            'union_adv': tf.reduce_mean(union_adv_nu),
        }
        del tape
        return info_dict

    @tf.function
    def select_action(self, observation, device='cuda', deterministic: bool = True):
        observation = tf.convert_to_tensor([observation], dtype=tf.float32)
        all_actions, _ = self.actor(observation)
        if deterministic:
            actions = all_actions[0]
        else:
            actions = all_actions[1]
        return actions


    def get_training_state(self):
        training_state = {
            'cost_params': [(variable.name, variable.value().numpy()) for variable in self.cost.variables],
            'critic_params': [(variable.name, variable.value().numpy()) for variable in self.critic.variables],
            'actor_params': [(variable.name, variable.value().numpy()) for variable in self.actor.variables],
            'cost_optimizer_state': [(variable.name, variable.value().numpy()) for variable in self.cost_optimizer.variables()],
            'critic_optimizer_state': [(variable.name, variable.value().numpy()) for variable in self.critic_optimizer.variables()],
            'actor_optimizer_state': [(variable.name, variable.value().numpy()) for variable in self.actor_optimizer.variables()],
        }
        return training_state

    def set_training_state(self, training_state):
        def _assign_values(variables, params):
            if len(variables) != len(params):
                import pdb; pdb.set_trace()
            assert len(variables) == len(params)
            for variable, (name, value) in zip(variables, params):
                assert variable.name == name
                variable.assign(value)

        _assign_values(self.cost.variables, training_state['cost_params'])
        _assign_values(self.critic.variables, training_state['critic_params'])
        _assign_values(self.actor.variables, training_state['actor_params'])
        _assign_values(self.cost_optimizer.variables(), training_state['cost_optimizer_state'])
        _assign_values(self.critic_optimizer.variables(), training_state['critic_optimizer_state'])
        _assign_values(self.actor_optimizer.variables(), training_state['actor_optimizer_state'])

    def init_dummy(self, state_dim, action_dim):
        # dummy train_step (to create optimizer variables)
        dummy_state = np.zeros((1, state_dim), dtype=np.float32)
        dummy_action = np.zeros((1, action_dim), dtype=np.float32)
        self.update(dummy_state, dummy_state, dummy_action, dummy_state, dummy_state, dummy_action, dummy_state)
        
    def save(self, filepath):
        print('Save checkpoint: ', filepath)
        training_state = self.get_training_state()
        data = {
            'training_state': training_state,
        }
        with open(filepath + '.tmp', 'wb') as f:
            pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
        os.rename(filepath + '.tmp', filepath)
        print('Saved!')

    def load(self, filepath):
        print('Load checkpoint:', filepath)
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        self.set_training_state(data['training_state'])
        return data
