from agent.history_ensemble import HistoryEnsembleMixin
from agent.aps import APSAgent
import utils
import torch

class APSHistoryEnsembleAgent(HistoryEnsembleMixin, APSAgent):
    
    def update_actor(self, obs, task, step):
        metrics = dict()

        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs, stddev)
        action = dist.sample(clip=self.stddev_clip)
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        Q1, Q2 = self.critic(obs, action, task)
        Q = torch.min(Q1, Q2)

        actor_loss = -Q.mean()
        
        return self.add_ensemble_loss(obs=obs, step=step, dist=dist, log_prob=log_prob, action=action, stddev=stddev, actor_loss=actor_loss, metrics=metrics)