from collections import deque

from causal_irl.algorithms.trpo_mpi.trpo_mpi import TRPO
import tensorflow as tf
import gym

class AIRL(TRPO):
    """
    :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...)
    :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
    :param expert_dataset: (ExpertDataset) the dataset manager
    :param gamma: (float) the discount value
    :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon)
    :param max_kl: (float) the Kullback-Leibler loss threshold
    :param cg_iters: (int) the number of iterations for the conjugate gradient calculation
    :param lam: (float) GAE factor
    :param entcoeff: (float) the weight for the entropy loss
    :param cg_damping: (float) the compute gradient dampening factor
    :param vf_stepsize: (float) the value function stepsize
    :param vf_iters: (int) the value function's number iterations for learning
    :param hidden_size: ([int]) the hidden dimension for the MLP
    :param g_step: (int) number of steps to train policy in each epoch
    :param d_step: (int) number of steps to train discriminator in each epoch
    :param d_stepsize: (float) the reward giver stepsize
    :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
    :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
    :param full_tensorboard_log: (bool) enable additional logging when using tensorboard
        WARNING: this logging can take a lot of space quickly
    """

    def __init__(self, policy, env, expert_dataset=None,
                 hidden_size_adversary=32,
                 val_interval=None, save_path=None,
                 g_step=1, d_step=50, d_stepsize=1e-3, verbose=0,
                 _init_setup_model=True, fusion=False, exp_code='000', **kwargs):
        # for pendulum
        # d_step = 50,
        # for ant
        # d_step = 10
        super().__init__(policy, env, verbose=verbose,
                         timesteps_per_batch=1000,
                         vf_stepsize=0.1,
                         _init_setup_model=False, **kwargs)
        # for ant
        # super().__init__(policy, env, verbose=verbose,
        #                  _init_setup_model=False, **kwargs)
        self.using_airl = True
        self.exp_code = exp_code
        self.expert_dataset = expert_dataset
        self.g_step = g_step
        self.d_step = d_step
        self.d_stepsize = d_stepsize
        self.hidden_size_adversary = hidden_size_adversary
        self.val_interval = val_interval
        self.save_path = save_path
        self.fusion=fusion
        if self.fusion:
            self.traj_buffer = []

        if _init_setup_model:
            self.setup_model()

    def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="AIRL",
              reset_num_timesteps=True):
        assert self.expert_dataset is not None, "You must pass an expert dataset to AIRL for training"
        return super().learn(total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps)
