# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import deque
from typing import Any, NamedTuple

import dm_env
import numpy as np
from dm_env import StepType, specs
from gym import ObservationWrapper
from gym.wrappers import FrameStack


class VstackWrapper(ObservationWrapper):
    def __init__(self, env):
        from gym import spaces
        super(VstackWrapper, self).__init__(env)
        self.observation_space = spaces.Box(
            low=self.observation_space.low.min(),
            high=self.observation_space.high.max(),
            shape=(self.observation_space.shape[0] * self.observation_space.shape[1], *self.observation_space.shape[2:])
        )
        # self.observation_space = self.observation_space.shape[0] * self.observation_space.shape[1]

    def observation(self, lazy_frames):
        return np.vstack(lazy_frames)


class FrameStackWithState(FrameStack):
    """ Slightly customized version of FrameStack.
    This stacks info['sim_state'] as well
    """

    def __init__(self, env, num_stack, lz4_compress=False, state_key='sim_state'):
        super().__init__(env, num_stack, lz4_compress)
        self.state_frames = deque(maxlen=num_stack)
        self._state_key = state_key

    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        self.frames.append(observation)
        self.state_frames.append(info[self._state_key])
        info.update({self._state_key: list(self.state_frames)})
        return self._get_observation(), reward, done, info

    def reset(self, **kwargs):
        observation = self.env.reset(**kwargs)
        dmc_env = self.env.unwrapped
        state = dmc_env.env.physics.get_state()
        [self.frames.append(observation) for _ in range(self.num_stack)]
        [self.state_frames.append(state) for _ in range(self.num_stack)]
        return self._get_observation()


def get_env(name, frame_stack, action_repeat, seed,
            distraction_config=('background', 'camera', 'color'),
            rails_wrapper=False, save_info_wrapper=False, intensity=0, size=84):
    import os

    import distracting_control
    import gym
    from ml_logger import logger

    # Common render settings
    module, env_name = name.split(':', 1)
    domain_name = env_name.split('-')[0].lower()
    camera_id = 2 if env_name.startswith('Quadruped') else 0  # zoom in camera for quadruped
    render_kwargs = dict(height=size, width=size, camera_id=camera_id)

    # NOTE: distraction_config controls the all distractio settings
    # default setttings for distraction_control are: ('background', 'camera', 'color').
    # extra setting: 'video-background' and 'dmcgen-color-hard'

    # Replace video-background with background and set dynamic=True
    if 'video-background' in distraction_config:
        distraction_config = list(distraction_config)
        distraction_config.remove('video-background')
        distraction_config.append('background')
        dynamic = True
    else:
        dynamic = False

    if 'dmcgen-color-hard' in distraction_config:
        assert len(distraction_config) == 1, 'dmcgen-color-hard cannot be used with other distractions'
        dmcgen_color_hard = True
    else:
        dmcgen_color_hard = False

    if intensity:
        assert env_name.endswith('intensity-v1')

    # Default gdc background path: $HOME/datasets/DAVIS/JPEGImages/480p/
    print('making environment:', module, name, distraction_config)

    # NOTE: name format:
    #   f'distracting_control:{domain_name.capitalize()}-{task_name}-{difficulty}-v1'
    #   f'dmc:{domain_name.capitalize()}-{task_name}-v1'
    if dmcgen_color_hard:
        from .dmc_gen_env.wrappers import FrameStack, ColorWrapper

        # Create dmc env
        # NOTE: This env is not included in gym-dmc or distracting_control, thus handled specially.
        # I didn't wrap this env with FrameStackWithState env, so NOT all states are available.

        mode = 'color_hard'  # NOTE: Currently restricting the mode to only this one.
        env = gym.make('dmc:' + env_name, from_pixels=True, frame_skip=action_repeat, channels_first=True,
                       **render_kwargs)
        env = FrameStack(env, frame_stack)
        env = ColorWrapper(env, mode, domain_name, seed, fix_color=True)
        env = gym.wrappers.RescaleAction(env, -1.0, 1.0)
        env.seed(seed)
    else:
        extra_kwargs = {}
        if module == 'distracting_control':
            extra_kwargs.update(dict(
                background_data_path=os.environ.get("DC_BG_PATH", None),
                distraction_seed=seed,
                distraction_types=distraction_config,
                dynamic=dynamic,
                fix_distraction=True,
                intensity=intensity,
                disable_zoom=True,
                sample_from_edge=bool(intensity)
            ))

            # Our project-specific distraction configuration
            if intensity:
                extra_kwargs.update({'sample_from_edge': True})

                if "background" in distraction_config:
                    # Get max value of ground-plane-alpha
                    if domain_name == 'reacher':
                        max_ground_plane_alpha = 0.0
                    elif domain_name in ['walker', 'cheetah']:
                        max_ground_plane_alpha = 1.0
                    else:
                        max_ground_plane_alpha = 0.3

                    # intensity decides background alpha and ground-plane alpha.
                    assert 0 <= intensity <= 1
                    extra_kwargs.update({'background_kwargs': {
                        'video_alpha': intensity,
                        'ground_plane_alpha': 1.0 * (1.0 - intensity) + max_ground_plane_alpha * intensity
                    }})
            logger.print('name', name)
            logger.print('extra_kwargs', extra_kwargs)

        env = gym.make(name, from_pixels=True, frame_skip=action_repeat, channels_first=True,
                       **render_kwargs, **extra_kwargs)
        env.seed(seed)

        # Inject RailsWrapper
        if save_info_wrapper:
            from .collect_offline_data import SaveInfoWrapper
            env = SaveInfoWrapper(env)
        if rails_wrapper:
            from .collect_offline_data import RailsWrapper
            env = RailsWrapper(env)

        # Wrappers
        env = gym.wrappers.RescaleAction(env, -1.0, 1.0)
        env = FrameStackWithState(env, frame_stack)
        env = VstackWrapper(env)

    # HACK: video_recorder requires access to env.physics.render
    env.physics = env.unwrapped.env.physics
    return env
