from textwrap import dedent
from rich import print as rprint

env_description = (
    "The Snake environment from the Gym-Snake repository provides a grid-based game where the agent controls a snake "
    "that moves around the screen to consume randomly placed food items. Each time the snake eats food, it grows in "
    "length, increasing the complexity of navigation. The agent receives visual observations representing the current "
    "game grid, including the snake's position and the location of the food. The action space is discrete, allowing the "
    "agent to choose directional movements (up, down, left, right). The objective is to maximize the length of the snake "
    "while avoiding collisions with the walls or the snake’s own body. The reward structure is sparse, "
    "giving positive reward when the snake consumes food, and a negative reward is given when a collision occurs (episode ends)."
)
eval_criteria = "mean_ep_snake_length_increase"
reward_func_args = [
    {"name": "game_grid", "type_annotation": "np.ndarray"},
    {"name": "prev_game_grid", "type_annotation": "np.ndarray"},
    {"name": "action", "type_annotation": "int"},
    {"name": "food_eaten", "type_annotation": "bool"},
    {"name": "snake_death", "type_annotation": "bool"},
    {"name": "snake_steps", "type_annotation": "int"},
]
reward_func_return_type = "float"
reward_func_definition = dedent("""\
    Args:
        - `game_grid` (2D numpy array): An array representation of the current game grid, where `0` indicates an empty cell, `1` indicates a food item, `2` indicates a snake body, and `3` indicates a snake head.
           The grid follows typical numpy array indexing, i.e. [0,0] is located at the upper left most pixel, [0, 1] is the pixel to the right of [0,0], [1, 0] is the pixel below [0,0].
        - `prev_game_grid` (2D numpy array): An array representation of the game grid on the previous step. prev_game_grid has the same shape as `game_grid`.
        - `action` (int): The action taken by the snake that led to the current game grid. The action space is discrete, with the following possible values: 0-Move up, 1-Move right, 2-Move down, 3-Move left.
        - `food_eaten` (bool): Whether the snake has eaten food in the current step.
        - `snake_death` (bool): Whether the snake has died in the current step.
        - `snake_steps` (int): The number of steps the snake has taken since the start of the episode.

    Example:
        Consider a 5x5 game grid where the snake has moved only once since the start of the episode, the current arguments are:
        - game_grid:
            [
                [0, 0, 0, 0, 0],
                [0, 2, 0, 0, 0],
                [0, 2, 3, 0, 0],
                [0, 0, 0, 0, 0],
                [0, 0, 0, 1, 0]
            ]
        - prev_game_grid:
            [
                [0, 2, 0, 0, 0],
                [0, 2, 0, 0, 0],
                [0, 3, 0, 0, 0],
                [0, 0, 0, 0, 0],
                [0, 0, 0, 1, 0]
            ]
        - action: 1
        - food_eaten: False
        - snake_death: False
        - snake_steps: 1
        
        This example shows a snake of length 3, whose body is currently at position (1, 1), (2, 1), and the head is at position (2, 2). The food is located at position (4, 3).
        The snake took action 1 (move right) from the previous step where the snake was at position (0, 1), (1, 1), and the head was at position (2, 1).
        Since the snake has not eaten food in the current step, `food_eaten` is False.
        Since the snake has not died in the current step, `snake_death` is False.
        Since the snake has moved only once since the start of the episode, `snake_steps` is 1.

    Returns:
        You need to return the reward signal for the current step.
""") 



reward_func_return_type_baseline = "Tuple[float, Dict[str, float]]"
reward_func_definition_baseline = dedent("""\
    Args:
        - `game_grid` (2D numpy array): An array representation of the current game grid, where `0` indicates an empty cell, `1` indicates a food item, `2` indicates a snake body, and `3` indicates a snake head.
           The grid follows typical numpy array indexing, i.e. [0,0] is located at the upper left most pixel, [0, 1] is the pixel to the right of [0,0], [1, 0] is the pixel below [0,0].
        - `prev_game_grid` (2D numpy array): An array representation of the game grid on the previous step. prev_game_grid has the same shape as `game_grid`.
        - `action` (int): The action taken by the snake that led to the current game grid. The action space is discrete, with the following possible values: 0-Move up, 1-Move right, 2-Move down, 3-Move left.
        - `food_eaten` (bool): Whether the snake has eaten food in the current step.
        - `snake_death` (bool): Whether the snake has died in the current step.
        - `snake_steps` (int): The number of steps the snake has taken since the start of the episode.

    Example:
        Consider a 5x5 game grid where the snake has moved only once since the start of the episode, the current arguments are:
        - game_grid:
            [
                [0, 0, 0, 0, 0],
                [0, 2, 0, 0, 0],
                [0, 2, 3, 0, 0],
                [0, 0, 0, 0, 0],
                [0, 0, 0, 1, 0]
            ]
        - prev_game_grid:
            [
                [0, 2, 0, 0, 0],
                [0, 2, 0, 0, 0],
                [0, 3, 0, 0, 0],
                [0, 0, 0, 0, 0],
                [0, 0, 0, 1, 0]
            ]
        - action: 1
        - food_eaten: False
        - snake_death: False
        - snake_steps: 1
        
        This example shows a snake of length 3, whose body is currently at position (1, 1), (2, 1), and the head is at position (2, 2). The food is located at position (4, 3).
        The snake took action 1 (move right) from the previous step where the snake was at position (0, 1), (1, 1), and the head was at position (2, 1).
        Since the snake has not eaten food in the current step, `food_eaten` is False.
        Since the snake has not died in the current step, `snake_death` is False.
        Since the snake has moved only once since the start of the episode, `snake_steps` is 1.

    Returns (Tuple[float, Dict\[str, float]]):
        1. return the reward signal for the current step.
        2. return a dictionary of each individual reward component for the current step.
""") 



if __name__ == "__main__":
    rprint(env_description)
    # rprint(reward_func_definition)