import numpy as np
import tqdm

from tianshou.data import AsyncCollector, Batch, Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import BasePolicy

if __name__ == '__main__':
    from env import MyTestEnv
else:  # pytest
    from test.base.env import MyTestEnv


class MyPolicy(BasePolicy):

    def __init__(self, dict_state=False, need_state=True):
        """
        :param bool dict_state: if the observation of the environment is a dict
        :param bool need_state: if the policy needs the hidden state (for RNN)
        """
        super().__init__()
        self.dict_state = dict_state
        self.need_state = need_state

    def forward(self, batch, state=None):
        if self.need_state:
            if state is None:
                state = np.zeros((len(batch.obs), 2))
            else:
                state += 1
        if self.dict_state:
            return Batch(act=np.ones(len(batch.obs['index'])), state=state)
        return Batch(act=np.ones(len(batch.obs)), state=state)

    def learn(self):
        pass


def test_collector_nstep():
    policy = MyPolicy()
    env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
    dum = DummyVectorEnv(env_fns)
    num = len(env_fns)
    c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num))
    for i in tqdm.trange(1, 400, desc="test step collector n_step"):
        c3.reset()
        result = c3.collect(n_step=i * len(env_fns))
        assert result['n/st'] >= i


def test_collector_nepisode():
    policy = MyPolicy()
    env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
    dum = DummyVectorEnv(env_fns)
    num = len(env_fns)
    c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num))
    for i in tqdm.trange(1, 400, desc="test step collector n_episode"):
        c3.reset()
        result = c3.collect(n_episode=i)
        assert result['n/ep'] == i
        assert result['n/st'] == len(c3.buffer)


def test_asynccollector():
    env_lens = [2, 3, 4, 5]
    env_fns = [
        lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens
    ]

    venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
    policy = MyPolicy()
    bufsize = 300
    c1 = AsyncCollector(
        policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4)
    )
    ptr = [0, 0, 0, 0]
    for n_episode in tqdm.trange(1, 100, desc="test async n_episode"):
        result = c1.collect(n_episode=n_episode)
        assert result["n/ep"] >= n_episode
        # check buffer data, obs and obs_next, env_id
        for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]):
            env_len = i + 2
            total = env_len * count
            indices = np.arange(ptr[i], ptr[i] + total) % bufsize
            ptr[i] = (ptr[i] + total) % bufsize
            seq = np.arange(env_len)
            buf = c1.buffer.buffers[i]
            assert np.all(buf.info.env_id[indices] == i)
            assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
            assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
    # test async n_step, for now the buffer should be full of data
    for n_step in tqdm.trange(1, 150, desc="test async n_step"):
        result = c1.collect(n_step=n_step)
        assert result["n/st"] >= n_step
        for i in range(4):
            env_len = i + 2
            seq = np.arange(env_len)
            buf = c1.buffer.buffers[i]
            assert np.all(buf.info.env_id == i)
            assert np.all(buf.obs.reshape(-1, env_len) == seq)
            assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)


if __name__ == '__main__':
    test_collector_nstep()
    test_collector_nepisode()
    test_asynccollector()
