import gymnasium as gym
from vizdoom import gymnasium_wrapper
import miniworld
from DuRND.Wrappers import *
import levdoom


class RewardForwardFilter:
    def __init__(self, gamma):
        self.rewems = None
        self.gamma = gamma

    def update(self, rews):
        if self.rewems is None:
            self.rewems = rews
        else:
            self.rewems = self.rewems * self.gamma + rews
        return self.rewems


def atari_games_env_maker(env_id, seed=1, render=False, reward_scale=1):
    env = gym.make(env_id) if not render else gym.make(env_id, render_mode="human")

    assert isinstance(env.action_space, gym.spaces.Discrete), "only discrete action space is supported for Atari Games!"

    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    # some wrappers for the atari environment
    env = gym.wrappers.AtariPreprocessing(env, frame_skip=1)
    env = gym.wrappers.FrameStack(env, 4)

    if env_id in env_wrapper_list:
        for wrapper in env_wrapper_list[env_id]:
            env = wrapper["wrapper"](env, **wrapper["kwargs"])

    # + rescale the reward if needed
    if reward_scale != 1:
        env = gym.wrappers.TransformReward(env, lambda x: x * reward_scale)

    # to automatically record the episodic return
    env = gym.wrappers.RecordEpisodeStatistics(env)

    return env


def sync_vector_atari_envs_maker(env_id, num_envs, seed=1, reward_scale=1):
    """
    Make the synchronized vectorized environments.
    :param env_id: the name of the environment
    :param num_envs: the number of environments
    :param seed: the random seed
    :return: the vectorized environments
    """

    envs = gym.vector.SyncVectorEnv(
        [lambda: atari_games_env_maker(env_id, seed, render=False, reward_scale=reward_scale) for _ in range(num_envs)]
    )

    return envs


def vizdoom_env_maker(env_id, seed=1, render=False, reward_scale=1):
    env = gym.make(env_id) if not render else gym.make(env_id, render_mode="human")

    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    env = VizDoomScreenObsWrapper(env)

    env = gym.wrappers.ResizeObservation(env, (84, 84))
    env = gym.wrappers.GrayScaleObservation(env)

    env = VizDoomAddChanelWrapper(env)

    if env_id in env_wrapper_list:
        for wrapper in env_wrapper_list[env_id]:
            env = wrapper["wrapper"](env, **wrapper["kwargs"])

    # + rescale the reward if needed
    if reward_scale != 1:
        env = gym.wrappers.TransformReward(env, lambda x: x * reward_scale)

    # to automatically record the episodic return
    env = gym.wrappers.RecordEpisodeStatistics(env)

    return env


def sync_vector_vizdoom_envs_maker(env_id, num_envs, seed=1, reward_scale=1):
    """
    Make the synchronized vectorized environments.
    :param env_id: the name of the environment
    :param num_envs: the number of environments
    :param seed: the random seed
    :return: the vectorized environments
    """

    envs = gym.vector.SyncVectorEnv(
        [lambda: vizdoom_env_maker(env_id, seed, render=False, reward_scale=reward_scale) for _ in range(num_envs)]
    )

    return envs


def miniworld_env_maker(env_id, seed=1, render=False, reward_scale=1):
    """
    Make the MiniWorld environment.
    :param env_id: the name of the environment
    :param seed: the random seed
    :param render: whether to render the environment
    :return: the environment
    """
    env = gym.make(env_id) if not render else gym.make(env_id, render_mode="human")

    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    env = gym.wrappers.ResizeObservation(env, (84, 84))
    env = gym.wrappers.GrayScaleObservation(env)

    env = VizDoomAddChanelWrapper(env)

    if env_id in env_wrapper_list:
        for wrapper in env_wrapper_list[env_id]:
            env = wrapper["wrapper"](env, **wrapper["kwargs"])

    # + rescale the reward if needed
    if reward_scale != 1:
        env = gym.wrappers.TransformReward(env, lambda x: x * reward_scale)

    env = gym.wrappers.RecordEpisodeStatistics(env)

    # env = RemovePickUpActionWrapper(env)  # remove the pickup action from the action space

    return env


def sync_vector_miniworld_envs_maker(env_id, num_envs, seed=1, reward_scale=1):
    """
    Make the synchronized vectorized environments.
    :param env_id: the name of the environment
    :param num_envs: the number of environments
    :param seed: the random seed
    :return: the vectorized environments
    """

    envs = gym.vector.SyncVectorEnv(
        [lambda: miniworld_env_maker(env_id, seed, render=False, reward_scale=reward_scale) for _ in range(num_envs)]
    )

    return envs


def levdoom_env_maker(env_id, seed=1, render=False, reward_scale=1):
    env = levdoom.make(env_id)

    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    # env = gym.wrappers.ResizeObservation(env, (84, 84))
    env = gym.wrappers.GrayScaleObservation(env)

    env = VizDoomAddChanelWrapper(env)

    if env_id in env_wrapper_list:
        for wrapper in env_wrapper_list[env_id]:
            env = wrapper["wrapper"](env, **wrapper["kwargs"])

    # + rescale the reward if needed
    if reward_scale != 1:
        env = gym.wrappers.TransformReward(env, lambda x: x * reward_scale)

    # to automatically record the episodic return
    env = gym.wrappers.RecordEpisodeStatistics(env)

    return env


def sync_vector_levdoom_envs_maker(env_id, num_envs, seed=1, reward_scale=1):
    """
    Make the synchronized vectorized environments.
    :param env_id: the name of the environment
    :param num_envs: the number of environments
    :param seed: the random seed
    :return: the vectorized environments
    """

    envs = gym.vector.SyncVectorEnv(
        [lambda: levdoom_env_maker(env_id, seed, render=False, reward_scale=reward_scale) for _ in range(num_envs)]
    )

    return envs
