from collections import defaultdict

import torch
from tqdm import tqdm

from data import Episode, EpisodeDataset
from env_loop import make_env_loop
from utils import coroutine


@coroutine
def make_collector(env_fn, model, dataset: EpisodeDataset, epsilon: float = 0.0, save_on_disk: bool = True, verbose: bool = True):
    episodes, episode_ids = None, defaultdict(lambda: None)

    env = env_fn()
    env_loop = make_env_loop(env, model, epsilon)
    n, to_log, pbar = (None,) * 3

    def reset():
        nonlocal pbar, n, to_log
        n = 0
        to_log = []
        pbar = tqdm(total=num_steps, desc=f'Collect {dataset.name}', disable=not verbose)

    num_steps = yield
    reset()

    while True:

        with torch.no_grad():
            all_obs, act, rew, end, trunc, *_, mask = env_loop.send(1 if num_steps is not None else None)

        n += mask.sum()
        pbar.update(mask.sum().item())

        new_episodes = [Episode(*o_a_r_d_t) for o_a_r_d_t in zip(all_obs, act, rew, end, trunc)]
        episodes = new_episodes if episodes is None else [old.merge(new) for old, new in zip(episodes, new_episodes)]

        for i, episode in enumerate(episodes):
            episode_ids[i] = dataset.add_episode(episode, episode_id=episode_ids[i], save_on_disk=save_on_disk)

        if env.all_done:
            to_log.extend([{f'{dataset.name}/episode_id': episode_ids[i], **e.compute_metrics()} for i, e in enumerate(episodes)])
            episodes, episode_ids = None, defaultdict(lambda: None)

        if num_steps is None or n >= num_steps:
            pbar.close()
            num_steps = yield to_log
            reset()
