from typing import Dict, Optional

import torch

from offline_rl.rewards.evaluation.reward_collection import RewardCollection
from offline_rl.rewards.reward_model import RewardModel


class ModelCollection:
    """A collection of labeled models.

    Args:
        label_to_model: Dictionary mapping labels to corresponding models.
    """
    def __init__(self, label_to_model: Dict[str, RewardModel]):
        assert len(label_to_model) > 0
        self.label_to_model = label_to_model

    def rewards(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> RewardCollection:
        """Computes the rewards for each model and returns a reward collection.

        See `RewardModel.reward` for details on args.
        """
        assert len(states) == len(actions)
        assert next_states is None or len(next_states) == len(states)
        assert terminals is None or len(terminals) == len(states)

        rewards = RewardCollection()
        for label, model in self.label_to_model.items():
            # This assumes that rewards should always be flat.
            rewards[label] = model.reward(states, actions, next_states, terminals).reshape(-1)
        assert rewards.is_valid()
        return rewards

    def get_model(self, label: str) -> RewardModel:
        """Gets the model associated with the label.

        Args:
            label: The label of the model to return. Asserts in the model collection.

        Returns:
            The model with the associated label.
        """
        assert label in self.label_to_model, f"No model with label {label}"
        return self.label_to_model[label]

    def __repr__(self):
        return f"RewardCollection({self.label_to_model})"
