import argparse
from argparse import Namespace
import json
import pprint

import torch.nn.functional as F
from gym.spaces import Discrete
from stable_baselines3.common import logger
from stable_baselines3.common.utils import configure_logger
from stable_baselines3.common.callbacks import CallbackList, EvalCallback
from torch.distributions import Normal, Categorical
from tqdm import tqdm

from buf import *
from env_utils import *
from irl import AIRLDiscriminator, AIRLInvDiscriminator, GAILDiscriminator
from ppo import PPO
from preference import PreferenceReward
from utils import *


class Trainer:
    def __init__(self, opt):
        self.opt = opt
        self.cuda = torch.cuda.is_available()
        if args.resume is None:
            self.output_dir = self.create_output_dir(opt)
        else:
            self.output_dir = opt.resume
        if opt.discriminator_type is not None:
            self.expert_demos = load_expert_demos(opt, opt.env_name, opt.env_kwargs)
        self.env_dict, self.testing_env = self.make_envs(opt)
        self.ppo_dict, self.d_dict, self.pref_reward_model, self.callback_dict = self.create_models(self.env_dict, opt)
        self.buf_dict = self.make_bufs(opt)
        self.summary_writer = SummaryWriter(self.output_dir)
        self.logging_interval = opt.logging_interval

        # summary_writer.add_hparams(vars(opt), vars(opt))
        self.summary_writer.add_text('params', str(opt))

    def make_bufs(self, opt):
        buf_dict = {}
        for k in self.env_dict.keys():
            # buf_dict[k] = Buffer(opt.buf_size, get_buf_dtype(self.env_dict[k]))
            if opt.use_sb_ppo:
                buf_dict[k] = self.ppo_dict[k].rollout_buffer
            else:
                buf_dict[k] = RolloutBuffer(capacity=opt.n_steps * opt.n_envs)

        return buf_dict

    def make_envs(self, opt):
        env_dict = {}
        wrapper_kwargs = {}
        if self.opt.env_kwargs is not None:
            for spec in self.opt.env_kwargs:
                envs, testing_env = make_venv(self.opt,
                                              self.opt.n_envs,
                                              spec,
                                              self.opt.env_spec_test,
                                              wrapper_kwargs, use_subprocess=opt.use_subprocess,
                                              use_rank=opt.use_seed_ranking)
                env_dict[str(spec)] = envs

        else:
            env_dict['all'], testing_env = make_venv(self.opt,
                                                     self.opt.n_envs,
                                                     spec=None, 
                                                     spec_test=None,
                                                     wrapper_kwargs=None,
                                                     use_subprocess=opt.use_subprocess,
                                                     use_rank=opt.use_seed_ranking)

        return env_dict, testing_env

    # create preference learning model
    def create_pref_model(self, opt):
        pref_reward_model = PreferenceReward(opt)
        if os.path.exists('reward_models/pref_reward.th'):
            pref_reward_model.load_state_dict(torch.load('reward_models/pref_reward.th'))

        return pref_reward_model

    # create a dict of ppo and discriminator models (keys=envs)
    def create_models(self, env_dict, opt):
        ppo_dict = {}
        discriminator_dict = {}
        callback_dict = {}
        for k in env_dict.keys():
            if isinstance(env_dict[k].action_space, Discrete):
                pi_dist = 'Categorical'
                nr_actions = env_dict[k].action_space.n
            else:
                pi_dist = 'Normal'
                nr_actions = 1

            ac_base_type = opt.ac_base_type
            use_cnn_base = False
            if 'MiniGrid' in opt.env_name:
                nr_actions = 7
                if opt.minigrid_wrapper == 'img':
                    ac_base_type = 'minigridcnn'
                    use_cnn_base = True

            # create discriminator(*s) for every env (*only applicable to IRM Games formulation)
            # this needs to happen before the policy gets initialised as we need the IRL reward for the env wrapper
            if opt.discriminator_type == 'airl':
                discriminator_dict[k] = AIRLDiscriminator(env_dict[k], opt.d_layer_dims,
                                                          lr=opt.lr,
                                                          gamma=opt.gamma,
                                                          use_actions=opt.use_actions,
                                                          irm_coeff=opt.irm_coeff,
                                                          lip_coeff=opt.lip_coeff,
                                                          use_cnn_base=use_cnn_base)

            elif opt.discriminator_type == 'airl_invrat':
                discriminator_dict[k] = AIRLInvDiscriminator(env_dict[k], opt.d_layer_dims,
                                                             lr=opt.lr,
                                                             gamma=opt.gamma,
                                                             use_actions=opt.use_actions,
                                                             irm_coeff=opt.irm_coeff,
                                                             lip_coeff=opt.lip_coeff,
                                                             use_cnn_base=use_cnn_base)
            elif opt.discriminator_type == 'gail':
                discriminator_dict[k] = GAILDiscriminator(env_dict[k], opt.d_layer_dims,
                                                          lr=opt.lr,
                                                          use_actions=opt.use_actions,
                                                          irm_coeff=opt.irm_coeff,
                                                          lip_coeff=opt.lip_coeff,
                                                          use_cnn_base=use_cnn_base)

            # resume from checkpoint if specified
            if opt.resume is not None:
                discriminator_dict[k].load_state_dict(torch.load(os.path.join(opt.resume, 'disc' + str(k))))


            # (a bit hacky.. ) repack the venv with
            if opt.discriminator_type is not None:
                if opt.single_disc:
                    id = list(env_dict.keys())[0]
                    env_dict[k] = repack_vecenv(env_dict[k], disc=discriminator_dict[id])
                else:
                    env_dict[k] = repack_vecenv(env_dict[k], disc=discriminator_dict[k])

            # create policy
            if opt.use_sb_ppo:
                sb_yml = open(opt.sb_config)
                sb_args = yaml.load(sb_yml)[opt.env_name]
                policy = sb_args['policy']
                sb_args = process_sb_args(sb_args)
                ppo_dict[k] = PPOSB(policy, env_dict[k], **sb_args, tensorboard_log=self.output_dir)
                if opt.resume is not None:
                    ppo_dict[k].load(os.path.join(opt.resume, "best_model.zip"), env_dict[k])
                eval_callback = EvalCallback(self.testing_env, best_model_save_path=self.output_dir,
                                                            log_path=self.output_dir, eval_freq=10000,
                                                            deterministic=True, render=False)
                #eval_callback.init_callback(ppo_dict[k])
                callback_list = CallbackList([CustomCallback(id=k, log_path=self.output_dir),
                                              eval_callback])
                callback_list.init_callback(ppo_dict[k])
                callback_dict[k] = callback_list
                configure_logger(tensorboard_log=self.output_dir)
            else:
                ppo_dict[k] = PPO(env_dict[k], opt.ac_layer_dims, nr_actions, opt.lr,
                              opt.ppo_eps, opt.ppo_vcoeff, opt.ppo_entcoeff, opt.max_grad_norm,
                              dist=pi_dist, ac_base_type=ac_base_type, vf_clip=opt.ppo_vfclip)
                if opt.resume is not None:
                    ppo_dict[k].ac.load_state_dict(torch.load(os.path.join(opt.resume, 'ppo' + str(k))))

            # create preference reward if necessary
            if opt.pretrained_pref_reward is not None:
                pref_reward_model = self.create_pref_model(opt)
            else:
                pref_reward_model = None

        return ppo_dict, discriminator_dict, pref_reward_model, callback_dict

    def create_output_dir(self, opt):
        ts = datetime.now().strftime('%Y%m%d_%H%M%S')
        if opt.discriminator_type is not None:
            if opt.irm_coeff > 0:
                if opt.irmg:
                    irm_flag = 'irmg'
                else:
                    irm_flag = 'irm'
            else:
                irm_flag = ''

            output_dir = os.path.join(opt.output_dir, opt.discriminator_type + '_'
                                      + irm_flag + '_' + ts + '_' + opt.exp_id)
        else:
            output_dir = os.path.join(opt.output_dir, 'ppo_' + ts + '_' + opt.exp_id)

        return output_dir

    def gather_experience(self, ob, env_id, total_numsteps):
        env = self.env_dict[env_id]
        ppo = self.ppo_dict[env_id]
        if self.opt.discriminator_type is not None:
            discriminator = self.d_dict[env_id]

        # clear buf dict about to be filled
        self.buf_dict[env_id].clear()

        # temporary storage
        obs = []
        acs = []
        next_obs = []
        lprobs = []
        dones = []
        gt_rewards = []
        irl_rewards = []
        pref_rewards = []
        masks = []
        values = []
        ep_rew_gt = 0
        ep_rew_irl = 0

        with torch.no_grad():
            for i_env in range(self.opt.n_steps):
                ob_t = torch.from_numpy(ob).type(torch.get_default_dtype())
                ac_t = ppo.get_action(ob_t)
                lpr = ppo.get_lprobs(ob_t, ac_t)
                ac = ac_t.detach().numpy()
                action_finite = np.isfinite(ac).all()
                if not action_finite:
                    print(ac)

                # Rescale and perform action
                clipped_actions = ac
                # Clip the actions to avoid out of bound error
                if isinstance(env.action_space, gym.spaces.Box):
                    clipped_actions = np.clip(ac, env.action_space.low, env.action_space.high)

                next_ob, gt_reward, done, _ = env.step(clipped_actions)
                next_ob_t = torch.from_numpy(next_ob).type(torch.get_default_dtype())
                # hacky for shape fitting purposes
                ep_rew_gt += np.mean(gt_reward)
                if self.opt.discriminator_type is not None:
                    gt_rewards.append(gt_reward)
                    irl_reward = discriminator.get_reward(ob_t, ac_t, next_ob_t).detach().numpy()
                    # irl_rewards.append(irl_reward[:, np.newaxis])
                    irl_rewards.append(irl_reward)
                    ep_rew_irl += irl_reward
                else:
                    gt_rewards.append(gt_reward[:, np.newaxis])
                    irl_reward = np.zeros([self.opt.n_envs, 1])
                    irl_rewards.append(irl_reward)

                if self.opt.pretrained_pref_reward is not None:
                    pref_reward = self.pref_reward_model.get_reward(ob_t, ac_t).detach().numpy()
                else:
                    pref_reward = np.zeros([self.opt.n_envs, 1])

                # check if episode has ended
                if np.any(done):
                    self.summary_writer.add_scalar('Reward/Ep_rewards_gt', ep_rew_gt, total_numsteps + i_env)
                    self.summary_writer.add_scalar('Reward/Ep_rewards_irl', np.mean(ep_rew_irl), total_numsteps + i_env)
                    ep_rew_gt = 0
                    ep_rew_irl = 0

                v_pred = ppo.get_value(ob_t).detach().numpy()

                pref_rewards.append(pref_reward)
                values.append(v_pred)
                masks.append((1 - done)[:, np.newaxis])
                obs.append(ob)
                acs.append(ac)
                lprobs.append(lpr.detach().numpy())
                next_obs.append(next_ob)
                dones.append(done)

                ob = next_ob
                # print("One iteration takes: ", time.time()-ts)

            # get last value
            next_value = ppo.get_value(next_ob_t).detach().numpy()

        # TODO: see if there is a potential issue with gae estimation based on a limited nr_steps
        if self.opt.discriminator_type is not None:
            adv, v_tgt = compute_gae(next_value, irl_rewards, masks, values, lam=self.opt.gae_lambda)
        else:
            adv, v_tgt = compute_gae(next_value, gt_rewards, masks, values, lam=self.opt.gae_lambda)

        # flatten data
        obs = np.concatenate(obs, 0)
        acs = np.concatenate(acs, 0)
        next_obs = np.concatenate(next_obs, 0)
        gt_rewards = np.concatenate(gt_rewards, 0)

        irl_rewards = np.concatenate(irl_rewards, 0)
        pref_rewards = np.concatenate(pref_rewards, 0)
        dones = np.concatenate(dones, 0)
        values = np.concatenate(values, 0)
        lprobs = np.concatenate(lprobs, 0)
        adv = np.reshape(adv, (-1, 1))
        v_tgt = np.reshape(v_tgt, (-1, 1))

        # print(obs.shape, acs.shape, next_obs.shape, gt_rewards.shape, irl_rewards.shape,
        # pref_rewards.shape, dones.shape, values.shape, v_tgt.shape, adv.shape, lprobs.shape)

        # add transition to dict
        for i in range(self.opt.n_steps * self.opt.n_envs):
            # ('state', 'action', 'next_state', 'gt_reward', 'done', 'irl_reward',
            # 'pref_reward', 'v_pred', 'v_tgt', 'adv', 'logp_a'))
            self.buf_dict[env_id].push(obs[i], acs[i], next_obs[i], gt_rewards[i], dones[i],
                                       irl_rewards[i], pref_rewards[i],
                                       values[i], v_tgt[i], adv[i], lprobs[i])

        if total_numsteps % self.logging_interval == 0:
            self.summary_writer.add_scalar('Reward/GT_rewards', np.mean(gt_rewards), total_numsteps)
            self.summary_writer.add_scalar('Reward/IRL rewards', np.mean(irl_rewards), total_numsteps)

    def update_policy(self, opt, env_id, whosturn, total_numsteps, update_iteration):
        """
        :param opt: cmd line options dict
        :param env_id: environment id
        :param whosturn: for IPO / IRMG formulation, the updates take turn according to it
        :param total_numsteps:
        :param update_iteration:
        :return:
        """
        ppo = self.ppo_dict[env_id]
        buf = self.buf_dict[env_id]
        ppo.assign_old_pi()

        for ep in range(opt.n_epochs):
            # invariant policy optimization
            if opt.ipo:
                mu_dict = {}
                mu_old_dict = {}
                value_dict = {}
                # sample from buffer who's turn it is
                transitions = self.buf_dict[whosturn].sample(opt.batch_size)
                batch = Transition(*zip(*transitions))

                state_batch = torch.from_numpy(np.stack(batch.state)).type(torch.get_default_dtype())
                action_batch = torch.from_numpy(np.stack(batch.action))
                if opt.discriminator_type is not None:
                    returns_batch = torch.from_numpy(np.stack(batch.irl_reward)).type(torch.get_default_dtype())
                else:
                    returns_batch = torch.from_numpy(np.stack(batch.gt_reward)).type(torch.get_default_dtype())

                advs_batch = torch.from_numpy(np.stack(batch.adv)).type(torch.get_default_dtype())
                # normalize advantages
                advs_batch = (advs_batch - advs_batch.mean()) / (advs_batch.std() + 1e-8)

                # forward all models
                for k in self.env_dict:
                    h = self.ppo_dict[k].ac.base(state_batch)
                    h_v = self.ppo_dict[k].ac.base_v(state_batch)
                    mu = self.ppo_dict[k].ac.actor(h)
                    h_ = self.ppo_dict[k].old_ac.base(state_batch)
                    mu_old = self.ppo_dict[k].old_ac.actor(h_)
                    v = self.ppo_dict[k].ac.critic(h_v)

                    mu_dict[k] = mu
                    mu_old_dict[k] = mu_old
                    value_dict[k] = v

                # calculate ensemble
                mu_avg = torch.mean(torch.stack((list(mu_dict.values()))))
                mu_old_avg = torch.mean(torch.stack((list(mu_old_dict.values()))))

                value_avg = torch.mean(torch.stack((list(value_dict.values()))))
                if isinstance(self.env_dict[k].action_space, Discrete):
                    pi = Categorical(mu_avg)
                    pi_old = Categorical(mu_old_avg)
                else:
                    pi = Normal(mu_avg, np.exp(ppo.ac.log_std))
                    pi_old = Normal(mu_old_avg, np.exp(ppo.old_ac.log_std))

                ens_ppo_loss, actor_loss, critic_loss, ent = ppo.calc_loss(pi,
                                                                           pi.log_prob(action_batch),
                                                                           pi_old.log_prob(action_batch),
                                                                           value_avg, returns_batch, advs_batch)

                ppo.actor_loss = actor_loss
                ppo.critic_loss = critic_loss
                ppo.ent = ent
                # ensemble loss
                # only update the discriminator who's turn it is (best response dynamics)
                self.ppo_dict[whosturn].step(ens_ppo_loss)
            else:
                # update actor and critic networks
                buf.shuffle()
                for batch in buf.iter(opt.batch_size):

                    # transitions = buf.sample(opt.batch_size)
                    batch = Transition(*zip(*batch))

                    state_batch = torch.from_numpy(np.stack(batch.state)).type(torch.get_default_dtype())
                    action_batch = torch.from_numpy(np.stack(batch.action))
                    if opt.discriminator_type is not None:
                        returns_batch = torch.from_numpy(np.stack(batch.irl_reward)).type(torch.get_default_dtype())
                    else:
                        returns_batch = torch.from_numpy(np.stack(batch.gt_reward)).type(torch.get_default_dtype())
                    advs_batch = torch.from_numpy(np.stack(batch.adv)).type(torch.get_default_dtype())
                    advs_batch = (advs_batch - advs_batch.mean()) / (advs_batch.std() + 1e-8)
                    # Normalizing the rewards:
                    # returns_batch = (returns_batch - returns_batch.mean()) / (returns_batch.std() + 1e-5)
                    # if update_iteration > 10 and update_iteration % opt.logging_interval == 0:
                    #     self.summary_writer.add_scalar('PPO_adv_norms/advs_norm',
                    #                                advs_batch.norm(2).item(), idx)

                    ppo.update(state_batch, action_batch, returns_batch, advs_batch)

            if update_iteration > 10 and update_iteration % opt.logging_interval == 0:
                self.summary_writer.add_scalar('PPO/actor_loss_' + env_id,
                                               np.mean(ppo.actor_loss.detach().numpy()), total_numsteps)
                self.summary_writer.add_scalar('PPO/critic_loss_' + env_id,
                                               np.mean(ppo.critic_loss.detach().numpy()), total_numsteps)
                self.summary_writer.add_scalar('PPO/entropy_' + env_id, np.mean(ppo.ent.detach().numpy()),
                                               total_numsteps)

                # plot parameter norms
                # for i, p in enumerate(list(ppo.ac.get_actor_parameters())):
                # self.summary_writer.add_scalar('PPO_param_norms/actor_grad_norm_' + env_id + str(i),
                # p.data.norm(2).item(),total_numsteps)
                # for i, p in enumerate(list(ppo.ac.get_critic_parameters())):
                # self.summary_writer.add_scalar('PPO_param_norms/critic_grad_norm_' + env_id + str(i),
                # p.data.norm(2).item(), total_numsteps)

                # for i, p in enumerate(list(filter(lambda p: p.grad is not None, ppo.ac.get_actor_parameters()))):
                #     self.summary_writer.add_scalar('PPO_grad_norms/actor_grad_norm_' + env_id + str(i),
                #                                    p.grad.data.norm(2).item(), idx)
                #     self.summary_writer.add_histogram("gradients/actor",
                #                                       torch.cat([p.data.view(-1) for p in ppo.ac.actor.parameters()]),
                #                                       global_step=idx)
                # for i, p in enumerate(list(filter(lambda p: p.grad is not None, ppo.ac.critic.parameters()))):
                #     self.summary_writer.add_scalar('PPO_grad_norms/critic_grad_norm_' + env_id + str(i),
                #                                    p.grad.data.norm(2).item(), idx)
                #     self.summary_writer.add_histogram("gradients/critic",
                #              torch.cat([p.data.view(-1) for p in ppo.ac.critic.parameters()]), global_step=idx)

    def update_discriminator(self, opt, policy_buffer_gen, ppo_env_id, disc_env_id, update_iteration, whosturn):
        """
        :param disc_env_id: environment id the discriminator is used on
        :param ppo_env_id: environment id the ppo policy interacts with
        :param opt: command line options dict
        :param policy_buffer_gen: policy buffer iterator
        :param update_iteration:
        :param whosturn: for IRMG formulation, the updates take turn according to it
        :return:
        """
        ppo = self.ppo_dict[ppo_env_id]
        discriminator = self.d_dict[disc_env_id]

        bce_losses = {}
        policy_estimates = {}
        expert_estimates = {}
        grad_pens = {}
        d_out_dict = {}

        for irl_epoch in range(opt.n_irl_epochs):
            transitions = next(policy_buffer_gen)
            if opt.use_sb_ppo:
                policy_state_batch = transitions.observations
                policy_action_batch = transitions.actions
            else:
                student_batch = Transition(*zip(*transitions))
                policy_state_batch = np.array(list(student_batch.state))
                policy_action_batch = np.array(list(student_batch.action))
            if opt.discriminator_type == 'gail':
                # GAIL Discriminator Update
                for k in self.env_dict:
                    gail_update_dict = prepare_update_gail(self.env_dict[k], opt, self.expert_demos[k],
                                                      policy_state_batch, policy_action_batch)
                    if opt.irmg:
                        d_out = discriminator.forward(gail_update_dict['all_obs'], gail_update_dict['all_acs'])
                        d_out_dict[k] = d_out

                    _, bce_loss, grad_pen = discriminator.compute_loss(gail_update_dict)
                    bce_losses[k] = bce_loss
                    grad_pens[k] = grad_pen

                bce_loss_all = torch.stack(list(bce_losses.values())).mean()
                grad_pen_all = torch.stack(list(grad_pens.values())).mean()
                loss = bce_loss_all + opt.irm_coeff * grad_pen_all
                if opt.irm_coeff > 1.0:
                    loss /= opt.irm_coeff

                if opt.irmg:
                    stack = torch.stack(list(d_out_dict.values()))
                    avg_dout = torch.mean(stack, dim=0)
                    expert_out, policy_out = torch.chunk(avg_dout, chunks=2, dim=0)
                    labels = torch.cat([torch.ones(policy_out.size()),
                                        torch.zeros(expert_out.size())])
                    # ensemble loss
                    ens_bce_loss = F.binary_cross_entropy_with_logits(avg_dout, labels)
                    # only update the discriminator who's turn it is (best response dynamics)
                    # if whosturn == env_id: TODO: test this option
                    self.summary_writer.add_scalar('IRL/IRMG GAIL AVG BCE Loss', ens_bce_loss, update_iteration)

                    self.d_dict[whosturn].update(ens_bce_loss)
                    # update either a representation learner (policy?) or an environment model
                else:
                    discriminator.update(loss)

                self.summary_writer.add_scalar('IRL/GAIL BCE Loss', bce_loss_all, update_iteration)
                if opt.irm_coeff > 0:
                    self.summary_writer.add_scalar('IRL/GAIL IRM Loss', grad_pen_all, update_iteration)

            if opt.discriminator_type == 'airl':
                for k in self.env_dict:
                    airl_update_dict = prepare_update_airl(self.env_dict[k], opt,
                                                           self.expert_demos[k],
                                                           policy_state_batch,
                                                           policy_action_batch,
                                                           ppo)

                    # for IRM Games formulation, we compute avg_dout
                    if opt.irmg:
                        _, _, _, d_out = discriminator.forward(airl_update_dict['all_obs'],
                                                               airl_update_dict['all_obs_next'],
                                                               airl_update_dict['all_acs'],
                                                               airl_update_dict['all_lprobs'])
                        d_out_dict[k] = d_out

                    airl_output_dict = discriminator.compute_loss(airl_update_dict)
                    bce_losses[k] = airl_output_dict['d_loss']
                    policy_estimates[k] = airl_output_dict['policy_estimate']
                    expert_estimates[k] = airl_output_dict['expert_estimate']
                    grad_pens[k] = airl_output_dict['grad_penalty']

                bce_loss_all = torch.stack(list(bce_losses.values())).mean()
                policy_estimates_all = torch.stack(list(policy_estimates.values())).mean()
                expert_estimates_all = torch.stack(list(expert_estimates.values())).mean()
                grad_pen_all = torch.stack(list(grad_pens.values())).mean()
                loss = bce_loss_all + opt.irm_coeff * grad_pen_all
                # TODO: is this necessary?
                if opt.irm_coeff > 1.0:
                    loss /= opt.irm_coeff

                # if IRM games formulation:
                if opt.irmg:
                    stack = torch.stack(list(d_out_dict.values()))
                    avg_dout = torch.mean(stack, dim=0)
                    expert_out, policy_out = torch.chunk(avg_dout, chunks=2, dim=0)
                    labels = torch.cat([torch.ones(policy_out.size()),
                                        torch.zeros(expert_out.size())])
                    # ensemble loss
                    ens_bce_loss = F.binary_cross_entropy_with_logits(avg_dout, labels)
                    # only update the discriminator who's turn it is (best response dynamics)
                    # if whosturn == env_id: TODO: test this option
                    self.d_dict[whosturn].update(ens_bce_loss)
                    # update either a representation learner (policy?) or an environment model
                else:
                    discriminator.update(loss)

                self.summary_writer.add_scalar('IRL/AIRL_policy_estimate' + disc_env_id,
                                               policy_estimates[disc_env_id].mean(),
                                               update_iteration)
                self.summary_writer.add_scalar('IRL/AIRL_expert_estimate_' + disc_env_id,
                                               expert_estimates[disc_env_id].mean(),
                                               update_iteration)
                self.summary_writer.add_scalar('IRL/AIRL_bceloss_' + disc_env_id,
                                               bce_losses[disc_env_id].mean(), update_iteration)
                if opt.irm_coeff > 0:
                    self.summary_writer.add_scalar('IRL/AIRL_irmloss_' + disc_env_id,
                                                   grad_pens[disc_env_id].mean(), update_iteration)

            if opt.discriminator_type == 'airl_invrat':
                # get two contrasting environments
                envs = random.sample(list(self.env_dict), 2)
                airl_update_dict = prepare_update_airl(self.env_dict[envs[0]], opt, self.expert_demos[envs[0]],
                                                       policy_state_batch, policy_action_batch, ppo)

                airl_update_dict_e = prepare_update_airl(self.env_dict[envs[1]], opt, self.expert_demos[envs[1]],
                                                         policy_state_batch, policy_action_batch, ppo)

                airli_output_dict = discriminator.compute_loss(airl_update_dict, airl_update_dict_e)
                discriminator.update(airli_output_dict['d_loss'])

                # self.summary_writer.add_scalar('IRL/AIRL_policy_estimate' + env_id,
                #                                airli_output_dict['policy_estimate_aw'], update_iteration)
                # self.summary_writer.add_scalar('IRL/AIRL_expert_estimate_' + env_id,
                #                                airli_output_dict['expert_estimate_aw'], update_iteration)
                # self.summary_writer.add_scalar('IRL/AIRL_policy_estimate_inv_' + env_id,
                #                                airli_output_dict['policy_estimate_inv'], update_iteration)
                # self.summary_writer.add_scalar('IRL/AIRL_expert_estimate_inv_' + env_id,
                #                                airli_output_dict['expert_estimate_inv'], update_iteration)
                self.summary_writer.add_scalar('IRL/AIRL_inv_loss_' + disc_env_id, airli_output_dict['inv_loss'],
                                               update_iteration)
                self.summary_writer.add_scalar('IRL/AIRL_aware_loss_' + disc_env_id, airli_output_dict['aw_loss'],
                                               update_iteration)
                self.summary_writer.add_scalar('IRL/AIRL_diff_loss_' + disc_env_id, airli_output_dict['diff_loss'],
                                               update_iteration)
                if opt.irm_coeff > 0:
                    self.summary_writer.add_scalar('IRL/AIRL_diff_loss_' + disc_env_id, airli_output_dict['diff_loss'],
                                                   update_iteration)

    def eval_on_test_env(self, tenv, ppo):
        test_reward = np.mean([test_env(self.testing_env, ppo) for _ in range(10)])

        print("Testing env reward:", test_reward, self.opt.threshold_reward)
        if test_reward > self.opt.threshold_reward:
            early_stop = True
        else:
            early_stop = False

        return test_reward, early_stop

    def train(self):
        opt = self.opt

        # warm up the replay buffer
        # for i in range(opt.buf_warmup):
        #     for k in self.env_dict:
        #         self.gather_experience(ob[k], k, 0)

        test_rewards = []
        early_stop = False
        total_numsteps = 0

        for i_update in tqdm(range(opt.n_timesteps // opt.n_steps), 'performed ppo updates / episode count'):
            # calculate who's turn it is to update (IRM Games formulation)
            whosturn = list(self.env_dict.keys())[i_update % len(self.env_dict.keys())]

            ## Gather experience using for env k with
            ## respective models and store in buffer associated with that env
            for k in self.env_dict:
                if not opt.use_sb_ppo:
                    self.gather_experience(self.env_dict[k].reset(), k, total_numsteps)
                else:
                    # copy the rollout buffer of the SB implementation
                    if self.ppo_dict[k]._last_obs is None:
                        self.ppo_dict[k]._last_obs = self.env_dict[k].reset()
                    self.ppo_dict[k].collect_rollouts(self.env_dict[k],
                                                      self.callback_dict[k],
                                                      self.ppo_dict[k].rollout_buffer,
                                                      n_rollout_steps=self.opt.n_steps)
                    self.buf_dict[k] = self.ppo_dict[k].rollout_buffer

                total_numsteps += self.opt.n_steps
                # test on testing env every now and then and potentially early stop
                if total_numsteps % 10000 == 0 and not opt.use_sb_ppo:
                    test_reward, early_stop = self.eval_on_test_env(self.testing_env, self.ppo_dict[k])
                    test_rewards.append(test_reward)
                    self.summary_writer.add_scalar('Test Env Reward', test_reward, total_numsteps)
                    if opt.threshold_reward != 0 and early_stop:
                        print("early stopping")
                        break
                # a bit ugly -> break after first iteration if we only train a single policy (first env)
                if opt.single_policy:
                    break

            ## policy updates
            transitions_dict = {}
            for k in self.env_dict:
                # Policy Update
                if opt.use_sb_ppo:
                    self.ppo_dict[k].train()
                    # different buffer spec in SB
                    transitions = self.buf_dict[k].get(opt.batch_size)
                    logger.dump(total_numsteps)
                else:
                    transitions = self.buf_dict[k].iter(opt.batch_size)
                    self.update_policy(opt, k, whosturn, total_numsteps, i_update)
                transitions_dict[k] = transitions
                # a bit ugly -> break after first iteration if we only train a single policy
                if opt.single_policy:
                    break

            ## discriminator updates
            for k in self.env_dict:
                if opt.single_policy:
                    k = list(self.env_dict.keys())[0]
                # Discriminator Update (whosturn only relevant for irmg)
                if opt.discriminator_type is not None:
                    # fix env_id for discriminator only
                    if opt.single_disc:
                        id = list(self.env_dict.keys())[0]
                        self.update_discriminator(opt, transitions_dict[k], 
                                                  ppo_env_id=k,
                                                  disc_env_id=id, 
                                                  update_iteration=total_numsteps,
                                                  whosturn=whosturn)

                    else:
                        self.update_discriminator(opt, transitions_dict[k], 
                                                  ppo_env_id=k,
                                                  disc_env_id=k, 
                                                  update_iteration=total_numsteps,
                                                  whosturn=whosturn)
                else:
                    if i_update == opt.n_timesteps - 1000 or (early_stop and opt.threshold_reward != 0):
                        save_expert_traj(opt, self.testing_env, self.ppo_dict[k],
                                         opt.env_kwargs,
                                         extra_reward_threshold=opt.threshold_reward,
                                         nr_trajectories=opt.num_expert_traj, stable_baselines_model=False)

                if i_update % 10 == 0 and i_update > 1:
                    print(">>> Saving models.." + str(k))

                    if not opt.use_sb_ppo:
                        torch.save(self.ppo_dict[k].ac.state_dict(), os.path.join(self.output_dir,
                                                                                  'policy' + format_name_string(str(k))))

                    if opt.single_disc:
                        sd_id = list(self.env_dict.keys())[0]
                        torch.save(self.d_dict[sd_id].state_dict(), os.path.join(self.output_dir, 'single_disc'))
                    else:
                        torch.save(self.d_dict[k].state_dict(), os.path.join(self.output_dir,
                                                                             'disc' + format_name_string(str(k))))

                    with open(os.path.join(self.output_dir, 'args.json'), 'w') as fp:
                        json.dump(vars(opt), fp)

                # a bit ugly -> break after first iteration if we only train a single policy
                if opt.single_policy and opt.single_disc:
                    break


if __name__ == '__main__':
    p = argparse.ArgumentParser()
    p.add_argument('-c', '--my-config', required=False, help='config file path')
    p.add_argument('--output_dir', default='exp_output', required=False, help='output dir path')
    p.add_argument('--demo_dir', default='demos', required=False, help='demo dir path')
    p.add_argument('--logging_interval', default=10, required=False, help='tensorboard logging interval')
    p.add_argument('--sb_config', default='config/config_sb3.yaml', required=False, help='sb config file path')
    p.add_argument('--exp_id', default='', required=False, help='unique experiment identifier')
    p.add_argument('--env_name', default='Pendulum-v0', required=False, help='environment name')
    p.add_argument('--env_kwargs', nargs='+', required=False, help='environment variation (e.g. gravity)')
    p.add_argument('--env_spec_test', required=False, help='environment variation for test (e.g. gravity)')
    p.add_argument('--n_envs', default=16, type=int, required=False, help='number of parallel envs')
    p.add_argument('--use_subprocess', default=False, required=False, action='store_true',
                   help='use subproc for parallel envs')
    p.add_argument('--seed', default=42, required=False, help='seed')
    p.add_argument('--use_seed_ranking', default=True, required=False, action='store_false',
                   help='seed ranking for vecenv')
    p.add_argument('--ac_base_type', default='flat', required=False, help='actor critic base network type')
    p.add_argument('--ac_layer_dims', default=[32, 32], nargs='+', required=False,
                   help='actor critic MLP layer dimensions')
    p.add_argument('--d_layer_dims', default=[128, 128], nargs='+', required=False,
                   help='discriminator MLP layer dimensions')
    p.add_argument('--lr', default=3e-4, type=float, required=False, help='learning rate')
    p.add_argument('--gae_lambda', default=0.95, type=float, required=False, help='GAE lambda')
    p.add_argument('--gamma', default=0.99, type=float, required=False, help='discount factor GAE / AIRL')
    p.add_argument('--ppo_eps', default=0.2, type=float, required=False, help='ppo update epsilon')
    p.add_argument('--ppo_vcoeff', default=0.5, type=float, required=False, help='ppo value loss coeff')
    p.add_argument('--ppo_entcoeff', default=0.0, type=float, required=False, help='ppo entropy loss coeff')
    p.add_argument('--ppo_vfclip', default=0.5, type=float, required=False, help='ppo value function clipping')
    p.add_argument('--max_grad_norm', default=0.5, type=float, required=False, help='max grad norm')
    p.add_argument('--buf_warmup', default=10, type=int, required=False, help='number of stored transitions in buffer')
    p.add_argument('--n_steps', default=16, type=int, required=False, help='number of policy steps')
    p.add_argument('--n_epochs', default=4, type=int, required=False, help='number of ppo update epochs')
    p.add_argument('--n_irl_epochs', default=4, type=int, required=False, help='number of ppo update epochs')
    p.add_argument('--num_expert_traj', default=10, nargs='+', type=int, required=False, help='number of expert trajectories used for training')
    p.add_argument('--pref_epochs', default=5, type=int, required=False, help='number of preference training epochs')
    p.add_argument('--nr_pairs', default=5000, type=int, required=False, help='number of pairs used for preference training')
    p.add_argument('--batch_size', default=16, type=int, required=False, help='batch size')
    p.add_argument('--threshold_reward', default=0, type=int, required=False,
                   help='reward threshold for expert generation')
    p.add_argument('--discriminator_type', required=False,
                   help='discriminator type (GAIL/AIRL/Null(ppo))')
    p.add_argument('--use_actions', default=True, required=False, action='store_false',
                   help='use actions in discriminator')
    p.add_argument('--extra_demos', default=False, required=False, action='store_true',
                   help='use extra demos')
    p.add_argument('--minigrid_wrapper', default='flat', required=False, help='use img wrapper')
    p.add_argument('--irm_coeff', default=0, type=float, required=False, help='use IRM penalty on discriminator')
    p.add_argument('--lip_coeff', default=0, type=float, required=False, help='use gradient penalty on discriminator')
    p.add_argument('--ipo', default=False, required=False, action='store_true',
                   help='use IRM Games formulation for policy opt (IPO')
    p.add_argument('--irmg', default=False, required=False, action='store_true',
                   help='use IRM Games formulation')
    p.add_argument('--invrat', default=False, required=False, action='store_true',
                   help='use Invariant Rationalization formulation')
    p.add_argument('--n_timesteps', default=1000000, type=int, required=False, help='max number of PPO iterations')
    p.add_argument('--train_using_sb', default=False, action='store_true',
                   required=False, help='use stable baselines for baseline policy training')
    p.add_argument('--use_sb_ppo', default=True, action='store_true',
                   required=False, help='use stable baselines for policy during IRL training')
    p.add_argument('--single_policy', default=False, action='store_true',
                   required=False, help='use single policy to interact with different environment discriminators')
    p.add_argument('--single_disc', default=True, action='store_true',
                   required=False, help='use single discriminator to interact with different environment policies')
    p.add_argument('--train_pref', default=False, required=False, action='store_true',
                   help='train preference reward')
    p.add_argument('--load_sb_model', default=False, required=False, action='store_true',
                   help='load SB model')
    p.add_argument('--resume', required=False, help='resume from checkpoint (dirpath required)')
    p.add_argument('--pretrained_pref_reward', required=False, help='use pretrained preference reward')
    p.add_argument('--pretrained_irl_reward', required=False, help='use pretrained irl reward')
    p.add_argument('-v', help='verbose', action='store_true')

    args = p.parse_args()

    # set seeds
    args.seed = int(args.seed)
    if isinstance(args.ac_layer_dims[0], str):
        args.ac_layer_dims = [int(d) for d in args.ac_layer_dims]
    if isinstance(args.d_layer_dims[0], str):
        args.d_layer_dims = [int(d) for d in args.d_layer_dims]

    if torch.cuda.is_available():
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        torch.set_default_tensor_type(torch.FloatTensor)
    # set all the seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    if torch.cuda.is_available():
        torch.device('cuda')


    if args.resume is not None:
        # load arguments from checkpoint
        opt_ckpt = json.load(open(os.path.join(args.resume, 'args.json')))
        opt_ckpt['resume'] = args.resume
        args = Namespace(**opt_ckpt)
    print("-"*100)
    print("   Experiment parameters: ")
    print("-"*100)
    pprint.pprint(vars(args), indent=5)
    print("-"*100)

    # convert env_kwargs string argument to dict
    if args.env_kwargs is not None:
        if args.env_kwargs[0] == 'parse_dir':
            tag = args.env_name.split("-")[0].lower()
            d = os.path.join('pybullet-gym/pybulletgym/envs/assets/mjcf', tag)
            # additional filter tag for specific interventions (mass, leglength)
            if len(args.env_kwargs) > 1:
                filt = args.env_kwargs[1]
                files = [f for f in os.listdir(d) if filt in f]
            else:
                files = os.listdir(d)

            env_kwargs = []
            for f in files:
                print(">>> Parsing directory for env files: ", f)
                env_kwargs.append({'xml_file': os.path.join(tag, f)})
            args.env_kwargs = env_kwargs
            print("-" * 100)

        elif args.env_kwargs[0] == 'gen_interv':
            type = args.env_kwargs[1]
            range_min = float(args.env_kwargs[2])
            range_max = float(args.env_kwargs[3])
            number = int(args.env_kwargs[4])
            env_kwargs = []
            for i in range(number):
                val = np.random.uniform(range_min, range_max)
                env_kwargs.append({type: val})

            args.env_kwargs = env_kwargs
            print(">> Generated interventions")
            print(args.env_kwargs)
            print("-" * 100)

        else:
            env_kwargs = []
            for env_kwarg in args.env_kwargs:
                if args.resume:
                    env_kwargs.append(env_kwarg)
                else:
                    env_kwargs.append(json.loads(env_kwarg))
            args.env_kwargs = env_kwargs
    #else:
    #    args.env_kwargs = [""]

    if args.env_spec_test is not None and not args.resume:
        args.env_spec_test = json.loads(args.env_spec_test)


    pref_ckpt_dir = './sb_models/ckpt_' + args.env_name

    if args.train_using_sb:
        train_using_sb(args,
                       pref_reward_model=args.pretrained_pref_reward,
                       irl_reward_model=args.pretrained_irl_reward)
    elif args.train_pref:
        if not os.path.exists('demos/preference_learning'):
            os.mkdir('demos/preference_learning')
        if not os.path.exists(pref_ckpt_dir):
            train_using_sb(args, save_checkpoints_for_pl=pref_ckpt_dir)
        for spec in args.env_kwargs:
            if len(get_env_demo_files('demos/preference_learning', args.env_name, spec)) == 0:
                save_ranked_expert_demos(args, pref_ckpt_dir, spec)

        pref_rew = PreferenceReward(args)
        pref_rew.train(args)
    else:
        sb_yml = open(args.sb_config)
        sb_args = yaml.load(sb_yml)[args.env_name]
        for key in ['n_epochs', 'n_steps', 'gae_lambda', 'batch_size', 'gamma', 'n_envs']:
            vars(args)[key] = sb_args[key]

        trainer = Trainer(args)
        trainer.train()
