from collections import deque
import time

import gym
from mpi4py import MPI
from scipy.special import softmax
from stable_baselines.common import fmt_row, dataset, tf_util
import numpy as np
from stable_baselines import logger
import tensorflow as tf
from stable_baselines.common.misc_util import flatten_lists
from stable_baselines.common.mpi_adam import MpiAdam

from causal_irl.algorithms.common.base_class import SetVerbosity
from causal_irl.algorithms.common.policies import ActorCriticPolicy
from causal_irl.algorithms.common.runners import traj_segment_generator
from causal_irl.algorithms.gail.model import GAIL


class BC_GAIL(GAIL):

    def __init__(self, policy, env, expert_dataset=None, deterministic=False, lr=1e-4, adam_epsilon=1e-8,
                 val_interval=None, g_step=3, d_step=1, rejection_sampling=False,
                 verbose=0, **kwargs):
        super().__init__(policy, env, expert_dataset=expert_dataset, verbose=verbose, **kwargs)
        self.expert_dataset = expert_dataset
        self.deterministic = deterministic
        self.lr = lr
        self.adam_eps = adam_epsilon
        self.val_interval = val_interval
        self.g_step = g_step
        self.d_step = d_step
        self.setup_pretrain = False
        self.rejection_sampling = rejection_sampling

    def _get_pretrain_placeholders(self):
        policy = self.policy_pi
        action_ph = policy.pdtype.sample_placeholder([None])
        if isinstance(self.action_space, gym.spaces.Discrete):
            if self.deterministic:
                return policy.obs_ph, action_ph, policy.policy
            else:
                return policy.obs_ph, action_ph, policy.policy_proba
        elif self.deterministic:
            return policy.obs_ph, action_ph, policy.deterministic_action
        return policy.obs_ph, action_ph, policy.proba_distribution


    # def get_get_parameter_list(self):
    #     return self.params + self.bc_params

    def setup_model(self):
        # prevent import loops
        from causal_irl.algorithms.gail.adversary import TransitionClassifier

        with SetVerbosity(self.verbose):

            assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the TRPO model must be " \
                                                               "an instance of common.policies.ActorCriticPolicy."

            self.nworkers = MPI.COMM_WORLD.Get_size()
            self.rank = MPI.COMM_WORLD.Get_rank()
            np.set_printoptions(precision=3)

            self.graph = tf.Graph()
            with self.graph.as_default():
                self.set_random_seed(self.seed)
                self.sess = tf_util.make_session(num_cpu=self.n_cpu_tf_sess, graph=self.graph, make_default=True)

                if self.using_gail:
                    self.reward_giver = TransitionClassifier(self.observation_space, self.action_space,
                                                             self.hidden_size_adversary,
                                                             entcoeff=self.adversary_entcoeff)

                # Construct network for new policy
                self.policy_pi = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
                                             None, reuse=False, **self.policy_kwargs)
            tf_util.initialize(sess=self.sess)
            with tf.variable_scope("Adam_mpi", reuse=False):
                if self.using_gail:
                    self.d_adam = MpiAdam(self.reward_giver.get_trainable_variables(), sess=self.sess)
                    # self.d_adam.sync()

            self.params = tf_util.get_trainable_vars("model") + self.reward_giver.get_trainable_variables()


            def allmean(arr):
                        assert isinstance(arr, np.ndarray)
                        out = np.empty_like(arr)
                        MPI.COMM_WORLD.Allreduce(arr, out, op=MPI.SUM)
                        out /= self.nworkers
                        return out
                        
            self.allmean = allmean

            def step(obs, state=None, mask=None, deterministic=False):
                if self.rejection_sampling:
                    # print("rejection_sampling !!!!!")
                    # repeat it observations 10 times parallelize
                    gen_results = [self.policy_pi.step(obs, state, mask, False) for _ in range(10)]
                    gen_actions = [result[0] for result in gen_results]
                    gen_value = [result[1] for result in gen_results]
                    gen_snew = [result[2] for result in gen_results]
                    gen_nlogp = [result[3] for result in gen_results]

                    # print(gen_actions)


                    action_scores = [self.reward_giver.get_discrim_logit(obs, action) for action in gen_actions]
                    action_dist = softmax(action_scores).squeeze()
                    # print(action_dist.shape)
                    # print(action_dist)

                    assert len(action_dist.shape) == 1
                    idx = np.argmax(np.random.multinomial(1, action_dist))
                    pred_action = gen_actions[idx]
                    pred_value = gen_value[idx]
                    pred_snew = gen_snew[idx]
                    pred_nlogp = gen_nlogp[idx]
                    # print(pred_action)
                    return pred_action, pred_value, pred_snew, pred_nlogp
                else:
                    return self.policy_pi.step(obs, state, mask, deterministic)

            self.step = step
            self.proba_step = self.policy_pi.proba_step
            self.initial_state = self.policy_pi.initial_state


    # def step(self, obs, state=None, mask=None, deterministic=False):
    #     if self.rejection_sampling:
    #         print("rejection_sampling !!!!!")
    #         gen_actions = [self.policy_pi.step(obs, state, mask, False) for _ in range(10)]
    #         action_scores = [self.reward_giver.dis_out(obs, action) for action in gen_actions]
    #         action_dist = softmax(action_scores, axis=-1)
    #         print(action_dist)
    #         pred_action = np.random.multinomial(1, action_dist).dot(gen_actions)
    #         print(pred_action)
    #         return pred_action
    #     else:
    #         return self.policy_pi.step(obs, state, mask, deterministic)

    def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4,
                 adam_epsilon=1e-8, val_interval=None):
        """
        Pretrain a model using behavior cloning:
        supervised learning given an expert dataset.

        NOTE: only Box and Discrete spaces are supported for now.

        :param dataset: (ExpertDataset) Dataset manager
        :param n_epochs: (int) Number of iterations on the training set
        :param learning_rate: (float) Learning rate
        :param adam_epsilon: (float) the epsilon value for the adam optimizer
        :param val_interval: (int) Report training and validation losses every n epochs.
            By default, every 10th of the maximum number of epochs.
        :return: (BaseRLModel) the pretrained model
        """
        continuous_actions = isinstance(self.action_space, gym.spaces.Box)
        discrete_actions = isinstance(self.action_space, gym.spaces.Discrete)
        det = "deterministic" if self.deterministic else "stochastic"
        print("Training a {} BC agent".format(det))
        assert discrete_actions or continuous_actions, 'Only Discrete and Box action spaces are supported'

        # Validate the model every 10% of the total number of iteration
        if val_interval is None:
            # Prevent modulo by zero
            if n_epochs < 10:
                val_interval = 1
            else:
                val_interval = int(n_epochs / 10)

        if not self.setup_pretrain:
            with self.graph.as_default():
                with tf.variable_scope('pretrain', reuse=tf.AUTO_REUSE):
                    if continuous_actions:
                        obs_ph, actions_ph, pred_actions_ph = self._get_pretrain_placeholders()
                        if self.deterministic:
                            loss = tf.reduce_mean(tf.square(actions_ph - pred_actions_ph))
                        else:
                            loss = tf.reduce_mean(pred_actions_ph.neglogp(actions_ph)) # change this to logpdf
                    else:
                        obs_ph, actions_ph, actions_logits_ph = self._get_pretrain_placeholders()
                        # actions_ph has a shape if (n_batch,), we reshape it to (n_batch, 1)
                        # so no additional changes is needed in the dataloader
                        actions_ph = tf.expand_dims(actions_ph, axis=1)
                        one_hot_actions = tf.one_hot(actions_ph, self.action_space.n)
                        loss = tf.nn.softmax_cross_entropy_with_logits_v2(
                            logits=actions_logits_ph,
                            labels=tf.stop_gradient(one_hot_actions)
                        )
                        loss = tf.reduce_mean(loss)
                    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=adam_epsilon)
                    optim_op = optimizer.minimize(loss, var_list=tf_util.get_trainable_vars("model"))
                tf_util.initialize(sess=self.sess)
            # self.setup_pretrain = True

        if self.verbose > 0:
            print("Pretraining with Behavior Cloning...")

        best_val_loss = np.inf
        for epoch_idx in range(int(n_epochs)):
            train_loss = 0.0
            # Full pass on the training set
            for _ in range(len(dataset.train_loader)):
                expert_obs, expert_actions = dataset.get_next_batch('train')
                feed_dict = {
                    obs_ph: expert_obs,
                    actions_ph: expert_actions,
                }
                train_loss_, _ = self.sess.run([loss, optim_op], feed_dict)
                train_loss += train_loss_

            train_loss /= len(dataset.train_loader)

            if self.verbose > 0 and (epoch_idx + 1) % val_interval == 0:
                val_loss = 0.0
                # Full pass on the validation set
                for _ in range(len(dataset.val_loader)):
                    expert_obs, expert_actions = dataset.get_next_batch('val')
                    val_loss_, = self.sess.run([loss], {obs_ph: expert_obs,
                                                        actions_ph: expert_actions})
                    val_loss += val_loss_

                val_loss /= len(dataset.val_loader)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    # self.save(self.checkpoint_dir + "bc_gail_{}_{}_{}_{}".format(self.env_id, self.train_mode, self.total_timesteps,
                                                                   # self.seed))
                if self.verbose > 0:
                    print("==== Training progress {:.2f}% ====".format(100 * (epoch_idx + 1) / n_epochs))
                    print('Epoch {}'.format(epoch_idx + 1))
                    print("Training loss: {:.6f}, Validation loss: {:.6f}".format(train_loss, val_loss))
                    print()
            # Free memory            del expert_obs, expert_actions
        if self.verbose > 0:
            print("Pretraining done.")
        return self

    def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="BC",
              reset_num_timesteps=True):
        assert self.expert_dataset is not None, "You must pass an expert dataset to BC for training"


        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)

        seg_gen = traj_segment_generator(self.policy_pi, self.env, self.timesteps_per_batch,
                                         reward_giver=self.reward_giver,
                                         gail=self.using_gail, airl=self.using_airl, callback=callback)

        episodes_so_far = 0
        timesteps_so_far = 0
        iters_so_far = 0
        t_start = time.time()
        len_buffer = deque(maxlen=40)  # rolling buffer for episode lengths
        reward_buffer = deque(maxlen=40)  # rolling buffer for episode rewards

        true_reward_buffer = deque(maxlen=40)
        self._initialize_dataloader()

        while True:
            if timesteps_so_far >= total_timesteps:
                break

            logger.log("********** Iteration %i ************" % iters_so_far)
            logger.log(" Sampling: ")
            seg = seg_gen.__next__()

            logger.log("Optimizing Policy...")
            # Stop training early (triggered by the callback)
            if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                break
            logger.log(" Pretraining: ")
            self.pretrain(self.expert_dataset, n_epochs=self.g_step,
                          learning_rate=self.lr,
                          adam_epsilon=self.adam_eps, val_interval=self.val_interval)

            if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                break

            lens, rews = None, None
            # ------------------ Update D ------------------
            logger.log("Optimizing Discriminator...")
            logger.log(fmt_row(13, self.reward_giver.loss_name))
            observation, action = seg["observations"], seg["actions"]
            assert len(observation) == self.timesteps_per_batch
            batch_size = self.timesteps_per_batch // self.d_step

            # NOTE: uses only the last g step for observation
            d_losses = []  # list of tuples, each of which gives the loss for a minibatch
            # NOTE: for recurrent policies, use shuffle=False?
            for ob_batch, ac_batch in dataset.iterbatches((observation, action),
                                                          include_final_partial_batch=False,
                                                          batch_size=batch_size,
                                                          shuffle=True):
                ob_expert, ac_expert = self.expert_dataset.get_next_batch()
                # update running mean/std for reward_giver
                if self.reward_giver.normalize:
                    self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))

                # Reshape actions if needed when using discrete actions
                if isinstance(self.action_space, gym.spaces.Discrete):
                    if len(ac_batch.shape) == 2:
                        ac_batch = ac_batch[:, 0]
                    if len(ac_expert.shape) == 2:
                        ac_expert = ac_expert[:, 0]

                *newlosses, grad = self.reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert)
                self.d_adam.update(self.allmean(grad), self.d_stepsize)
                d_losses.append(newlosses)
            logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

            # lr: lengths and rewards
            lr_local = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"])  # local values
            list_lr_pairs = MPI.COMM_WORLD.allgather(lr_local)  # list of tuples
            lens, rews, true_rets = map(flatten_lists, zip(*list_lr_pairs))
            true_reward_buffer.extend(true_rets)

            len_buffer.extend(lens)
            reward_buffer.extend(rews)

            if len(len_buffer) > 0:
                logger.record_tabular("EpLenMean", np.mean(len_buffer))
                logger.record_tabular("EpRewMean", np.mean(reward_buffer))
            if self.using_gail or self.using_airl:
                logger.record_tabular("EpTrueRewMean", np.mean(true_reward_buffer))
            logger.record_tabular("EpThisIter", len(lens))
            episodes_so_far += len(lens)
            current_it_timesteps = MPI.COMM_WORLD.allreduce(seg["total_timestep"])
            timesteps_so_far += current_it_timesteps
            self.num_timesteps += current_it_timesteps
            iters_so_far += 1

            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", self.num_timesteps)
            logger.record_tabular("TimeElapsed", time.time() - t_start)

            if self.verbose >= 1 and self.rank == 0:
                logger.dump_tabular()