from dataclasses import dataclass
import numpy as np
import tensorflow as tf


@dataclass
class TaskConfig:
    time_limit: Optional[int] = 100


class Task:
    def __init__(self, config: TaskConfig):
        self.config = config
        self.goal = None

    def sample_start_configuration(self):
        """
        Optionally sample a start configuration for the task.
        """
        return None

    def sample_goal(self, observation: dict):
        """
        Sample a goal configuration based on the current observation

        Return
        goal : dict
            Goal configuration for the task
        """
        raise NotImplementedError

    def compute_step(
        self,
        observation: dict,
        action: np.ndarray,
        next_observation: dict,
        goal: dict,
        num_steps: int,
    ) -> Tuple[float, bool, dict]:
        """
        Compute information (reward, termination, etc.) from a step with the task.

        Parameters
        config : TaskConfig
            Configuration of the task
        observation : dict
            Current observation from the robot
        action : np.ndarray
            Action taken by the robot
        next_observation : dict
            Next observation from the robot
        goal : dict
            Goal configuration for the task
        num_steps : int
            Number of steps taken in the current episode

        Return
        reward : float
            Reward for the current transition
        truncated : bool
            Whether the episode is truncated
        terminated : bool
            Whether the episode is terminated
        info : dict
        """
        raise NotImplementedError
    
    def data_format(self, obs_format: dict):
        """
        Get the data format for the goal and info.
        """
        raise NotImplementedError

    def reset(self, observation: dict):
        """
        Reset the task to a new episode.
        """
        self.goal = self.sample_goal(observation)
        self.num_steps = 0

    def update(self, observation: dict, action: np.ndarray, next_observation: dict):
        """
        Update the task based on the current state of the robot.
        """
        reward, truncated, terminated, info = self.compute_step(
            observation, action, next_observation, self.goal, self.num_steps
        )
        self.num_steps += 1

        return reward, truncated, terminated, info

    def get_goal(self):
        """
        Get the current goal for the task.
        """
        return self.goal


class BaseNavTask(Task):
    def __init__(self, config: TaskConfig):
        super().__init__(config)

    def sample_goal(self, observation: dict):
        return {}

    def compute_step(
        self,
        observation: dict,
        action: np.ndarray,
        next_observation: dict,
        goal: dict,
        num_steps: int,
    ) -> Tuple[float, bool, dict]:
        crashed = next_observation["crash"]
        time_limit = num_steps >= self.config.time_limit

        return (
            0.0,
            time_limit,
            crashed,
            {
                "crash": crashed,
                "time_limit": time_limit,
            },
        )
    
    def data_format(self, obs_format: dict):
        return {}, {
            "crash": tf.TensorSpec((), tf.bool, name="crash"),
            "time_limit": tf.TensorSpec((), tf.bool, name="time_limit"),
        }
