from copy import deepcopy

from . import (
    simple_replay_pool,
    extra_policy_info_replay_pool,
    union_pool,
    trajectory_replay_pool,
    continuous_goal_pool)


POOL_CLASSES = {
    'SimpleReplayPool': simple_replay_pool.SimpleReplayPool,
    'TrajectoryReplayPool': trajectory_replay_pool.TrajectoryReplayPool,
    'ExtraPolicyInfoReplayPool': (
        extra_policy_info_replay_pool.ExtraPolicyInfoReplayPool),
    'UnionPool': union_pool.UnionPool,
    'ContinuousGoalPool': continuous_goal_pool.ContinuousGoalPool,
}

DEFAULT_REPLAY_POOL = 'SimpleReplayPool'


def get_replay_pool_from_variant(variant, env, *args, **kwargs):
    replay_pool_params = variant['replay_pool_params']
    replay_pool_type = replay_pool_params['type']
    replay_pool_kwargs = deepcopy(replay_pool_params['kwargs'])

    if replay_pool_type == 'UnionPool':
        num_sub_pools = replay_pool_kwargs.pop('num_sub_pools')
        sub_pool_type = replay_pool_kwargs.pop('sub_pool_type')
        adaptive_sampling = replay_pool_kwargs.pop('adaptive_sampling')
        pools = [POOL_CLASSES[sub_pool_type](
                *args,
                observation_space=env.observation_space,
                action_space=env.action_space,
                **replay_pool_kwargs,
                **kwargs) for _ in range(num_sub_pools)]
        replay_pool = POOL_CLASSES[replay_pool_type](
            *args,
            pools=pools,
            adaptive_sampling=adaptive_sampling,
            **kwargs)
    elif replay_pool_type == 'ContinuousGoalPool':
        sub_pool_type = replay_pool_kwargs.pop('sub_pool_type')
        replay_pool = POOL_CLASSES[replay_pool_type](
            *args,
            observation_space=env.observation_space,
            action_space=env.action_space,
            sub_pool_type=sub_pool_type,
            **replay_pool_kwargs,
            **kwargs)
    else:
        replay_pool = POOL_CLASSES[replay_pool_type](
            *args,
            observation_space=env.observation_space,
            action_space=env.action_space,
            **replay_pool_kwargs,
            **kwargs)

    return replay_pool
