"""Definition of the ground truth macro-states. Sand is `0`, water is `1`,
rock is `2`.
"""
from swmpo.transition import Transition
import torch
from terrain_mass.environment import EnvironmentInstance
from terrain_mass.environment import Terrain
import enum


class GroundTruthState(enum.Enum):
    SAND = 0
    WATER = 1


def get_ground_truth_state(
        environment_instance: EnvironmentInstance,
        state: torch.Tensor,
        ) -> int:
    """Get the ground truth state of the system."""
    terrain = environment_instance.get_terrain(state)
    if terrain is Terrain.WATER:
        return GroundTruthState.WATER.value
    return GroundTruthState.SAND.value


def get_ground_truth_states(
        environment_instance: EnvironmentInstance,
        episode: list[Transition],
        ) -> list[int]:
    """Get the sequence of ground truth states of the system."""
    states = [
        transition.source_state
        for transition in episode
    ]
    if len(episode) > 0:
        states.append(episode[-1].next_state)
    ground_truth_states = [
        get_ground_truth_state(environment_instance, state)
        for state in states
    ]
    return ground_truth_states
