# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A wrapper for dm_control environments which applies color distractions."""
import os
import gym
from PIL import Image
import collections
from dm_control.rl import control
import numpy as np

DIFFICULTY_NUM_VIDEOS = dict(easy=4, medium=8, hard=None)


def get_background_kwargs(domain_name,
                          num_videos,
                          dynamic,
                          dataset_path,
                          dataset_videos=None,
                          shuffle=False,
                          video_alpha=1.0):
    assert domain_name in [
        'reacher', 'cartpole', 'finger', 'cheetah', 'ball_in_cup', 'walker',
        'humanoid', 'hopper'
    ]
    if domain_name == 'reacher':
        ground_plane_alpha = 0.0
    elif domain_name in ['walker', 'cheetah', 'hopper']:
        ground_plane_alpha = 1.0
    else:
        ground_plane_alpha = 0.3

    return dict(
        num_videos=num_videos,
        video_alpha=video_alpha,
        ground_plane_alpha=ground_plane_alpha,
        dynamic=dynamic,
        dataset_path=dataset_path,
        dataset_videos=dataset_videos,
        shuffle_buffer_size=100 if shuffle else None,
    )


from dm_control.mujoco.wrapper import mjbindings

DAVIS17_TRAINING_VIDEOS = [
    'bear', 'bmx-bumps', 'boat', 'boxing-fisheye', 'breakdance-flare', 'bus',
    'car-turn', 'cat-girl', 'classic-car', 'color-run', 'crossing',
    'dance-jump', 'dancing', 'disc-jockey', 'dog-agility', 'dog-gooses',
    'dogs-scale', 'drift-turn', 'drone', 'elephant', 'flamingo', 'hike',
    'hockey', 'horsejump-low', 'kid-football', 'kite-walk', 'koala',
    'lady-running', 'lindy-hop', 'longboard', 'lucia', 'mallard-fly',
    'mallard-water', 'miami-surf', 'motocross-bumps', 'motorbike',
    'night-race', 'paragliding', 'planes-water', 'rallye', 'rhino',
    'rollerblade', 'schoolgirls', 'scooter-board', 'scooter-gray', 'sheep',
    'skate-park', 'snowboard', 'soccerball', 'stroller', 'stunt', 'surf',
    'swing', 'tennis', 'tractor-sand', 'train', 'tuk-tuk', 'upside-down',
    'varanus-cage', 'walking'
]
DAVIS17_VALIDATION_VIDEOS = [
    'bike-packing', 'blackswan', 'bmx-trees', 'breakdance', 'camel',
    'car-roundabout', 'car-shadow', 'cows', 'dance-twirl', 'dog', 'dogs-jump',
    'drift-chicane', 'drift-straight', 'goat', 'gold-fish', 'horsejump-high',
    'india', 'judo', 'kite-surf', 'lab-coat', 'libby', 'loading',
    'mbike-trick', 'motocross-jump', 'paragliding-launch', 'parkour', 'pigs',
    'scooter-black', 'shooting', 'soapbox'
]
SKY_TEXTURE_INDEX = 0
Texture = collections.namedtuple('Texture', ('size', 'address', 'textures'))


def imread(filename):
    img = Image.open(filename)
    img_np = np.asarray(img)
    return img_np


def size_and_flatten(image, ref_height, ref_width):
    # Resize image if necessary and flatten the result.
    image_height, image_width = image.shape[:2]

    if image_height != ref_height or image_width != ref_width:
        # image = tf.cast(tf.image.resize(image, [ref_height, ref_width]), tf.uint8)
        import torchvision.transforms.functional as F
        from PIL import Image
        import PIL
        image = Image.fromarray(image)

        image = np.array(
            image.resize([ref_height, ref_width], PIL.Image.BICUBIC))

        # image = F.resize(torch.transpose(torch.tensor(image), 0,2), [ref_height, ref_width])
        # torch.transpose(image, 0, 2)
        # image = image.to(torch.uint8)
        image = image.astype(np.uint8)
    # return tf.reshape(image, [-1]).numpy()
    return image.flatten()


