# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


class ACAgent(object):
    def __init__(self, algo, storage):
        self.algo = algo
        self.storage = storage

    def update(self, discard_grad=False, kl_dict=None):
        info = self.algo.update(self.storage, discard_grad=discard_grad, kl_dict=kl_dict)
        self.storage.after_update()

        return info

    def to(self, device):
        self.algo.actor_critic.to(device)
        self.storage.to(device)

        return self

    def train(self):
        self.algo.actor_critic.train()

    def eval(self):
        self.algo.actor_critic.eval()

    def random(self):
        self.algo.actor_critic.random = True

    def process_action(self, action):
        if hasattr(self.algo.actor_critic, "process_action"):
            return self.algo.actor_critic.process_action(action)
        else:
            return action

    def act(self, *args, **kwargs):
        return self.algo.actor_critic.act(*args, **kwargs)

    def get_value(self, *args, **kwargs):
        return self.algo.actor_critic.get_value(*args, **kwargs)

    def insert(self, *args, **kwargs):
        return self.storage.insert(*args, **kwargs)

    @property
    def is_recurrent(self):
        return self.algo.actor_critic.is_recurrent
