import random

import torch
from torch.distributions.categorical import Categorical

from utils import coroutine


@coroutine
def make_env_loop(env, model, epsilon=0.0):
    obs = None
    hx_cx = None

    def reset():
        nonlocal obs, hx_cx
        obs, _ = env.reset()
        hx_cx = None

    num_steps = yield

    reset()

    while True:

        all_ = []
        n = 0
        is_alive = env.is_alive

        while num_steps is None or n < num_steps:
            logits_act, val, hx_cx = model(obs, hx_cx)
            act = Categorical(logits=logits_act).sample()

            if random.random() < epsilon:
                act = torch.randint(low=0, high=env.num_actions, size=(obs.size(0),), device=obs.device)

            next_obs, rew, end, trunc, _ = env.step(act)
            all_.append((obs, act, rew, end, trunc, logits_act, val))
            obs = next_obs
            n += 1
            if env.all_done:
                break

        all_obs, act, rew, end, trunc, logits_act, val = (torch.stack(x, dim=1) for x in zip(*all_))

        with torch.no_grad():
            _, val_bootstrap, _ = model(next_obs, hx_cx) # do not update hx/cx

        mask = compute_mask_after_end_or_trunc(end, trunc)
        mask = torch.logical_and(mask, is_alive.unsqueeze(1)) # mask envs that are already dead

        num_steps = yield all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap, mask

        if env.all_done:
            reset()
        else:
            hx_cx = (hx_cx[0].detach(), hx_cx[1].detach())


def compute_mask_after_end_or_trunc(end: torch.ByteTensor, trunc: torch.ByteTensor) -> torch.BoolTensor:
    assert end.ndim == 2 and end.size() == trunc.size()
    dead = (end + trunc).clip(max=1)
    idx_dead = torch.argmax(dead, dim=1)
    mask = torch.arange(end.size(1), device=end.device).unsqueeze(0) <= idx_dead.unsqueeze(1)
    mask = torch.logical_or(mask, end.sum(dim=1, keepdim=True) == 0)
    return mask
