from collections import defaultdict

import numpy as np
import torch


def results_to_dict(results):
    all_res = defaultdict(list)
    for r in results:
        for k, v in r.items():
            all_res[k].append(v)
    return all_res


@torch.no_grad()
def evaluate(
    model, 
    env,
    num_eval_episodes: int = 1,
    device: str = "cpu",
):
    results = []
    for _ in range(num_eval_episodes):
        with torch.no_grad():
            result = rollout_episode(model, env, device=device)
        results.append(result)
    return results_to_dict(results)


def rollout_episode(
    model,
    env,
    device: str = "cpu",
):
    model.eval()

    state = env.reset()
    done = False
    episode_return, episode_length = 0, 0
    while not done:
        state = torch.from_numpy(np.float32(state) / 255.).to(device)
        action = model.get_action(state)
        state, reward, force_done, done, log_reward, info = env.step(action)

        episode_return += reward
        episode_length += 1

    info_dict = {}
    rooms_dict = info.get('episode', {}).get('visited_rooms_full', {})
    rooms_survived = info.get('rooms_survived', {})
    rooms_visitation = info.get('room_visitation', {})
    lives_lost = info.get('lives_lost', {})
    rooms_done = info.get('rooms_done', {})
    for room, survived_time in rooms_survived.items():
        info_dict.update(
            {
                f"time_spent/{room}": rooms_dict.get(room, 0),
                f"survived/{room}": survived_time,
                f"visitation/{room}": rooms_visitation.get(room, 0),
                f"lives_lost/{room}": lives_lost.get(room, 0),
                f"rooms_done/{room}": rooms_done.get(room, 0),
            }
        )

    return {
        "episode_return": episode_return,
        "episode_length": episode_length,
        **info_dict,
    }
