from typing import Dict

import flax.linen as nn
import gym
import numpy as np
import imageio
import os
video_dir = "./videos/"


def evaluate(agent: nn.Module, env: gym.Env,
             num_episodes: int, step=0, run_name="default", save_video=False) -> Dict[str, float]:
    stats = {'return': [], 'length': []}
    images = []
    for i in range(num_episodes):
        observation, done = env.reset(), False

        while not done:
            action = agent.sample_actions(observation, temperature=0.0)
            observation, _, done, info = env.step(action)
            if i == 0 and save_video:
                images.append(env.render(camera_name='corner'))

        for k in stats.keys():
            try:
                stats[k].append(info['episode'][k])
            except KeyError:
                pass

        if i == 0 and save_video:
            try:
                returns = info['episode']['return']
            except KeyError:
                returns = 0
            imageio.mimsave(os.path.join(video_dir, f"{run_name}_{returns}_{step}.gif"), images)

    for k, v in stats.items():
        stats[k] = np.mean(v)

    return stats
