import copy
import os
from collections import deque
import time
from datetime import datetime

import gym
from gym.wrappers import TimeLimit
from mpi4py import MPI
from scipy.special import softmax
from sklearn.metrics import mean_squared_error
from stable_baselines.common import fmt_row, dataset, tf_util
import numpy as np
from stable_baselines import logger
import tensorflow as tf
from stable_baselines.common.misc_util import flatten_lists, set_global_seeds
from stable_baselines.common.mpi_adam import MpiAdam

from causal_irl.algorithms.common.base_class import SetVerbosity, BaseRLModel, ActorCriticRLModel
from causal_irl.algorithms.common.policies import ActorCriticPolicy
from causal_irl.algorithms.common.runners import traj_segment_generator
from causal_irl.algorithms.common.utils import name_to_models, extended_calibration_curve, plot_calibration
from causal_irl.algorithms.gail import generate_expert_traj
from causal_irl.algorithms.fusion.my_adversary import TransitionClassifier
from causal_irl.algorithms.gail.model import GAIL
from causal_irl.algorithms.trpo_mpi import TRPO
from causal_irl.envs.noisy_action_wrapper import NoisyActionWrapper
from causal_irl.envs.my_observation_wrapper import MyObservationWrapper
from causal_irl.envs.noisy_observation_wrapper import NoisyObservationWrapper

