import gym
import numpy as np

class DeepMindControl:
    metadata = {}

    def __init__(self, name, action_repeat=1, size=(64, 64), camera=None, seed=0, render_image=False): #[todo]
        domain, task = name.split("_", 1)
        if domain == "cup":  # Only domain with multiple words.
            domain = "ball_in_cup"
        #[todo] start
        if render_image:
            import os
            os.environ["MUJOCO_GL"] = "egl"
            os.environ["PYOPENGL_PLATFORM"] = "egl"
            os.environ[
                "LD_LIBRARY_PATH"] = "/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/nvidia:" + os.environ.get(
                "LD_LIBRARY_PATH", "")
            os.environ["LD_PRELOAD"] = "/usr/lib/x86_64-linux-gnu/libEGL_nvidia.so.0"
            from dm_control import _render
            print(_render.BACKEND)
            from dm_control import mujoco


        #[todo] end
        if isinstance(domain, str):
            from dm_control import suite

            self._env = suite.load(
                domain,
                task,
                task_kwargs={"random": seed},
            )
        else:
            assert task is None
            self._env = domain()
        self._action_repeat = action_repeat
        self._size = size
        if camera is None:
            camera = dict(quadruped=2).get(domain, 0)
        self._camera = camera
        self.reward_range = [-np.inf, np.inf]
        #[todo] start
        self.render_image = render_image
        if self.render_image:

            self._renderer = PersistentRenderer(self._env.physics, height=size[0], width=size[1], camera_id=self._camera)
        self.task = name
        # 用于缓存 flatten 相关信息
        self._flatten_initialized = False
        self._flat_keys = None
        self._flat_slices = None
        self._flat_buffer = None
        #[todo] end

    @property
    def observation_space(self):
        spaces = {}
        obs_shp = [] #[todo]
        for key, value in self._env.observation_spec().items():
            if len(value.shape) == 0:
                shape = (1,)
            else:
                shape = value.shape
            #[todo] start
            try:
                shp = np.prod(value.shape)
            except:
                shp = 1
            obs_shp.append(shp)
            #[todo] end
            spaces[key] = gym.spaces.Box(-np.inf, np.inf, shape, dtype=np.float32)
        spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8)
        #[todo] start
        if not self.render_image:
            obs_shp = (int(np.sum(obs_shp)),)
            spaces["state"] = gym.spaces.Box(
                low=np.full(
                    obs_shp,
                    -np.inf,
                    dtype=np.float32),
                high=np.full(
                    obs_shp,
                    np.inf,
                    dtype=np.float32),
                dtype=np.float32,
            )
        spaces["token_embed"] = gym.spaces.Box(-np.inf, np.inf, (384,), dtype=np.float32)
        #[todo] end

        return gym.spaces.Dict(spaces)

    @property
    def action_space(self):
        spec = self._env.action_spec()
        return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)

    def step(self, action):
        assert np.isfinite(action).all(), action
        reward = 0
        for _ in range(self._action_repeat):
            time_step = self._env.step(action)
            reward += time_step.reward or 0
            if time_step.last():
                break
        obs = dict(time_step.observation)
        obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
        if self.render_image:  #[todo]
            obs["image"] = self.render()
        else:
            #[todo] start
            obs["image"] = np.zeros(shape=self._size + (3,))
            obs["state"] = self._flatten_obs(time_step.observation)
            #[todo] end
        # There is no terminal state in DMC
        obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
        obs["is_first"] = time_step.first()
        obs["token_embed"] = np.zeros(384) #[todo]
        done = time_step.last()
        info = {"discount": np.array(time_step.discount, np.float32)}
        return obs, reward, done, info

    def reset(self):
        time_step = self._env.reset()
        obs = dict(time_step.observation)
        obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
        if self.render_image:
            obs["image"] = self.render()
        else:
            #[todo] start
            obs["image"] = np.zeros(shape=self._size + (3,))
            obs["state"] = self._flatten_obs(time_step.observation)
            #[todo] end
        obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
        obs["is_first"] = time_step.first()
        obs["token_embed"] = np.zeros(384) #[todo]
        return obs

    def render(self, *args, **kwargs):
        if kwargs.get("mode", "rgb_array") != "rgb_array":
            raise ValueError("Only render mode 'rgb_array' is supported.")
        # return self._env.physics.render(*self._size, camera_id=self._camera)
        #[todo] start
        # self._renderer.update_scene(self._env.physics, camera_id=self._camera)
        return self._renderer.render()
        #[todo] end

    #[todo] start

    def _init_flatten_obs(self, obs):
        self._flat_keys = list(obs.keys())

        sizes = [obs[k].size for k in self._flat_keys]

        self._flat_slices = []
        start = 0
        for s in sizes:
            end = start + s
            self._flat_slices.append((start, end))
            start = end

        total = start
        dtype = obs[self._flat_keys[0]].dtype

        # 只分配一次
        self._flat_buffer = np.empty(total, dtype=dtype)

        self._flatten_initialized = True

    def _flatten_obs(self, obs): #替代_obs_to_array
        # 第一次由 reset() 调用时初始化
        if not self._flatten_initialized:
            self._init_flatten_obs(obs)

        buf = self._flat_buffer
        keys = self._flat_keys
        slices = self._flat_slices

        for i in range(len(keys)):
            s, e = slices[i]
            buf[s:e] = obs[keys[i]].ravel()

        return buf

    def _obs_to_array(self, obs):
        # return np.concatenate(list(obs.values())).ravel()
        return np.concatenate([v.flatten() for v in obs.values()])
    #[todo] end

#[todo] start

class PersistentRenderer:
    """高性能持久渲染器，兼容 dm_control==1.0.9"""
    def __init__(self, physics, height, width, camera_id=0):
        self.physics = physics
        self.width = width
        self.height = height
        self.camera_id = camera_id

        # 初始化一次上下文
        from dm_control.mujoco import engine
        self._camera = engine.Camera(
            physics, height=height, width=width, camera_id=camera_id
        )

    def render(self):
        # 直接复用 context，避免重复创建
        return self._camera.render()
#[todo] end