from typing import Dict

import flax.linen as nn
import gym
import numpy as np


def evaluate(
    agent: nn.Module,
    env: gym.Env,
    num_episodes: int,
    num_actions_to_sample: int,
    fixed_action_noise: float,
    optimism_parameter: float,
) -> Dict[str, float]:
    stats = {"return": [], "length": []}

    for _ in range(num_episodes):
        observation, done = env.reset(), False

        while not done:
            action, _ = agent.sample_actions(
                observation,
                temperature=0.0,
                num_actions_to_sample=num_actions_to_sample,
                fixed_action_noise=fixed_action_noise,
                optimism_parameter=optimism_parameter,
            )
            observation, _, done, info = env.step(action)

        for k in stats.keys():
            stats[k].append(info["episode"][k])

    for k, v in stats.items():
        stats[k] = np.mean(v)

    return stats
