import functools
import os
if 'MUJOCO_GL' not in os.environ:
  os.environ['MUJOCO_GL'] = 'egl'
from collections import OrderedDict
import re

from metaworld.envs.mujoco.env_dict import ALL_V2_ENVIRONMENTS

import elements
import embodied
import numpy as np
import cv2

import functools

import elements
import embodied
import gymnasium as gym
import numpy as np


class FromGymMetaworld(embodied.Env):

  def __init__(self, env, obs_key='image', act_key='action', **kwargs):
    if isinstance(env, str):
      self._env = gym.make(env, **kwargs)
    else:
      assert not kwargs, kwargs
      self._env = env
    self._obs_dict = hasattr(self._env.observation_space, 'spaces')
    self._act_dict = hasattr(self._env.action_space, 'spaces')
    self._obs_key = obs_key
    self._act_key = act_key
    self._done = True
    self._info = None

  @property
  def env(self):
    return self._env

  @property
  def info(self):
    return self._info

  @functools.cached_property
  def obs_space(self):
    if self._obs_dict:
      spaces = self._flatten(self._env.observation_space.spaces)
    else:
      spaces = {self._obs_key: self._env.observation_space}
    spaces = {k: self._convert(v) for k, v in spaces.items()}
    return {
        **spaces,
        'reward': elements.Space(np.float32),
        'is_first': elements.Space(bool),
        'is_last': elements.Space(bool),
        'is_terminal': elements.Space(bool),
    }

  @functools.cached_property
  def act_space(self):
    if self._act_dict:
      spaces = self._flatten(self._env.action_space.spaces)
    else:
      spaces = {self._act_key: self._env.action_space}
    spaces = {k: self._convert(v) for k, v in spaces.items()}
    spaces['reset'] = elements.Space(bool)
    return spaces

  def step(self, action):
    if action['reset'] or self._done:
      self._done = False
      obs, info = self._env.reset()
      return self._obs(obs, 0.0, is_first=True)
    if self._act_dict:
      action = self._unflatten(action)
    else:
      action = action[self._act_key]
    obs, reward, terminated, truncated, self._info = self._env.step(action)
    success = self._info.get('success', False)
    self._done = terminated or truncated
    return self._obs(
        obs, reward,
        is_last=bool(self._done),
        is_terminal=bool(terminated),
        success=success)

  def _obs(
      self, obs, reward, is_first=False, is_last=False, is_terminal=False, success=False):
    if not self._obs_dict:
      obs = {self._obs_key: obs}
    obs = self._flatten(obs)
    obs = {k: np.asarray(v) for k, v in obs.items()}
    obs.update(
        reward=np.float32(reward),
        is_first=is_first,
        is_last=is_last,
        is_terminal=is_terminal,
        success=success)
    return obs

  def render(self):
    image = self._env.render('rgb_array')
    assert image is not None
    return image

  def close(self):
    try:
      self._env.close()
    except Exception:
      pass

  def _flatten(self, nest, prefix=None):
    result = {}
    for key, value in nest.items():
      key = prefix + '/' + key if prefix else key
      if isinstance(value, gym.spaces.Dict):
        value = value.spaces
      if isinstance(value, dict):
        result.update(self._flatten(value, key))
      else:
        result[key] = value
    return result

  def _unflatten(self, flat):
    result = {}
    for key, value in flat.items():
      parts = key.split('/')
      node = result
      for part in parts[:-1]:
        if part not in node:
          node[part] = {}
        node = node[part]
      node[parts[-1]] = value
    return result

  def _convert(self, space):
    if hasattr(space, 'n'):
      return elements.Space(np.int32, (), 0, space.n)
    return elements.Space(space.dtype, space.shape, space.low, space.high)


class Metaworld(embodied.Env):
  # def __init__(self, env, obs_key='image', act_key='action', **kwargs):
  def __init__(self, task, repeat=1, size=(64, 64), proprio=True, image=True, camera=-1, seed=0, length=1000):
    if isinstance(task, str):
      if camera == -1:
        camera = 1
      
      envs = create_observable_goal_envs()
      env = envs[task+'-goal-observable'](seed=seed, render_mode='rgb_array', camera_id=camera)
      # env = gym.make(task, **kwargs)    
    
    self._task = task
    self._mwenv = env
    self._env = FromGymMetaworld(self._mwenv, obs_key='observation')
    # self._env = embodied.wrappers.ExpandScalars(self._env)
    self._env = embodied.wrappers.ActionRepeat(self._env, repeat)
    self._env = embodied.wrappers.TimeLimit(self._env, length)
    self._size = size
    self._proprio = proprio
    self._image = image
    self._camera = camera
    

  @functools.cached_property
  def obs_space(self):
    spaces = self._env.obs_space.copy()
    if not self._proprio:
      del spaces['observation']
    key = 'image' if self._image else 'log/image'
    spaces[key] = elements.Space(np.uint8, self._size + (3,))
    spaces['log/success'] = elements.Space(np.float32)
    return spaces

  @functools.cached_property
  def act_space(self):
    return self._env.act_space

  def step(self, action):
    for key, space in self.act_space.items():
      if not space.discrete:
        assert np.isfinite(action[key]).all(), (key, action[key])
    obs = self._env.step(action)
    if not self._proprio:
      del obs['observation']
    key = 'image' if self._image else 'log/image'
    obs[key] = self.render()
    obs['log/success'] = obs['success']
    del obs['success']
    return obs

  def render(self, mode='rgb_array'):
    rgb = self._mwenv.render()
    rgb = cv2.resize(rgb, self._size)
    rgb = cv2.rotate(rgb, 1)  # specific to certaim camera_id
    
    return rgb


def create_observable_goal_envs():
    observable_goal_envs = {}

    for env_name, env_cls in ALL_V2_ENVIRONMENTS.items():
        d = {}

        def initialize(env, seed=None, render_mode=None, camera_id=None):
            if seed is not None:
                st0 = np.random.get_state()
                np.random.seed(seed)
            super(type(env), env).__init__(render_mode=render_mode)
            
            if camera_id is not None:
                 env.camera_id = camera_id
            env._partially_observable = False
            env._freeze_rand_vec = False
            env._set_task_called = True
            env.reset()
            env._freeze_rand_vec = True
            if seed is not None:
                env.seed(seed)
                np.random.set_state(st0)

        d["__init__"] = initialize
        og_env_name = re.sub(
            r"(^|[-])\s*([a-zA-Z])", lambda p: p.group(0).upper(), env_name
        )
        og_env_name = og_env_name.replace("-", "")

        og_env_key = f"{env_name}-goal-observable"
        og_env_name = f"{og_env_name}GoalObservable"
        ObservableGoalEnvCls = type(og_env_name, (env_cls,), d)
        observable_goal_envs[og_env_key] = ObservableGoalEnvCls

    return OrderedDict(observable_goal_envs)