"""Evaluation for DQN models."""
import random
from typing import Union, Callable

import jax
import tqdm
import gymnasium as gym
import numpy as np


def evaluate_policy(
    policy: Callable[[jax.typing.ArrayLike], jax.typing.ArrayLike],
    env: Union[gym.Env, gym.vector.VectorEnv],
    num_episodes: int = 10,
    epsilon: float = 0.05,
    show_progress: bool = False
) -> tuple[float, float]:
    """Evaluate a policy on a given environment.

    Args:
        policy: The policy to evaluate.
        env: The environment to evaluate the policy on.
        num_episodes: The number of episodes to evaluate the policy on.
            Default: 10.
        epsilon: The probability of taking a random action. Default: 0.05.
        show_progress: Whether to show the progress of the evaluation.
            Default: ``False``.
    
    Returns:
        The mean and standard deviation of the returns of the policy.
    """
    episodic_returns = []
    for _ in tqdm.trange(num_episodes, disable=not show_progress):
        obs, _ = env.reset()
        last_activations = None
        total_reward = 0.0
        while True:
            if random.random() < epsilon:
                actions = env.action_space.sample()
                activations = None
            else:
                actions, activations = policy(obs, last_activations)

            obs, reward, terminated, truncated, _ = env.step(actions)
            total_reward += reward

            last_activations = activations

            if terminated or truncated:
                break

        episodic_returns.append(total_reward)
        
    return np.mean(episodic_returns), np.std(episodic_returns)