def blend_to_background(alpha, image, background):
    if alpha == 1.0:
        return image
    elif alpha == 0.0:
        return background
    else:
        return (alpha * image.astype(np.float32) +
                (1. - alpha) * background.astype(np.float32)).astype(np.uint8)


class DistractingBackgroundEnv(control.Environment):
    """Environment wrapper for background visual distraction.

  **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure
  the background image changes are applied before rendering occurs.
  """
    def __init__(self,
                 env,
                 dataset_path=None,
                 dataset_videos=None,
                 video_alpha=1.0,
                 ground_plane_alpha=1.0,
                 num_videos=None,
                 dynamic=False,
                 seed=None,
                 shuffle_buffer_size=None):

        if not 0 <= video_alpha <= 1:
            raise ValueError('`video_alpha` must be in the range [0, 1]')

        self._env = env
        self._video_alpha = video_alpha
        self._ground_plane_alpha = ground_plane_alpha
        self._random_state = np.random.RandomState(seed=seed)
        self._dynamic = dynamic
        self._shuffle_buffer_size = shuffle_buffer_size
        self._background = None
        self._current_img_index = 0

        if not dataset_path or num_videos == 0:
            # Allow running the wrapper without backgrounds to still set the ground
            # plane alpha value.
            self._video_paths = []
        else:
            # Use all videos if no specific ones were passed.
            if not dataset_videos:
                # dataset_videos = sorted(tf.io.gfile.listdir(dataset_path))
                dataset_videos = sorted(os.listdir(dataset_path))
            # Replace video placeholders 'train'/'val' with the list of videos.
            elif dataset_videos in ['train', 'training']:
                dataset_videos = DAVIS17_TRAINING_VIDEOS
            elif dataset_videos in ['val', 'validation']:
                dataset_videos = DAVIS17_VALIDATION_VIDEOS
            # Get complete paths for all videos.
            video_paths = [
                os.path.join(dataset_path, subdir) for subdir in dataset_videos
            ]

            # Optionally use only the first num_paths many paths.
            if num_videos is not None:
                if num_videos > len(video_paths) or num_videos < 0:
                    raise ValueError(
                        f'`num_bakground_paths` is {num_videos} but '
                        'should not be larger than the number of available '
                        f'background paths ({len(video_paths)}) and at '
                        'least 0.')
                video_paths = video_paths[:num_videos]

            self._video_paths = video_paths

    def reset(self):
        """Reset the background state."""
        time_step = self._env.reset()
        self._reset_background()
        return time_step

    def _reset_background(self):
        # Make grid semi-transparent.
        if self._ground_plane_alpha is not None:
            self._env.physics.named.model.mat_rgba[
                'grid', 'a'] = self._ground_plane_alpha

        # For some reason the height of the skybox is set to 4800 by default,
        # which does not work with new textures.
        self._env.physics.model.tex_height[SKY_TEXTURE_INDEX] = 800

        # Set the sky texture reference.
        sky_height = self._env.physics.model.tex_height[SKY_TEXTURE_INDEX]
        sky_width = self._env.physics.model.tex_width[SKY_TEXTURE_INDEX]
        sky_size = sky_height * sky_width * 3
        sky_address = self._env.physics.model.tex_adr[SKY_TEXTURE_INDEX]

        sky_texture = self._env.physics.model.tex_rgb[sky_address:sky_address +
                                                      sky_size].astype(
                                                          np.float32)

        if self._video_paths:

            if self._shuffle_buffer_size:
                # Shuffle images from all videos together to get background frames.
                file_names = [
                    os.path.join(path, fn) for path in self._video_paths
                    # for fn in tf.io.gfile.listdir(path)
                    for fn in os.listdir(path)
                ]
                self._random_state.shuffle(file_names)
                # Load only the first n images for performance reasons.
                file_names = file_names[:self._shuffle_buffer_size]
                images = [imread(fn) for fn in file_names]
            else:
                # Randomly pick a video and load all images.
                video_path = self._random_state.choice(self._video_paths)
                # file_names = tf.io.gfile.listdir(video_path)
                file_names = os.listdir(video_path)
                if not self._dynamic:
                    # Randomly pick a single static frame.
                    file_names = [self._random_state.choice(file_names)]
                # images = [imread(os.path.join(video_path, fn)) for fn in file_names]
                images = [
                    imread(os.path.join(video_path, fn.decode()))
                    for fn in file_names
                ]

            # Pick a random starting point and steping direction.
            self._current_img_index = self._random_state.choice(len(images))
            self._step_direction = self._random_state.choice([-1, 1])

            # Prepare images in the texture format by resizing and flattening.

            # Generate image textures.
            texturized_images = []
            for image in images:
                image_flattened = size_and_flatten(image, sky_height,
                                                   sky_width)
                new_texture = blend_to_background(self._video_alpha,
                                                  image_flattened, sky_texture)
                texturized_images.append(new_texture)

        else:

            self._current_img_index = 0
            texturized_images = [sky_texture]

        self._background = Texture(sky_size, sky_address, texturized_images)
        self._apply()

    def step(self, action):
        time_step = self._env.step(action)

        if time_step.first():
            self._reset_background()
            return time_step

        if self._dynamic and self._video_paths:
            # Move forward / backward in the image sequence by updating the index.
            self._current_img_index += self._step_direction

            # Start moving forward if we are past the start of the images.
            if self._current_img_index <= 0:
                self._current_img_index = 0
                self._step_direction = abs(self._step_direction)
            # Start moving backwards if we are past the end of the images.
            if self._current_img_index >= len(self._background.textures):
                self._current_img_index = len(self._background.textures) - 1
                self._step_direction = -abs(self._step_direction)

            self._apply()
        return time_step

    def _apply(self):
        """Apply the background texture to the physics."""

        if self._background:
            start = self._background.address
            end = self._background.address + self._background.size
            texture = self._background.textures[self._current_img_index]

            self._env.physics.model.tex_rgb[start:end] = texture
            # Upload the new texture to the GPU. Note: we need to make sure that the
            # OpenGL context belonging to this Physics instance is the current one.
            with self._env.physics.contexts.gl.make_current() as ctx:
                ctx.call(
                    mjbindings.mjlib.mjr_uploadTexture,
                    self._env.physics.model.ptr,
                    self._env.physics.contexts.mujoco.ptr,
                    SKY_TEXTURE_INDEX,
                )

    # Forward property and method calls to self._env.
    def __getattr__(self, attr):
        # if '_env' not in vars(self):
        #     raise AttributeError
        # return getattr(self._env, attr)
        # if hasattr(self._env, attr):
        if '_env' in vars(self):
            if hasattr(self._env, attr):
                # return getattr(self._env, attr)
                return getattr(self._env, attr)
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))


