import sys
import time

import gymnasium as gym
import numpy as np
import pytest
from gymnasium.spaces.discrete import Discrete

from tianshou.data import Batch
from tianshou.env import (
    ContinuousToDiscrete,
    DummyVectorEnv,
    MultiDiscreteToDiscrete,
    RayVectorEnv,
    ShmemVectorEnv,
    SubprocVectorEnv,
    VectorEnvNormObs,
)
from tianshou.env.gym_wrappers import TruncatedAsTerminated
from tianshou.utils import RunningMeanStd

if __name__ == "__main__":
    from env import MyTestEnv, NXEnv
else:  # pytest
    from test.base.env import MyTestEnv, NXEnv

try:
    import envpool
except ImportError:
    envpool = None


def has_ray():
    try:
        import ray  # noqa: F401
        return True
    except ImportError:
        return False


def recurse_comp(a, b):
    try:
        if isinstance(a, np.ndarray):
            if a.dtype == object:
                return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all()
            else:
                return np.allclose(a, b)
        elif isinstance(a, (list, tuple)):
            return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all()
        elif isinstance(a, dict):
            return np.array([recurse_comp(a[k], b[k]) for k in a.keys()]).all()
    except (Exception):
        return False


def test_async_env(size=10000, num=8, sleep=0.1):
    # simplify the test case, just keep stepping
    env_fns = [
        lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
        for i in range(size, size + num)
    ]
    test_cls = [SubprocVectorEnv, ShmemVectorEnv]
    if has_ray():
        test_cls += [RayVectorEnv]
    for cls in test_cls:
        v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
        v.seed(None)
        v.reset()
        # for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
        # P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
        # expectation of v is n / (n + 1)
        # for a synchronous environment, the following actions should take
        # about 7 * sleep * num / (num + 1) seconds
        # for async simulation, the analysis is complicated, but the time cost
        # should be smaller
        action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
        current_idx_start = 0
        act = action_list[:num]
        env_ids = list(range(num))
        o = []
        spent_time = time.time()
        while current_idx_start < len(action_list):
            A, B, C, D, E, = v.step(action=act, id=env_ids)
            b = Batch({"obs": A, "rew": B, "terminate": C, "truncated": D, "info": E})
            env_ids = b.info.env_id
            o.append(b)
            current_idx_start += len(act)
            # len of action may be smaller than len(A) in the end
            act = action_list[current_idx_start:current_idx_start + len(A)]
            # truncate env_ids with the first terms
            # typically len(env_ids) == len(A) == len(action), except for the
            # last batch when actions are not enough
            env_ids = env_ids[:len(act)]
        spent_time = time.time() - spent_time
        Batch.cat(o)
        v.close()
        # assure 1/7 improvement
        if sys.platform == "linux" and cls != RayVectorEnv:
            # macOS/Windows cannot pass this check
            assert spent_time < 6.0 * sleep * num / (num + 1)


def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
    env_fns = [
        lambda: MyTestEnv(size=size, sleep=sleep * 2),
        lambda: MyTestEnv(size=size, sleep=sleep * 3),
        lambda: MyTestEnv(size=size, sleep=sleep * 5),
        lambda: MyTestEnv(size=size, sleep=sleep * 7)
    ]
    test_cls = [SubprocVectorEnv, ShmemVectorEnv]
    if has_ray():
        test_cls += [RayVectorEnv]
    total_pass = 0
    for cls in test_cls:
        pass_check = 1
        v = cls(env_fns, wait_num=num - 1, timeout=timeout)
        t = time.time()
        v.reset()
        t = time.time() - t
        print(f"{cls} reset {t}")
        if t > sleep * 9:  # huge than maximum sleep time (7 sleep)
            pass_check = 0
        expect_result = [
            [0, 1],
            [0, 1, 2],
            [0, 1, 3],
            [0, 1, 2],
            [0, 1],
            [0, 2, 3],
            [0, 1],
        ]
        ids = np.arange(num)
        for res in expect_result:
            t = time.time()
            _, _, _, _, info = v.step([1] * len(ids), ids)
            t = time.time() - t
            ids = Batch(info).env_id
            print(ids, t)
            if not (
                len(ids) == len(res) and np.allclose(sorted(ids), res) and
                (t < timeout) == (len(res) == num - 1)
            ):
                pass_check = 0
                break
        total_pass += pass_check
    if sys.platform == "linux":  # Windows/macOS may not pass this check
        assert total_pass >= 2


