import gym

from core.config import BaseConfig
from collections import defaultdict

ENVS = ["AcrobotSwingup-v0", "AcrobotSwingup_sparse-v0",
        "Ball_in_cupCatch-v0",
        "CartpoleBalance-v0", "CartpoleBalance_sparse-v0",
        "CartpoleSwingup-v0", "CartpoleSwingup_sparse-v0",
        "CartpoleTwo_poles-v0", "CartpoleThree_poles-v0",
        "CheetahRun-v0",
        "DogStand-v0", "DogWalk-v0", "DogTrot-v0", "DogRun-v0", "DogFetch-v0",
        "FingerSpin-v0", "FingerTurn_easy-v0", "FingerTurn_hard-v0",
        "FishUpright-v0", "FishSwim-v0",
        "HopperStand-v0", "HopperHop-v0",
        "HumanoidStand-v0", "HumanoidWalk-v0", "HumanoidRun-v0", "HumanoidRun_pure_state-v0", "Humanoid_cmuStand-v0",
        "Humanoid_cmuRun-v0",
        "ManipulatorBring_ball-v0", "ManipulatorBring_peg-v0", "ManipulatorInsert_ball-v0", "ManipulatorInsert_peg-v0",
        "PendulumSwingup-v0",
        "Point_massEasy-v0", "Point_massHard-v0",
        "QuadrupedWalk-v0", "QuadrupedRun-v0", "QuadrupedEscape-v0", "QuadrupedFetch-v0",
        "ReacherEasy-v0", "ReacherHard-v0",
        "StackerStack_2-v0", "StackerStack_4-v0",
        "SwimmerSwimmer6-v0", "SwimmerSwimmer15-v0",
        "WalkerStand-v0", "WalkerWalk-v0", "WalkerRun-v0"]

ACTION_REPEATS = defaultdict(lambda: 2)
CAMERA_ID = defaultdict(lambda: 0)
for x in ["ManipulatorBring_ball-v0", "ManipulatorBring_peg-v0", "ManipulatorInsert_ball-v0",
          "ManipulatorInsert_peg-v0",
          "HumanoidStand-v0", "HumanoidWalk-v0", "HumanoidRun-v0",
          "HumanoidRun_pure_state-v0", "Humanoid_cmuStand-v0", "Humanoid_cmuRun-v0"]:
    CAMERA_ID[x] = 0


class DMControlWrapper(gym.Wrapper):
    def __init__(self, env, camera_id=0):
        super(DMControlWrapper, self).__init__(env)
        self.observation_space = env.observation_space['observations']
        self.camera_id = camera_id

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        return obs['observations'], reward, done, info

    def reset(self):
        obs = self.env.reset()
        return obs['observations']

    def render(self, mode='human', *args, **kwargs):
        return self.env.render(mode, *args, use_opencv_renderer=True, camera_id=self.camera_id, **kwargs)


class DmControlConfig(BaseConfig):
    def __init__(self):
        super(DmControlConfig, self).__init__(max_env_steps=int(4e6),
                                              start_step=int(5e3),
                                              replay_memory_capacity=int(1e6),
                                              valid_envs=ENVS,
                                              action_repeats=ACTION_REPEATS,
                                              action_repeat_set=[1, 2, 4, 8, 16],
                                              env_itr_steps=1000,
                                              test_interval_steps=25000)

    def new_game(self, seed=None):
        env = gym.make('dm2gym:' + self.env_name, environment_kwargs={'flat_observation': True, },
                       visualize_reward=True)
        env = DMControlWrapper(env, CAMERA_ID[self.env_name])

        if seed is not None:
            env.seed(seed)

        return env


run_config = DmControlConfig()