from dm_control import suite  # pylint: disable=g-import-not-at-top


def bg_load(
    domain_name,
    task_name,
    difficulty="easy",
    dynamic=False,
    background_dataset_path=None,
    background_dataset_videos="train",
    background_kwargs=None,
    task_kwargs=None,
    environment_kwargs=None,
    visualize_reward=False,
):
    """Returns an environment from a domain name, task name and optional settings.

  ```python
  env = suite.load('cartpole', 'balance')
  ```

  Adding a difficulty will configure distractions matching the reference paper
  for easy, medium, hard.

  Users can also toggle dynamic properties for distractions.

  Args:
    domain_name: A string containing the name of a domain.
    task_name: A string containing the name of a task.
    difficulty: Difficulty for the suite. One of 'easy', 'medium', 'hard'.
    dynamic: Boolean controlling whether distractions are dynamic or static.
    background_dataset_path: String to the davis directory that contains the
      video directories.
    background_dataset_videos: String ('train'/'val') or list of strings of the
      DAVIS videos to be used for backgrounds.
    background_kwargs: Dict, overwrites settings for background distractions.
    task_kwargs: Dict, dm control task kwargs.
    environment_kwargs: Optional `dict` specifying keyword arguments for the
      environment.
    visualize_reward: Optional `bool`. If `True`, object colours in rendered
      frames are set to indicate the reward at each step. Default `False`.
  Returns:
    The requested environment.
  """

    if difficulty not in [None, "easy", "medium", "hard"]:
        raise ValueError(
            "Difficulty should be one of: 'easy', 'medium', 'hard'.")


    env = suite.load(domain_name,
                     task_name,
                     task_kwargs=task_kwargs,
                     environment_kwargs=environment_kwargs,
                     visualize_reward=visualize_reward)

    # Apply background distractions.
    assert background_dataset_path is not None
    final_background_kwargs = dict()
    # Get kwargs for the given difficulty.
    num_videos = DIFFICULTY_NUM_VIDEOS[difficulty]
    final_background_kwargs.update(
        get_background_kwargs(domain_name, num_videos, dynamic,
                              background_dataset_path,
                              background_dataset_videos))
    if background_kwargs:
        # Overwrite kwargs with those passed here.
        final_background_kwargs.update(background_kwargs)
    env = DistractingBackgroundEnv(env, **final_background_kwargs)

    return env


