# Original implementation: https://github.com/abaisero/asym-rlpo
#
####
#
# Extended to informed POMDPs by anonymous authors (2025)
#
####

import torch
import torch.nn as nn

from asym_rlpo.data import Episode

from .base import A2C_ABC

# Informed asymmetric A2C
class InformedAsymA2C(A2C_ABC):
    model_keys = {
        'agent': [
            'action_model',
            'observation_model',
            'interaction_model',
            'history_model',
            'policy_model',
        ],
        'critic': [
            'latent_model',
            'action_model',
            'observation_model',
            'interaction_model',
            'history_model',
            'information_model',
            'vhi_model',
        ],
    }

    def compute_v_values(
        self, models: nn.ModuleDict, episode: Episode
    ) -> torch.Tensor:

        history_features = self.compute_history_features(
            models.critic.interaction_model,
            models.critic.history_model,
            episode.actions,
            episode.observations,
        )
        latent_features = models.critic.information_model(episode.information)
        inputs = torch.cat([history_features, latent_features], dim=-1)
        vhi_values = models.critic.vhi_model(inputs).squeeze(-1)
        return vhi_values
