from typing import Protocol, NamedTuple, Tuple
import jax.numpy as jnp


class EvalOutput(NamedTuple):
    """Holds the results of evaluating a batch of K solutions.

    It contains the fitness, behavior descriptors, and their gradients with
    respect to the solution parameters.

    Attributes:
        fitnesses (jnp.ndarray): Fitness values for each solution.
            Shape: (K,)
        descriptors (jnp.ndarray): Behavior descriptors for each solution.
            Shape: (K, m)
        fitness_grads (jnp.ndarray): The gradient of the fitness function for
            each solution. Shape: (K, D). For solution `k`, this is the
            vector `∇_θ f(θ_k)`.
        descriptor_grads (jnp.ndarray): The Jacobian of the behavior function
            for each solution. Shape: (K, m, D). For solution `k`, this is
            the matrix `J_b(θ_k)` where `J[i, j] = ∂b_i / ∂θ_j`.
    """

    fitnesses: jnp.ndarray
    descriptors: jnp.ndarray
    fitness_grads: jnp.ndarray
    descriptor_grads: jnp.ndarray


class Task(Protocol):
    """Defines the interface for a task."""

    solution_size: Tuple[int]
    descriptor_dim: int

    def evaluate(
        self, solutions: jnp.ndarray, key: jnp.ndarray, return_grad: bool = True
    ) -> EvalOutput:
        """
        Evaluates a batch of K solutions, providing fitness, descriptors,
        and their gradients with respect to the solutions.
        """
        ...

    def simple_evaluate(
        self, solution: jnp.ndarray, key: jnp.ndarray
    ) -> Tuple[float, jnp.ndarray]:
        """
        Evaluates a single solution, returning only its fitness and descriptor.
        Useful for evaluations and analysis without gradient overhead.
        """
        ...

    def get_random_solution(self, n: int, key: jnp.ndarray) -> jnp.ndarray:
        """
        Returns `n` randomly initialized solutions.
        """
        ...
