import torch
from erl_lib.agent.module.critic.critic import EnsembleCriticNetwork


class TwoHeadsCritic(EnsembleCriticNetwork):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, dim_output=2, **kwargs)
        self.latent_state = 0

    def base_forward(self, state_action):
        preds = self.hidden_layers(state_action)
        greedy_q, variational_q = preds.transpose(0, 2)
        return greedy_q, variational_q

    def forward(self, state_action):
        greedy_q, variational_q = self.base_forward(state_action)
        if self.latent_state:
            return variational_q
        else:
            return greedy_q

    def set_latent_state(self, latent_state):
        self.latent_state = latent_state
