import warnings

import numpy as np
import torch
from numba import njit

from data import Batch, VectorBuffer


class Collector(object):
    def __init__(self, policy, env, buffer=None):
        super().__init__()

        self.env = env  # type: ignore
        self.env_num = len(self.env)
        if buffer is None:
            buffer = VectorBuffer(self.env_num, self.env_num)
        self.buffer = buffer
        self.policy = policy
        self._action_space = self.env.action_space
        # avoid creating attribute outside __init__
        self.reset(False)


    def reset(self, reset_buffer = True):
        self.data = Batch(
            obs={}, act={}, rew={}, done={}, obs_next={}, info={}
        )
        self.reset_env()
        if reset_buffer:
            self.reset_buffer()
        self.reset_stat()

    def reset_stat(self):
        self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0

    def reset_buffer(self):
        self.buffer.reset()

    def reset_env(self) -> None:
        obs = self.env.reset()
        self.data.obs = obs


    def collect(self, n_step=None, n_episode=None):
        if n_step is not None:
            if not n_step % self.env_num == 0:
                warnings.warn(
                    f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
                    "which may cause extra transitions collected into the buffer."
                )
            ready_env_ids = np.arange(self.env_num)
        elif n_episode is not None:
            ready_env_ids = np.arange(min(self.env_num, n_episode))
            self.data = self.data[:min(self.env_num, n_episode)]
        else:
            raise TypeError(
                "Please specify at least one (either n_step or n_episode) "
                "in AsyncCollector.collect()."
            )

        step_count = 0
        episode_count = 0
        episode_rews = np.zeros_like(ready_env_ids, dtype=np.float64) - np.inf
        episode_lens = np.zeros_like(ready_env_ids, dtype=np.float64) - np.inf
 

        while True:
            assert len(self.data) == len(ready_env_ids)

            with torch.no_grad():  
                result = self.policy(self.data)

            act = result.act.detach().cpu().numpy()
            feature = result.feature.detach().cpu().numpy()
            self.data.update(act=act,feature=feature)

            action_remap = self.policy.map_action(self.data.act)
            result = self.env.step(action_remap, ready_env_ids) 
            obs_next, rew, done, info = result

            self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)

            ep_rew, ep_len = self.buffer.add(self.data, ready_env_ids)

            step_count += len(ready_env_ids)

            self.data.obs = self.data.obs_next.copy()
            if np.any(done):
                env_ind_local = np.where(done)[0]
                env_ind_global = ready_env_ids[env_ind_local]
                episode_count += len(env_ind_local)
                episode_rews[env_ind_global] = ep_rew[env_ind_local].copy()
                episode_lens[env_ind_global] = ep_len[env_ind_local].copy()
                obs_reset = self.env.reset(env_ind_global)
                self.data.obs[env_ind_local] = obs_reset

                if n_episode:
                    surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
                    if surplus_env_num > 0:
                        mask = np.ones_like(ready_env_ids, dtype=bool)
                        mask[env_ind_local[:surplus_env_num]] = False
                        ready_env_ids = ready_env_ids[mask]
                        self.data = self.data[mask]

            if (n_step and step_count >= n_step) or \
                    (n_episode and episode_count >= n_episode):
                break

        self.collect_step += step_count
        self.collect_episode += episode_count

        if n_episode:
            self.data = Batch(
                obs={}, act={}, rew={}, done={}, obs_next={}, info={}
            )
            self.reset_env()

        if episode_count > 0:
            rews = np.ma.masked_equal(episode_rews, -np.inf)
            rew_mean, rew_std = rews.mean(), rews.std()
            lens = np.ma.masked_equal(episode_lens, -np.inf)
            len_mean, len_std = lens.mean(), lens.std()
        else:
            rews  = np.array([])
            rew_mean = rew_std  = 0
            lens  = np.array([])
            len_mean = len_std  = 0

        return {
            "n/ep": episode_count,
            "n/st": step_count,
            "rews": rews,
            "rew": rew_mean,
            "rew_std": rew_std,
            "lens": lens,
            "len": len_mean,
            "len_std": len_std
        }