def test_vecenv(size=10, num=8, sleep=0.001):
    env_fns = [
        lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True)
        for i in range(size, size + num)
    ]
    venv = [
        DummyVectorEnv(env_fns),
        SubprocVectorEnv(env_fns),
        ShmemVectorEnv(env_fns),
    ]
    if has_ray() and sys.platform == "linux":
        venv += [RayVectorEnv(env_fns)]
    for v in venv:
        v.seed(0)
    action_list = [1] * 5 + [0] * 10 + [1] * 20
    o = [v.reset()[0] for v in venv]
    for a in action_list:
        o = []
        for v in venv:
            A, B, C, D, E = v.step([a] * num)
            if sum(C + D):
                A, _ = v.reset(np.where(C + D)[0])
            o.append([A, B, C, D, E])
        for index, infos in enumerate(zip(*o)):
            if index == 4:  # do not check info here
                continue
            for info in infos:
                assert recurse_comp(infos[0], info)

    if __name__ == "__main__":
        t = [0] * len(venv)
        for i, e in enumerate(venv):
            t[i] = time.time()
            e.reset()
            for a in action_list:
                done = e.step([a] * num)[2]
                if sum(done) > 0:
                    e.reset(np.where(done)[0])
            t[i] = time.time() - t[i]
        for i, v in enumerate(venv):
            print(f"{type(v)}: {t[i]:.6f}s")

    def assert_get(v, expected):
        assert v.get_env_attr("size") == expected
        assert v.get_env_attr("size", id=0) == [expected[0]]
        assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3]

    for v in venv:
        assert_get(v, list(range(size, size + num)))
        assert v.env_num == num
        assert v.action_space == [Discrete(2)] * num

        v.set_env_attr("size", 0)
        assert_get(v, [0] * num)

        v.set_env_attr("size", 1, 0)
        assert_get(v, [1] + [0] * (num - 1))

        v.set_env_attr("size", 2, [1, 2, 3])
        assert_get(v, [1] + [2] * 3 + [0] * (num - 4))

    for v in venv:
        v.close()


def test_attr_unwrapped():
    train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")])
    train_envs.set_env_attr("test_attribute", 1337)
    assert train_envs.get_env_attr("test_attribute") == [1337]
    assert hasattr(train_envs.workers[0].env, "test_attribute")
    assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute")


def test_env_obs_dtype():
    for obs_type in ["array", "object"]:
        envs = SubprocVectorEnv(
            [lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]
        )
        obs, info = envs.reset()
        assert obs.dtype == object
        obs = envs.step([1, 1, 1, 1])[0]
        assert obs.dtype == object


def test_env_reset_optional_kwargs(size=10000, num=8):
    env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)]
    test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv]
    if has_ray():
        test_cls += [RayVectorEnv]
    for cls in test_cls:
        v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
        _, info = v.reset(seed=1)
        assert len(info) == len(env_fns)
        assert isinstance(info[0], dict)


def test_venv_wrapper_gym(num_envs: int = 4):
    # Issue 697
    envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)])
    envs = VectorEnvNormObs(envs)
    try:
        obs, info = envs.reset()
    except ValueError:
        obs, info = envs.reset(return_info=True)
    assert isinstance(obs, np.ndarray)
    assert isinstance(info, list)
    assert isinstance(info[0], dict)
    assert obs.shape[0] == len(info) == num_envs


def run_align_norm_obs(raw_env, train_env, test_env, action_list):

    def reset_result_to_obs(reset_result):
        """Extract observation from reset result
        (result is possibly a tuple containing info)"""
        if isinstance(reset_result, tuple) and len(reset_result) == 2:
            obs, _ = reset_result
        else:
            obs = reset_result
        return obs

    eps = np.finfo(np.float32).eps.item()
    raw_reset_result = raw_env.reset()
    train_reset_result = train_env.reset()
    initial_raw_obs = reset_result_to_obs(raw_reset_result)
    initial_train_obs = reset_result_to_obs(train_reset_result)
    raw_obs, train_obs = [initial_raw_obs], [initial_train_obs]
    for action in action_list:
        step_result = raw_env.step(action)
        if len(step_result) == 5:
            obs, rew, terminated, truncated, info = step_result
            done = np.logical_or(terminated, truncated)
        else:
            obs, rew, done, info = step_result
        raw_obs.append(obs)
        if np.any(done):
            reset_result = raw_env.reset(np.where(done)[0])
            obs = reset_result_to_obs(reset_result)
            raw_obs.append(obs)
        step_result = train_env.step(action)
        if len(step_result) == 5:
            obs, rew, terminated, truncated, info = step_result
            done = np.logical_or(terminated, truncated)
        else:
            obs, rew, done, info = step_result
        train_obs.append(obs)
        if np.any(done):
            reset_result = train_env.reset(np.where(done)[0])
            obs = reset_result_to_obs(reset_result)
            train_obs.append(obs)
    ref_rms = RunningMeanStd()
    for ro, to in zip(raw_obs, train_obs):
        ref_rms.update(ro)
        no = (ro - ref_rms.mean) / np.sqrt(ref_rms.var + eps)
        assert np.allclose(no, to)
    assert np.allclose(ref_rms.mean, train_env.get_obs_rms().mean)
    assert np.allclose(ref_rms.var, train_env.get_obs_rms().var)
    assert np.allclose(ref_rms.mean, test_env.get_obs_rms().mean)
    assert np.allclose(ref_rms.var, test_env.get_obs_rms().var)
    reset_result = test_env.reset()
    obs = reset_result_to_obs(reset_result)
    test_obs = [obs]
    for action in action_list:
        step_result = test_env.step(action)
        if len(step_result) == 5:
            obs, rew, terminated, truncated, info = step_result
            done = np.logical_or(terminated, truncated)
        else:
            obs, rew, done, info = step_result
        test_obs.append(obs)
        if np.any(done):
            reset_result = test_env.reset(np.where(done)[0])
            obs = reset_result_to_obs(reset_result)
            test_obs.append(obs)
    for ro, to in zip(raw_obs, test_obs):
        no = (ro - ref_rms.mean) / np.sqrt(ref_rms.var + eps)
        assert np.allclose(no, to)


