import numpy as np
from deep_sprl.teachers.util import Buffer
from deep_sprl.teachers.abstract_teacher import BaseWrapper

class SelfPacedWrapper(BaseWrapper):

    def __init__(self, env, sp_teacher, discount_factor, context_visible, max_context_buffer_size=1000,
                 reset_contexts=True, reward_from_info=False, use_undiscounted_reward=False):
        self.use_undiscounted_reward = use_undiscounted_reward
        BaseWrapper.__init__(self, env, sp_teacher, discount_factor, context_visible, reward_from_info=reward_from_info)

        self.context_buffer = Buffer(3, max_context_buffer_size, reset_contexts)

    def done_callback(self, step, cur_initial_state, cur_context, discounted_reward, undiscounted_reward):
        if self.use_undiscounted_reward:
            self.context_buffer.update_buffer((cur_initial_state, cur_context, undiscounted_reward))
        else:
            self.context_buffer.update_buffer((cur_initial_state, cur_context, discounted_reward))

    def get_context_buffer(self):
        ins, cons, disc_rews = self.context_buffer.read_buffer()
        return np.array(ins), np.array(cons), np.array(disc_rews)
