from collections import deque
import time

import gym
from mpi4py import MPI
from stable_baselines.common import fmt_row, dataset, tf_util
import numpy as np
from stable_baselines import logger
import tensorflow as tf
from stable_baselines.common.misc_util import flatten_lists
from stable_baselines.common.mpi_adam import MpiAdam

from causal_irl.algorithms.airl import AIRL
from causal_irl.algorithms.airl.airl_adversary import AIRL_Discriminator
from causal_irl.algorithms.common.base_class import SetVerbosity
from causal_irl.algorithms.common.policies import ActorCriticPolicy
from causal_irl.algorithms.common.runners import traj_segment_generator


class BC_AIRL(AIRL):

    def __init__(self, policy, env, expert_dataset=None, deterministic=False, lr=1e-3, adam_epsilon=1e-8,
                 val_interval=None, g_step=3, d_step=1,
                 verbose=0, **kwargs):
        super().__init__(policy, env, expert_dataset=expert_dataset, verbose=verbose, **kwargs)
        self.expert_dataset = expert_dataset
        self.deterministic = deterministic
        self.lr = lr
        self.adam_eps = adam_epsilon
        self.val_interval = val_interval
        self.g_step = g_step
        self.d_step = d_step
        self.using_gail = False
        self.using_airl = True

    def _get_pretrain_placeholders(self):
        policy = self.policy_pi
        action_ph = policy.pdtype.sample_placeholder([None])
        if isinstance(self.action_space, gym.spaces.Discrete):
            if self.deterministic:
                return policy.obs_ph, action_ph, policy.policy
            else:
                return policy.obs_ph, action_ph, policy.policy_proba
        elif self.deterministic:
            return policy.obs_ph, action_ph, policy.deterministic_action
        return policy.obs_ph, action_ph, policy.proba_distribution



    def setup_model(self):
        # prevent import loops
        from causal_irl.algorithms.gail.adversary import TransitionClassifier

        with SetVerbosity(self.verbose):

            assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the TRPO model must be " \
                                                               "an instance of common.policies.ActorCriticPolicy."

            self.nworkers = MPI.COMM_WORLD.Get_size()
            self.rank = MPI.COMM_WORLD.Get_rank()
            np.set_printoptions(precision=3)

            self.graph = tf.Graph()
            with self.graph.as_default():
                self.set_random_seed(self.seed)
                self.sess = tf_util.make_session(num_cpu=self.n_cpu_tf_sess, graph=self.graph, make_default=True)

                # Construct network for new policy
                self.policy_pi = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
                                             None, reuse=False, **self.policy_kwargs)
                self.reward_giver = AIRL_Discriminator(self.observation_space, self.action_space, self.policy_pi)
            tf_util.initialize(sess=self.sess)
            with tf.variable_scope("Adam_mpi", reuse=False):
                if self.using_airl:
                    self.d_adam = MpiAdam(self.reward_giver.get_trainable_variables(), sess=self.sess)
                    # self.d_adam.sync()

            self.params = tf_util.get_trainable_vars("model") + self.reward_giver.get_trainable_variables()

            def allmean(arr):
                        assert isinstance(arr, np.ndarray)
                        out = np.empty_like(arr)
                        MPI.COMM_WORLD.Allreduce(arr, out, op=MPI.SUM)
                        out /= self.nworkers
                        return out
            self.allmean = allmean
            self.step = self.policy_pi.step
            self.proba_step = self.policy_pi.proba_step
            self.initial_state = self.policy_pi.initial_state


    def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4,
                 adam_epsilon=1e-8, val_interval=None):
        """
        Pretrain a model using behavior cloning:
        supervised learning given an expert dataset.

        NOTE: only Box and Discrete spaces are supported for now.

        :param dataset: (ExpertDataset) Dataset manager
        :param n_epochs: (int) Number of iterations on the training set
        :param learning_rate: (float) Learning rate
        :param adam_epsilon: (float) the epsilon value for the adam optimizer
        :param val_interval: (int) Report training and validation losses every n epochs.
            By default, every 10th of the maximum number of epochs.
        :return: (BaseRLModel) the pretrained model
        """
        continuous_actions = isinstance(self.action_space, gym.spaces.Box)
        discrete_actions = isinstance(self.action_space, gym.spaces.Discrete)
        det = "deterministic" if self.deterministic else "stochastic"
        print("Training a {} BC agent".format(det))
        assert discrete_actions or continuous_actions, 'Only Discrete and Box action spaces are supported'

        # Validate the model every 10% of the total number of iteration
        if val_interval is None:
            # Prevent modulo by zero
            if n_epochs < 10:
                val_interval = 1
            else:
                val_interval = int(n_epochs / 10)

        with self.graph.as_default():
            with tf.variable_scope('pretrain', reuse=tf.AUTO_REUSE):
                if continuous_actions:
                    obs_ph, actions_ph, pred_actions_ph = self._get_pretrain_placeholders()
                    if self.deterministic:
                        loss = tf.reduce_mean(tf.square(actions_ph - pred_actions_ph))
                    else:
                        loss = tf.reduce_mean(pred_actions_ph.neglogp(actions_ph)) # change this to logpdf
                else:
                    obs_ph, actions_ph, actions_logits_ph = self._get_pretrain_placeholders()
                    # actions_ph has a shape if (n_batch,), we reshape it to (n_batch, 1)
                    # so no additional changes is needed in the dataloader
                    actions_ph = tf.expand_dims(actions_ph, axis=1)
                    one_hot_actions = tf.one_hot(actions_ph, self.action_space.n)
                    loss = tf.nn.softmax_cross_entropy_with_logits_v2(
                        logits=actions_logits_ph,
                        labels=tf.stop_gradient(one_hot_actions)
                    )
                    loss = tf.reduce_mean(loss)
                optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=adam_epsilon)
                optim_op = optimizer.minimize(loss, var_list=tf_util.get_trainable_vars("model"))
            tf_util.initialize(sess=self.sess)

        if self.verbose > 0:
            print("Pretraining with Behavior Cloning...")

        for epoch_idx in range(int(n_epochs)):
            train_loss = 0.0
            # Full pass on the training set
            for _ in range(len(dataset.train_loader)):
                expert_obs, expert_actions = dataset.get_next_batch('train')
                feed_dict = {
                    obs_ph: expert_obs,
                    actions_ph: expert_actions,
                }
                train_loss_, _ = self.sess.run([loss, optim_op], feed_dict)
                train_loss += train_loss_

            train_loss /= len(dataset.train_loader)

            if self.verbose > 0 and (epoch_idx + 1) % val_interval == 0:
                val_loss = 0.0
                # Full pass on the validation set
                for _ in range(len(dataset.val_loader)):
                    expert_obs, expert_actions = dataset.get_next_batch('val')
                    val_loss_, = self.sess.run([loss], {obs_ph: expert_obs,
                                                        actions_ph: expert_actions})
                    val_loss += val_loss_

                val_loss /= len(dataset.val_loader)
                if self.verbose > 0:
                    print("==== Training progress {:.2f}% ====".format(100 * (epoch_idx + 1) / n_epochs))
                    print('Epoch {}'.format(epoch_idx + 1))
                    print("Training loss: {:.6f}, Validation loss: {:.6f}".format(train_loss, val_loss))
                    print()
            # Free memory
            del expert_obs, expert_actions
        if self.verbose > 0:
            print("Pretraining done.")
        return self

    def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="BC",
              reset_num_timesteps=True):
        assert self.expert_dataset is not None, "You must pass an expert dataset to BC for training"


        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)

        seg_gen = traj_segment_generator(self.policy_pi, self.env, self.timesteps_per_batch,
                                         reward_giver=self.reward_giver,
                                         gail=self.using_gail, airl=self.using_airl, callback=callback)

        episodes_so_far = 0
        timesteps_so_far = 0
        iters_so_far = 0
        t_start = time.time()
        len_buffer = deque(maxlen=40)  # rolling buffer for episode lengths
        reward_buffer = deque(maxlen=40)  # rolling buffer for episode rewards

        true_reward_buffer = deque(maxlen=40)
        self._initialize_dataloader()

        while True:
            if timesteps_so_far >= total_timesteps:
                break

            logger.log("********** Iteration %i ************" % iters_so_far)
            logger.log(" Sampling: ")
            seg = seg_gen.__next__()

            logger.log("Optimizing Policy...")
            # Stop training early (triggered by the callback)
            if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                break
            logger.log(" Pretraining: ")
            self.pretrain(self.expert_dataset, n_epochs=self.g_step,
                          learning_rate=self.lr,
                          adam_epsilon=self.adam_eps, val_interval=self.val_interval)

            if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                break

            lens, rews = None, None

            # ------------------ Update D ------------------
            logger.log("Optimizing Discriminator...")
            logger.log(fmt_row(13, self.reward_giver.loss_name))
            observation, action = seg["observations"], seg["actions"]
            assert len(observation) == self.timesteps_per_batch
            batch_size = self.timesteps_per_batch // self.d_step

            # NOTE: uses only the last g step for observation
            d_losses = []  # list of tuples, each of which gives the loss for a minibatch
            # NOTE: for recurrent policies, use shuffle=False?
            for ob_batch, ac_batch in dataset.iterbatches((observation, action),
                                                          include_final_partial_batch=False,
                                                          batch_size=batch_size,
                                                          shuffle=True):
                ob_expert, ac_expert = self.expert_dataset.get_next_batch()
                # update running mean/std for reward_giver
                if self.reward_giver.normalize:
                    self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))

                # Reshape actions if needed when using discrete actions
                if isinstance(self.action_space, gym.spaces.Discrete):
                    if len(ac_batch.shape) == 2:
                        ac_batch = ac_batch[:, 0]
                    if len(ac_expert.shape) == 2:
                        ac_expert = ac_expert[:, 0]

                nob_batch = ob_batch[1:]
                nob_batch = np.r_[nob_batch, 0.0 * np.expand_dims(np.ones_like(nob_batch[0]), axis=0)]
                nob_expert = ob_expert[1:]
                nob_expert = np.r_[nob_expert, 0.0 * np.expand_dims(np.ones_like(nob_expert[0]), axis=0)]

                rand_actions = np.random.rand(ob_expert.shape[0], 3)

                def calculate_lprobs(obs, actions, policy):
                    nlogprob = policy.proba_obs_action(obs, actions)
                    return np.exp(-nlogprob)

                lprob_batch = calculate_lprobs(ob_batch, ac_batch, self.policy_pi)[:, None]
                lprob_expert = calculate_lprobs(ob_expert, ac_expert, self.policy_pi)[:, None]
                lprob_random = calculate_lprobs(ob_expert, rand_actions, self.policy_pi)[:, None]
                # print(np.mean(lprob_batch, axis=0), np.mean(lprob_expert, axis=0), np.mean(lprob_random, axis=0))
                # print(self.reward_giver.get_trainable_variables())
                # self.reward_giver.debug(ob_batch, ac_batch, nob_batch, lprob_batch,
                #                                                  ob_expert, ac_expert, nob_expert, lprob_expert)
                *newlosses, grad = self.reward_giver.lossandgrad(ob_batch, ac_batch, nob_batch, lprob_batch,
                                                                 ob_expert, ac_expert, nob_expert, lprob_expert)

                # print(grad)
                self.d_adam.update(self.allmean(grad), self.d_stepsize)
                d_losses.append(newlosses)
                logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                # lr: lengths and rewards
                lr_local = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"])  # local values
                list_lr_pairs = MPI.COMM_WORLD.allgather(lr_local)  # list of tuples
                lens, rews, true_rets = map(flatten_lists, zip(*list_lr_pairs))
                true_reward_buffer.extend(true_rets)

            len_buffer.extend(lens)
            reward_buffer.extend(rews)

            if len(len_buffer) > 0:
                logger.record_tabular("EpLenMean", np.mean(len_buffer))
                logger.record_tabular("EpRewMean", np.mean(reward_buffer))
            if self.using_gail or self.using_airl:
                logger.record_tabular("EpTrueRewMean", np.mean(true_reward_buffer))
            logger.record_tabular("EpThisIter", len(lens))
            episodes_so_far += len(lens)
            current_it_timesteps = MPI.COMM_WORLD.allreduce(seg["total_timestep"])
            timesteps_so_far += current_it_timesteps
            self.num_timesteps += current_it_timesteps
            iters_so_far += 1

            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", self.num_timesteps)
            logger.record_tabular("TimeElapsed", time.time() - t_start)

            if self.verbose >= 1 and self.rank == 0:
                logger.dump_tabular()