def test_venv_norm_obs():
    sizes = np.array([5, 10, 15, 20])
    action = np.array([1, 1, 1, 1])
    total_step = 30
    action_list = [action] * total_step
    env_fns = [lambda i=x: MyTestEnv(size=i, array_state=True) for x in sizes]
    raw = DummyVectorEnv(env_fns)
    train_env = VectorEnvNormObs(DummyVectorEnv(env_fns))
    print(train_env.observation_space)
    test_env = VectorEnvNormObs(DummyVectorEnv(env_fns), update_obs_rms=False)
    test_env.set_obs_rms(train_env.get_obs_rms())
    run_align_norm_obs(raw, train_env, test_env, action_list)


def test_gym_wrappers():

    class DummyEnv(gym.Env):

        def __init__(self):
            self.action_space = gym.spaces.Box(
                low=-1.0, high=2.0, shape=(4, ), dtype=np.float32
            )
            self.observation_space = gym.spaces.Discrete(2)

        def step(self, act):
            return self.observation_space.sample(), -1, False, True, {}

    bsz = 10
    action_per_branch = [4, 6, 10, 7]
    env = DummyEnv()
    original_act = env.action_space.high
    # convert continous to multidiscrete action space
    # with different action number per dimension
    env_m = ContinuousToDiscrete(env, action_per_branch)
    # check conversion is working properly for one action
    np.testing.assert_allclose(env_m.action(env_m.action_space.nvec - 1), original_act)
    # check conversion is working properly for a batch of actions
    np.testing.assert_allclose(
        env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
        np.array([original_act] * bsz),
    )
    # convert multidiscrete with different action number per
    # dimension to discrete action space
    env_d = MultiDiscreteToDiscrete(env_m)
    # check conversion is working properly for one action
    np.testing.assert_allclose(
        env_d.action(env_d.action_space.n - 1), env_m.action_space.nvec - 1
    )
    # check conversion is working properly for a batch of actions
    np.testing.assert_allclose(
        env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
        np.array([env_m.action_space.nvec - 1] * bsz),
    )
    # check truncate is True when terminated
    try:
        env_t = TruncatedAsTerminated(env)
    except EnvironmentError:
        env_t = None
    if env_t is not None:
        _, _, truncated, _, _ = env_t.step(env_t.action_space.sample())
        assert truncated


@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool():
    raw = envpool.make_gymnasium("Ant-v3", num_envs=4)
    train = VectorEnvNormObs(envpool.make_gymnasium("Ant-v3", num_envs=4))
    test = VectorEnvNormObs(
        envpool.make_gymnasium("Ant-v3", num_envs=4), update_obs_rms=False
    )
    test.set_obs_rms(train.get_obs_rms())
    actions = [
        np.array([raw.action_space.sample() for _ in range(4)]) for i in range(30)
    ]
    run_align_norm_obs(raw, train, test, actions)


@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool_gym_reset_return_info():
    num_envs = 4
    env = VectorEnvNormObs(
        envpool.make_gymnasium(
            "Ant-v3", num_envs=num_envs, gym_reset_return_info=True
        )
    )
    obs, info = env.reset()
    assert obs.shape[0] == num_envs
    for _, v in info.items():
        if not isinstance(v, dict):
            assert v.shape[0] == num_envs


if __name__ == "__main__":
    test_venv_norm_obs()
    test_venv_wrapper_gym()
    test_venv_wrapper_envpool()
    test_venv_wrapper_envpool_gym_reset_return_info()
    test_env_obs_dtype()
    test_vecenv()
    test_attr_unwrapped()
    test_async_env()
    test_async_check_id()
    test_env_reset_optional_kwargs()
    test_gym_wrappers()