class DeepMindControlBackground:
    metadata = {}

    def __init__(self,
                 name,
                 action_repeat=1,
                 size=(64, 64),
                 camera=None,
                 seed=0,
                 difficulty="easy",
                 dynamic=False,
                 background_dataset_path=None):
        domain, task = name.split("_", 1)
        if domain == "cup":  # Only domain with multiple words.
            domain = "ball_in_cup"
        if isinstance(domain, str):
            self._env = bg_load(
                domain,
                task,
                difficulty=difficulty,
                dynamic=dynamic,
                task_kwargs={"random": seed},
                background_dataset_path=background_dataset_path)
        else:
            assert task is None
            self._env = domain()
        self._action_repeat = action_repeat
        self._size = size
        if camera is None:
            camera = dict(quadruped=2).get(domain, 0)
        self._camera = camera
        self.reward_range = [-np.inf, np.inf]

    @property
    def observation_space(self):
        spaces = {}
        for key, value in self._env.observation_spec().items():
            if len(value.shape) == 0:
                shape = (1, )
            else:
                shape = value.shape
            spaces[key] = gym.spaces.Box(-np.inf,
                                         np.inf,
                                         shape,
                                         dtype=np.float32)
        spaces["image"] = gym.spaces.Box(0,
                                         255,
                                         self._size + (3, ),
                                         dtype=np.uint8)
        return gym.spaces.Dict(spaces)

    @property
    def action_space(self):
        spec = self._env.action_spec()
        return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)

    def step(self, action):
        assert np.isfinite(action).all(), action
        reward = 0
        for _ in range(self._action_repeat):
            time_step = self._env.step(action)
            reward += time_step.reward or 0
            if time_step.last():
                break
        obs = dict(time_step.observation)
        obs = {
            key: [val] if len(val.shape) == 0 else val
            for key, val in obs.items()
        }
        obs["image"] = self.render()
        # There is no terminal state in DMC
        obs["is_terminal"] = False if time_step.first(
        ) else time_step.discount == 0
        obs["is_first"] = time_step.first()
        done = time_step.last()
        info = {"discount": np.array(time_step.discount, np.float32)}
        return obs, reward, done, info

    def reset(self):
        time_step = self._env.reset()
        obs = dict(time_step.observation)
        obs = {
            key: [val] if len(val.shape) == 0 else val
            for key, val in obs.items()
        }
        obs["image"] = self.render()
        obs["is_terminal"] = False if time_step.first(
        ) else time_step.discount == 0
        obs["is_first"] = time_step.first()
        return obs

    def render(self, *args, **kwargs):
        if kwargs.get("mode", "rgb_array") != "rgb_array":
            raise ValueError("Only render mode 'rgb_array' is supported.")
        return self._env.physics.render(*self._size, camera_id=self._camera)
