import os

import gym
from stable_baselines.common import tf_util

from causal_irl.algorithms.gail.model import GAIL
import tensorflow as tf
import numpy as np

from causal_irl.algorithms.trpo_mpi.trpo_mpi import TRPO
from causal_irl.envs.my_observation_wrapper import MyObservationWrapper


class BC(TRPO):

    def __init__(self, policy, env, expert_dataset=None, deterministic=True, lr=1e-4, adam_epsilon=1e-8, action_noise=0.0, ob_noise=0.0,
                 val_interval=None, save_path=None, exp_code='000', dropout_rate=0.0,
                 verbose=0, **kwargs):
        super().__init__(policy, env, verbose=verbose, **kwargs)
        self.expert_dataset = expert_dataset
        self.action_noise = action_noise
        self.ob_noise = ob_noise
        self.deterministic = deterministic
        self.lr = lr
        self.adam_eps = adam_epsilon
        self.val_interval = val_interval
        self.trained_epochs = 0
        self.save_path = save_path
        self.exp_code = exp_code
        self.det = "deterministic" if self.deterministic else "stochastic"
        self.drop_rate = dropout_rate

    def setup_model(self):
        self.graph = tf.Graph()
        with self.graph.as_default():
            self.set_random_seed(self.seed)
            self.sess = tf_util.make_session(num_cpu=self.n_cpu_tf_sess, graph=self.graph, make_default=True)

            # Construct network for new policy
            self.policy_pi = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
                                         None, reuse=False, **self.policy_kwargs)
        tf_util.initialize(sess=self.sess)

        self.params = tf_util.get_trainable_vars("model")
        self.step = self.policy_pi.step
        self.proba_step = self.policy_pi.proba_step
        self.initial_state = self.policy_pi.initial_state

    def _get_pretrain_placeholders(self):
        policy = self.policy_pi
        action_ph = policy.pdtype.sample_placeholder([None])
        if isinstance(self.action_space, gym.spaces.Discrete):
            if self.deterministic:
                return policy.obs_ph, action_ph, policy.policy, policy.rate_ph
            else:
                return policy.obs_ph, action_ph, policy.policy_proba, policy.rate_ph
        elif self.deterministic:
            return policy.obs_ph, action_ph, policy.deterministic_action, policy.rate_ph
        return policy.obs_ph, action_ph, policy.proba_distribution, policy.rate_ph

    def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4,
                 adam_epsilon=1e-8, val_interval=None):
        """
        Pretrain a model using behavior cloning:
        supervised learning given an expert dataset.

        NOTE: only Box and Discrete spaces are supported for now.

        :param dataset: (ExpertDataset) Dataset manager
        :param n_epochs: (int) Number of iterations on the training set
        :param learning_rate: (float) Learning rate
        :param adam_epsilon: (float) the epsilon value for the adam optimizer
        :param val_interval: (int) Report training and validation losses every n epochs.
            By default, every 10th of the maximum number of epochs.
        :return: (BaseRLModel) the pretrained model
        """
        continuous_actions = isinstance(self.action_space, gym.spaces.Box)
        discrete_actions = isinstance(self.action_space, gym.spaces.Discrete)

        print("Training a {} BC agent".format(self.det))
        assert discrete_actions or continuous_actions, 'Only Discrete and Box action spaces are supported'

        # Validate the model every 10% of the total number of iteration
        if val_interval is None:
            # Prevent modulo by zero
            if n_epochs < 10:
                val_interval = 1
            else:
                val_interval = int(n_epochs / 10)

        with self.graph.as_default():
            with tf.variable_scope('pretrain', reuse=tf.AUTO_REUSE):
                if continuous_actions:
                    obs_ph, actions_ph, pred_actions_ph, rate_ph = self._get_pretrain_placeholders()
                    if self.deterministic:
                        loss = tf.reduce_mean(tf.square(actions_ph - pred_actions_ph))
                    else:
                        loss = tf.reduce_mean(pred_actions_ph.neglogp(actions_ph)) # change this to logpdf
                else:
                    obs_ph, actions_ph, actions_logits_ph, rate_ph = self._get_pretrain_placeholders()
                    # actions_ph has a shape if (n_batch,), we reshape it to (n_batch, 1)
                    # so no additional changes is needed in the dataloader
                    actions_ph = tf.expand_dims(actions_ph, axis=1)
                    one_hot_actions = tf.one_hot(actions_ph, self.action_space.n)
                    loss = tf.nn.softmax_cross_entropy_with_logits_v2(
                        logits=actions_logits_ph,
                        labels=tf.stop_gradient(one_hot_actions)
                    )
                    loss = tf.reduce_mean(loss)
                optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=adam_epsilon)
                optim_op = optimizer.minimize(loss, var_list=tf_util.get_trainable_vars("model"))

            tf_util.initialize(sess=self.sess)

        if self.verbose > 0:
            print("Pretraining with Behavior Cloning...")

        best_val_loss = np.inf
        train_losses = []
        val_losses = []
        for epoch_idx in range(int(n_epochs)):
            train_loss = 0.0
            # Full pass on the training set
            # import pdb; pdb.set_trace()
            for _ in range(len(dataset.train_loader)):
                expert_obs, expert_actions = dataset.get_next_batch('train')
                expert_actions = expert_actions + np.random.randn(*self.action_space.shape) * self.action_noise
                expert_obs = expert_obs + np.random.randn(*self.observation_space.shape) * self.ob_noise
                feed_dict = {
                    obs_ph: expert_obs,
                    actions_ph: expert_actions,
                    rate_ph: self.drop_rate
                }
                train_loss_, _ = self.sess.run([loss, optim_op], feed_dict)
                train_loss += train_loss_

            train_loss /= len(dataset.train_loader)
            train_losses.append(train_loss)
            if self.verbose > 0 and (epoch_idx + 1) % val_interval == 0:
                val_loss = 0.0
                # Full pass on the validation set
                for _ in range(len(dataset.val_loader)):
                    expert_obs, expert_actions = dataset.get_next_batch('val')
                    expert_actions = expert_actions + np.random.randn(*self.action_space.shape) * self.action_noise
                    expert_obs = expert_obs + np.random.randn(*self.observation_space.shape) * self.ob_noise
                    val_loss_, = self.sess.run([loss], {obs_ph: expert_obs,
                                                        actions_ph: expert_actions})
                    val_loss += val_loss_

                val_loss /= len(dataset.val_loader)
                val_losses.append(val_loss)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    model_name = 'bc' + \
                                 ('' if self.dropout_rate == 0.0 else str(self.dropout_rate)) + \
                                 ('' if self.action_noise == 0.0 else '-action-noise{}'.format(self.action_noise)) + \
                                 ('' if self.ob_noise == 0.0 else '-ob-noise{}'.format(self.ob_noise))
                    self.save(os.path.join(self.save_path, self.exp_code+"_best_{}_{}_traj{}_{}_{}_{}".format(model_name,
                                                                                            self.env.unwrapped.spec.id,
                                                                                            self.expert_dataset.num_traj,
                                                                                            self.det,
                                                                                            self.env.mode if isinstance(self.env, MyObservationWrapper) else 'testing',
                                                                                            self.seed)))
                if self.verbose > 0:
                    print("==== Training progress {:.2f}% ====".format(100 * (epoch_idx + 1) / n_epochs))
                    print('Epoch {}'.format(epoch_idx + 1))
                    print("Training loss: {:.6f}, Validation loss: {:.6f}".format(train_loss, val_loss))
                    print()
            # Free memory
            del expert_obs, expert_actions
            self.trained_epochs += 1
        if self.verbose > 0:
            print("Pretraining done.")
        return range(n_epochs), train_losses, val_losses

    def learn(self, total_timesteps=10, callback=None, log_interval=100, tb_log_name="BC",
              reset_num_timesteps=True):
        assert self.expert_dataset is not None, "You must pass an expert dataset to BC for training"

        return self.pretrain(self.expert_dataset, n_epochs=total_timesteps,
                      learning_rate=self.lr,
                      adam_epsilon=self.adam_eps, val_interval=self.val_interval)


    def save(self, save_path, cloudpickle=False):
        data = {
            "verbose": self.verbose,
            "action_noise": self.action_noise,
            "ob_noise": self.ob_noise,
            "policy": self.policy,
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "n_envs": self.n_envs,
            "n_cpu_tf_sess": self.n_cpu_tf_sess,
            "seed": self.seed,
            "_vectorize_action": self._vectorize_action,
            "trained_epochs": self.trained_epochs,
            "val_interval": self.val_interval,
            "save_path": self.save_path,
            "deterministic": self.deterministic,
            "lr": self.lr,
            "adam_eps": self.adam_eps,
            "det": self.det,
            "exp_code": self.exp_code,
            "drop_rate": self.drop_rate,
            "policy_kwargs": self.policy_kwargs
        }

        params_to_save = self.get_parameters()

        self._save_to_file(save_path, data=data, params=params_to_save, cloudpickle=cloudpickle)