date = '{}-{}'.format(datetime.now().month, datetime.now().day)
class RL_GAN(TRPO):

    def __init__(self, policy, env, discriminator=None, generator_name = None, generator_path=None, _init_setup_model=True, expert_dataset=None, policy_dataset=None,
                 timesteps_per_batch=128, d_stepsize=3e-4, val_interval=None, verbose=0, seed=None, save_path=None, calibration=False,
                 rejection_sampling=True, mpc=False, using_airl=False, using_random=False, code='001', plan_interval=1000, look_ahead=1000, n_rollout_envs=10):
        # Policy here is useless
        super(RL_GAN, self).__init__(policy, env, verbose, seed, _init_setup_model=False)
        self.env = env
        self.generator_name = generator_name
        self.generator_path = generator_path
        self.discriminator = discriminator
        self.expert_dataset = expert_dataset
        self.val_interval = val_interval
        self.verbose = verbose
        self.seed = seed
        self.timesteps_per_batch = timesteps_per_batch
        # self.policy_dataset = policy_dataset
        self.d_stepsize = d_stepsize
        self.save_path = save_path
        self.rejection_sampling = rejection_sampling
        self.calibration = calibration
        self.code = code
        self.using_airl = using_airl
        self.using_random = using_random
        # added
        self.mpc = mpc
        self.saved_rollout = {}
        self.plan_interval = plan_interval
        self.look_ahead = look_ahead
        self.n_rollout_envs = n_rollout_envs

        self.params = None
        self.trained_epochs = 0
        if _init_setup_model:
            self.setup_model()


    def set_random_seed(self, seed) -> None:
        """
        :param seed: (Optional[int]) Seed for the pseudo-random generators. If None,
            do not change the seeds.
        """
        # Ignore if the seed is None
        if seed is None:
            return
        # Seed python, numpy and tf random generator
        set_global_seeds(seed)
        if self.env is not None:
            self.env.seed(seed)
            # Seed the action space
            # useful when selecting random actions
            self.env.action_space.seed(seed)

    def setup_model(self):

        with SetVerbosity(self.verbose):

            self.nworkers = MPI.COMM_WORLD.Get_size()
            self.rank = MPI.COMM_WORLD.Get_rank()
            np.set_printoptions(precision=3)

            self.generator = name_to_models[self.generator_name].load(
                os.path.join(os.getcwd(), 'results', os.path.split(os.path.split(os.path.split(self.generator_path)[0])[0])[1],
                             self.generator_name, os.path.split(self.generator_path)[1]))
            if self.discriminator is None:
                self.graph = tf.Graph()
            else:
                self.graph = self.discriminator.graph
                self.reward_giver = self.discriminator.reward_giver
            # import pdb; pdb.set_trace()
            with self.graph.as_default():
                self.set_random_seed(self.seed)
                if self.discriminator is None:
                    self.sess = tf_util.make_session(graph=self.graph,
                                                 make_default=True)
                else:
                    self.sess = self.discriminator.sess
                    self.sess.__enter__()
                if self.discriminator is None:
                    assert not self.using_airl, "Can't create AIRL discriminator yet"
                    self.reward_giver = TransitionClassifier(self.observation_space, self.action_space,
                                                         self.hidden_size_adversary,
                                                         entcoeff=self.adversary_entcoeff, normalize=False)
                with tf.variable_scope("Adam_mpi", reuse=False):
                        self.d_adam = MpiAdam(self.reward_giver.get_trainable_variables(), sess=self.sess)
                        # self.d_adam.sync()

                def allmean(arr):
                    assert isinstance(arr, np.ndarray)
                    out = np.empty_like(arr)
                    MPI.COMM_WORLD.Allreduce(arr, out, op=MPI.SUM)
                    out /= self.nworkers
                    return out

                self.allmean = allmean

                self.actions_scores = []
                self.rewards = []

                self.envs = [gym.make(self.env.unwrapped.spec.id) for i in range(self.n_rollout_envs + 1)]
                for i in range(len(self.envs)):
                    # import pdb; pdb.set_trace()
                    if self.env.unwrapped.spec.id == 'Pendulum-v0':
                        self.envs[i]._max_episode_steps = 100
                    if isinstance(self.env, MyObservationWrapper):
                        # this also should be confounded
                        self.envs[i] = MyObservationWrapper(self.envs[i], "confounded")
                    self.envs[i].reset()
                def step(obs, state=None, mask=None, deterministic=False):
                    # if self.rejection_sampling:
                    #     print("rejection_sampling !!!!!")
                    #     # repeat it observations 10 times parallelize
                    #     if self.env.unwrapped.spec.id == 'Pendulum-v0':
                    #         saved_state = copy.deepcopy(self.env.env.state)
                    #     else:
                    #         saved_state = copy.deepcopy(self.env.sim.get_state())
                    #
                    #     for e in self.envs:
                    #         if self.env.unwrapped.spec.id == 'Pendulum-v0':
                    #             e.env.state = saved_state
                    #         else:
                    #             e.sim.set_state(saved_state)
                    #     # import pdb;
                    #     # pdb.set_trace()
                    #     gen_results = [self.generator.step(obs, state, mask, False) for _ in range(self.n_rollout_envs)]
                    #     gen_actions = [result[0] for result in gen_results]
                    #     gen_value = [result[1] for result in gen_results]
                    #     gen_snew = [result[2] for result in gen_results]
                    #     gen_nlogp = [result[3] for result in gen_results]
                    #
                    #     env_results = [self.envs[i].step(gen_actions[i]) for i in range(len(gen_results))]
                    #     nobs = [env_result[0].squeeze() for env_result in env_results]
                    #     rewards = [env_result[1].squeeze().item() for env_result in env_results]
                    #     # print(gen_actions)
                    #     if self.using_airl:
                    #         action_scores = np.array([self.reward_giver.get_reward(obs, gen_actions[i], -gen_nlogp[i], nobs[i]) for i in range(len(gen_actions))])
                    #     else:
                    #         action_scores = np.array([self.reward_giver.get_discrim_logit(obs, action) for action in gen_actions]).squeeze()
                    #
                    #     print("Action Scores: \n")
                    #     print(action_scores)
                    #     print("Rewards: \n")
                    #     print(rewards)
                    #
                    #     self.actions_scores.extend(action_scores)
                    #     action_dist = softmax(action_scores).squeeze()
                    #     action_dist = action_dist.astype(float)
                    #     action_dist /= action_dist.sum()
                    #     # print(action_dist.shape)
                    #     # print(action_dist)
                    #
                    #     self.rewards.extend(rewards)
                    #     idx_reward = np.argmax(rewards)
                    #     assert len(action_dist.shape) == 1
                    #     idx = np.argmax(np.random.multinomial(1, action_dist))
                    #     print("Action idx: \n")
                    #     print(idx)
                    #     print("Max score idx: \n")
                    #     print(idx_reward)
                    #     pred_action = gen_actions[idx]
                    #     pred_value = gen_value[idx]
                    #     pred_snew = gen_snew[idx]
                    #     pred_nlogp = gen_nlogp[idx]
                    #     # print(pred_action)
                    #     return pred_action, pred_value, pred_snew, pred_nlogp
                    # elif self.mpc:
                    print("rejection_sampling MPC !!!!!")
                    is_wrapped = isinstance(self.env.env, MyObservationWrapper) or isinstance(self.env.env, NoisyActionWrapper) or isinstance(self.env.env, NoisyObservationWrapper)
                    if self.env.unwrapped.spec.id == 'Pendulum-v0':
                        saved_state = copy.deepcopy(self.env.unwrapped.state)
                    else:
                        saved_state = copy.deepcopy(self.env.unwrapped.sim.get_state())
                    for i in range(len(self.envs)):
                        self.envs[i].reset()
                        if isinstance(self.env.env, MyObservationWrapper):
                            self.envs[i].env._elapsed_steps = self.env.env.env._elapsed_steps
                            self.envs[i].prev_action = self.env.env.prev_action
                        elif isinstance(self.env.env, NoisyActionWrapper) or isinstance(self.env.env, NoisyObservationWrapper):
                            self.envs[i]._elapsed_steps = self.env.env.env._elapsed_steps

                        if self.env.unwrapped.spec.id == 'Pendulum-v0':
                            self.envs[i].unwrapped.state = saved_state
                        else:
                            self.envs[i].unwrapped.sim.set_state(saved_state)


                    plan_now = is_wrapped and self.env.env.env._elapsed_steps % self.plan_interval == 0
                    plan_now = plan_now or (not is_wrapped and self.env.env._elapsed_steps == 0)

                    if not plan_now:
                        action, value, state, logp = self.generator.step(obs, state, mask, deterministic)
                        if self.using_random:
                            action = self.get_random_action()
                        return action, value, state, logp
                    # BUFFER SECITON
                    # if not plan_now and self.saved_rollout["ep_len"] > 0:
                    #     print(self.saved_rollout["ep_len"])
                    #     pred_action = self.saved_rollout["acs_list"].pop(0)
                    #     pred_value = self.saved_rollout["val_list"].pop(0)
                    #     pred_snew = self.saved_rollout["state_list"].pop(0)
                    #     pred_nlogp = self.saved_rollout["nlogp_list"].pop(0)
                    #     self.saved_rollout["ep_len"] -= 1
                    #     self.saved_rollout["ep_ret"] -= self.saved_rollout["rew_list"].pop(0)
                    #     self.saved_rollout["true_ep_ret"] -= self.saved_rollout["true_rew_list"].pop(0)
                    #     return pred_action, pred_value, pred_snew, pred_nlogp

                    gen_results = [self.generator.step(obs, state, mask, False) for _ in range(self.n_rollout_envs)]
                    gen_actions = [result[0] for result in gen_results]
                    gen_value = [result[1] for result in gen_results]
                    gen_snew = [result[2] for result in gen_results]
                    gen_nlogp = [result[3] for result in gen_results]
                    if self.using_random:
                        gen_actions = [self.get_random_action() for _ in range(len(gen_results))]
                    rollout_results = [self.get_rollout_return(self.envs[i],
                                                               obs,
                                                               gen_actions[i],
                                                               gen_value[i],
                                                               gen_snew[i],
                                                               gen_nlogp[i],
                                                               max_t=self.look_ahead) for i in range(len(gen_results))]
                    # import pdb; pdb.set_trace()
                    # BUFFER SECTION
                    if self.saved_rollout != {} and self.saved_rollout["ep_len"] > 0:
                        self.update_buffer(self.envs[-1], obs, self.saved_rollout, self.look_ahead)
                        print("Updated Buffer: {}".format((self.saved_rollout["ep_ret"], self.saved_rollout["true_ep_ret"], self.saved_rollout["ep_len"])))
                        rollout_results.append(self.saved_rollout)

                    print([(result["ep_ret"], result["true_ep_ret"], result["ep_len"]) for result in rollout_results])
                    rollout_returns = [result["ep_ret"] for result in rollout_results]
                    idx_return = np.argmax(rollout_returns)
                    print("Picked Trajectory: {}".format(idx_return))
                    # BUFFER SECITON
                    if idx_return == len(self.envs) - 1:
                        print("Picked saved rollout")
                    self.saved_rollout = rollout_results[idx_return]
                    pred_action = self.saved_rollout["acs_list"].pop(0)
                    pred_value = self.saved_rollout["val_list"].pop(0)
                    pred_snew = self.saved_rollout["state_list"].pop(0)
                    pred_nlogp = self.saved_rollout["nlogp_list"].pop(0)
                    self.saved_rollout["ep_len"] -= 1
                    self.saved_rollout["ep_ret"] -= self.saved_rollout["rew_list"].pop(0)
                    self.saved_rollout["true_ep_ret"] -= self.saved_rollout["true_rew_list"].pop(0)
                    return pred_action, pred_value, pred_snew, pred_nlogp
                    # else:
                    #     return self.generator.step(obs, state, mask, deterministic)
                tf_util.initialize(sess=self.sess)
                self.step = step
                self.params = self.reward_giver.get_trainable_variables()
    
    def _initialize_dataloader(self):
        """Initialize dataloader."""
        self.expert_dataset.init_dataloader(self.timesteps_per_batch)
        self.policy_dataset.init_dataloader(self.timesteps_per_batch)

    def _get_pretrain_placeholders(self):
        """
        Return the placeholders needed for the pretraining:
        - obs_ph: observation placeholder
        - actions_ph will be population with an action from the environment
            (from the expert dataset)
        - deterministic_actions_ph: e.g., in the case of a Gaussian policy,
            the mean.

        :return: ((tf.placeholder)) (obs_ph, actions_ph, deterministic_actions_ph)
        """
        return self.generator._get_pretrain_placeholders()

    def get_random_action(self):
        action = np.atleast_1d((self.action_space.high - self.action_space.low) * np.random.random_sample(
            self.action_space.shape[0]) + self.action_space.low)
        return action.reshape(-1, *action.shape).astype('float32')

    def update_buffer(self, env, ob, rollout_buffer, max_t):
        done = False
        actions = rollout_buffer["acs_list"]
        episode_return = 0.0
        true_return = 0.0
        rew_list = []
        true_rew_list = []

        i = 0
        while i < max_t:
            old_ob = ob
            if i < len(actions):
                ac = actions[i]
            else:
                ac, val, state, nlogp = self.generator.step(ob.reshape(-1, *ob.shape), deterministic=True)
                if self.using_random:
                    ac = self.get_random_action()
                rollout_buffer["acs_list"].append(copy.deepcopy(ac))
                #     obs_list.append(copy.deepcopy(ob))
                rollout_buffer["val_list"].append(copy.deepcopy(val))
                rollout_buffer["state_list"].append(copy.deepcopy(state))
                rollout_buffer["nlogp_list"].append(copy.deepcopy(nlogp))

            ob, true_reward, done, _info = env.step(ac.squeeze(axis=0))
            reward = self.reward_giver.get_reward(old_ob, ac).squeeze()
            if i+1 == max_t and not done:
                reward += self.generator.step(ob.reshape(-1, *ob.shape), deterministic=True)[1].item()

            episode_return += reward
            true_return += true_reward
            rew_list.append(reward)
            true_rew_list.append(true_reward)
            i += 1
            if done:
               break


        rollout_buffer["ep_len"] = i
        rollout_buffer["ep_ret"] = episode_return
        rollout_buffer["true_ep_ret"] = true_return
        rollout_buffer["rew_list"] = rew_list
        rollout_buffer["true_rew_list"] = true_rew_list





    def get_rollout_return(self, env, ob, ac, val, state, nlogp, max_t=1000):
        # consider reducing max_t overtime
        done = False
        episode_return = 0.0
        true_return = 0.0
        rollout_info = {}
        # obs_list = []
        acs_list = []
        val_list = []
        state_list = []
        nlogp_list = []
        rew_list = []
        true_rew_list = []

        t = 0
        while not done:
            # BUFFER SECTION
            acs_list.append(copy.deepcopy(ac))
            #     obs_list.append(copy.deepcopy(ob))
            val_list.append(copy.deepcopy(val))
            state_list.append(copy.deepcopy(state))
            nlogp_list.append(copy.deepcopy(nlogp))

            old_ob = ob
            ob, true_reward, done, _info = env.step(ac.squeeze(axis=0))
            true_return += true_reward
            if self.using_airl:
                reward = self.reward_giver.get_reward(old_ob, ac, -nlogp, ob).item()
            else:
                reward = self.reward_giver.get_reward(old_ob, ac).squeeze()
            ac, val, state, nlogp = self.generator.step(ob.reshape(-1, *ob.shape), deterministic=True)
            if self.using_random:
                ac = self.get_random_action()
            if max_t == 0:
                reward = val.item()
            elif t+1 == max_t and not done:
                reward += val.item()
            episode_return += reward

            # BUFFER SECTION
            rew_list.append(reward)
            true_rew_list.append(true_reward)

            t += 1
            done = done or t >= max_t
            if done:
                break
        rollout_info = {"acs_list": acs_list,
                        # "obs_list": obs_list,
                        "val_list": val_list,
                        "state_list": state_list,
                        "nlogp_list": nlogp_list,
                        "ep_len": t,
                        "true_ep_ret": true_return,
                        "ep_ret": episode_return,
                        "rew_list": rew_list,
                        "true_rew_list": true_rew_list
                        }
        return rollout_info

    def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="run",
              reset_num_timesteps=True):
        # in rl_gan, total_timestpes in deemed to be n_epochs
        n_epochs = total_timesteps
        assert self.expert_dataset is not None, "You must pass an expert dataset for training"
        assert self.policy_dataset is not None, "You must pass an policy dataset for training"
        # Validate the model every 10% of the total number of iteration
        if self.val_interval is None:
            # Prevent modulo by zero
            if n_epochs < 100:
                self.val_interval = 1
            else:
                self.val_interval = int(n_epochs / 10)
        timesteps_so_far = 0
        t_start = time.time()

        self._initialize_dataloader()
        best_val_loss = np.inf
        best_epoch = 0

        # saving errors:
        train_losses = []
        validation_losses = []

        for i in range(n_epochs):
            assert not self.using_airl, "Can't train AIRL yet"
            logger.log("********** Epoch %i ************" % i)

            # ------------------ Update D ------------------
            logger.log("Optimizing Discriminator...")
            logger.log(fmt_row(13, self.reward_giver.loss_name))

            # NOTE: uses only the last g step for observation
            d_losses = []  # list of tuples, each of which gives the loss for a minibatch
            # NOTE: for recurrent policies, use shuffle=False?
            for _ in range(len(self.policy_dataset.train_loader)):
                ob_batch, ac_batch = self.policy_dataset.get_next_batch(split='train')
                ob_expert, ac_expert = self.expert_dataset.get_next_batch(split='train')
                # print(ac_expert)
                if ob_expert.shape[0] < ob_batch.shape[0]:
                    n_repeat = ob_batch.shape[0] // ob_expert.shape[0]
                    ob_expert, ac_expert = np.repeat(ob_expert, n_repeat, axis=0), np.repeat(ac_expert, n_repeat, axis=0)
                # update running mean/std for reward_giver
                if self.reward_giver.normalize:
                    self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))
                # Reshape actions if needed when using discrete actions
                if isinstance(self.env.action_space, gym.spaces.Discrete):
                    if len(ac_batch.shape) == 2:
                        ac_batch = ac_batch[:, 0]
                    if len(ac_expert.shape) == 2:
                        ac_expert = ac_expert[:, 0]

                *newlosses, grad = self.reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert)
                self.d_adam.update(self.allmean(grad), self.d_stepsize)
                d_losses.append(newlosses)

                timesteps_so_far += self.timesteps_per_batch
                if _ % 10 == 0:
                    logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

            mean_train_loss = np.mean(d_losses, axis=0)
            train_losses.append(mean_train_loss[0] + mean_train_loss[1] + mean_train_loss[3])

            val_losses = []
            if self.val_interval is not None and i % self.val_interval == 0:
                logger.log("Validating Discriminator...")

                for _ in range(len(self.policy_dataset.val_loader)):
                    batch_obs, batch_actions = self.policy_dataset.get_next_batch('val')
                    expert_obs, expert_actions = self.expert_dataset.get_next_batch('val')
                    # output = self.reward_giver.get_discrim_logit(batch_obs, batch_actions)
                    # print("Discriminator output: ", 1 / (1 + np.exp(-output)))
                    # output2 = self.reward_giver.get_discrim_logit(expert_obs, expert_actions)
                    # print("Discriminator output2: ", 1 / (1 + np.exp(-output2)))
                    # print(expert_actions)
                    if expert_obs.shape[0] < batch_obs.shape[0]:
                        n_repeat = batch_obs.shape[0] // expert_obs.shape[0]
                        expert_obs, expert_actions = np.repeat(expert_obs, n_repeat, axis=0), np.repeat(expert_actions, n_repeat, axis=0)


                    # output = self.reward_giver.get_discrim_logit(batch_obs, batch_actions)
                    # print("Discriminator output: ", output[:10])
                    # output2 = self.reward_giver.get_discrim_logit(expert_obs, expert_actions)
                    # print("Discriminator output2: ", output2[:10])
                    *newlosses, grad = self.reward_giver.lossandgrad(batch_obs, batch_actions, expert_obs, expert_actions)
                    val_losses.append(newlosses)
                mean_val_losses = np.mean(val_losses, axis=0)
                logger.log(fmt_row(13, mean_val_losses))
                val_loss = mean_val_losses[0] + mean_val_losses[1] + mean_val_losses[3]
                validation_losses.append(val_loss)
                if val_loss < best_val_loss:
                    logger.log("Found a better Discriminator from validation, saving at epoch{}".format(i))
                    best_val_loss = val_loss
                    self.save(os.path.join(self.save_path, self.code+"_best_rl_gan_{}_traj{}_{}_{}".format(self.env.unwrapped.spec.id,
                                                                                                self.expert_dataset.num_traj,
                                                                                                self.env.mode if isinstance(self.env.env, MyObservationWrapper) else 'testing',
                                                                                                self.seed)))
                    best_epoch = i
                    # print(self.get_parameters())
            if self.calibration and i == best_epoch:
                self._initialize_dataloader()
                pred_probs = []
                targets = []
                for _ in range(len(self.policy_dataset.val_loader)):
                    batch_obs, batch_actions = self.policy_dataset.get_next_batch('val')
                    expert_obs, expert_actions = self.expert_dataset.get_next_batch('val')
                    if expert_obs.shape[0] < batch_obs.shape[0]:
                        n_repeat = batch_obs.shape[0] // expert_obs.shape[0]
                        expert_obs, expert_actions = np.repeat(expert_obs, n_repeat, axis=0), np.repeat(expert_actions, n_repeat, axis=0)
                    # print(batch_obs[:10], batch_actions[:10])
                    # print(expert_obs[:10], expert_actions[:10])
                    output = self.reward_giver.get_discrim_logit(batch_obs, batch_actions)
                    # print("Discriminator output: ", output[:10])
                    pred_probs.append(1 / (1 + np.exp(-output)))
                    targets.append(np.zeros_like(output))
                    output2 = self.reward_giver.get_discrim_logit(expert_obs, expert_actions)
                    # print("Discriminator output2: ", output2[:10])
                    pred_probs.append(1 / (1 + np.exp(-output2)))
                    targets.append(np.ones_like(output2))
                targets = np.array(targets).squeeze()
                pred_probs = np.array(pred_probs).squeeze()
                targets = np.concatenate(targets)
                pred_probs = np.concatenate(pred_probs)
                # print(pred_probs)
                fraction_of_positives, mean_predicted_value, bin_weights = extended_calibration_curve(targets,
                                                                                                      pred_probs,
                                                                                                      n_bins=10)
                calibration_error = mean_squared_error(fraction_of_positives, mean_predicted_value,
                                                       sample_weight=bin_weights)
                print('Calibration error', calibration_error)

                # if savefile is not None:
                cal_savefile = os.path.join(os.getcwd(), 'results', '8-5', 'calib', "calibration_in_{}".format(best_epoch))
                plot_calibration(fraction_of_positives, mean_predicted_value, cal_savefile)
            self.trained_epochs += 1
        return range(n_epochs), train_losses, validation_losses

    def save(self, save_path, cloudpickle=False):
        data = {
            "policy": self.policy,
            "n_envs": self.n_envs,
            "generator_name": self.generator_name,
            "generator_path": self.generator_path,
            "verbose": self.verbose,
            "seed": self.seed,
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "trained_epochs": self.trained_epochs,
            "val_interval": self.val_interval,
            "save_path": self.save_path,
            "lr": self.d_stepsize,
            "num_epochs": self.trained_epochs,
            "code": self.code,
            "using_airl": self.using_airl,
            "using_random": self.using_random,
            "plan_interval": self.plan_interval,
            "look_ahead": self.look_ahead,
            "n_rollout_envs": self.n_rollout_envs
        }

        params_to_save = self.get_parameters()

        self._save_to_file(save_path, data=data, params=params_to_save, cloudpickle=cloudpickle)

    def get_parameter_list(self):
        return self.params
