from typing import Dict

import flax.linen as nn
import gym
import numpy as np


def evaluate(
    agent: nn.Module, env: gym.Env, num_episodes: int, temperature: float = 0.00
) -> Dict[str, float]:
    stats = {"return": [], "length": []}

    for _ in range(num_episodes):
        observation, done = env.reset(), False
        while not done:
            action = agent.sample_actions(observation, temperature=temperature)
            observation, _, done, info = env.step(action)

        # for k in stats.keys():
        #     stats[k].append(info["episode"][k])

    for k, v in stats.items():
        stats[k] = np.mean(v)

    return stats
