from EscapeEnv.common.base_agent import BaseAgent
from EscapeEnv.rpn.estimator import RPNEstimator
from EscapeEnv.common.torch_layers import QNetwork
from EscapeEnv.common.buffers import BaseBuffer
from EscapeEnv.rpn.network import PriorEnsembleQNetwork

class RPN(BaseAgent):
    def __init__(self, 
                 env,
                 estimator_class=RPNEstimator, 
                 buffer_class=BaseBuffer,
                 network_class=PriorEnsembleQNetwork, 
                 **kwargs):
        super().__init__(env, estimator_class, buffer_class, network_class, **kwargs)
    def _build_network(self):
        num_ensembles = self.estimator_kwargs['num_ensembles']
        prior_scale = self.estimator_kwargs['prior_scale']
        self.network = self.network_class(state_dim=self.state_dim, num_actions=self.num_actions, hidden_size=self.net_arch, num_ensembles=num_ensembles, prior_scale=prior_scale, activator=self.activation_fn)
    
    def _build_estimator(self):
        self.q_estimator = self.estimator_class(self.network, self.batch_size, self.learning_rate, self.loops_per_train, optimizer_kwargs=self.optimizer_kwargs, estimator_kwargs=self.estimator_kwargs, device=self.device)
        
    def _build_buffer(self):
        self.buffer = self.buffer_class(size=self.buffer_size, batch_size=self.batch_size)
        
        
if __name__ == '__main__':
    pass