import threading
import functools
import embodied
import gym
import numpy as np
import os

from . import from_gymnasium


MYOSUITE_TASKS = {
  'myo-reach': 'myoHandReachFixed-v0',
  'myo-reach-hard': 'myoHandReachRandom-v0',
  'myo-pose': 'myoHandPoseFixed-v0',
  'myo-pose-hard': 'myoHandPoseRandom-v0',
  'myo-obj-hold': 'myoHandObjHoldFixed-v0',
  'myo-obj-hold-hard': 'myoHandObjHoldRandom-v0',
  'myo-key-turn': 'myoHandKeyTurnFixed-v0',
  'myo-key-turn-hard': 'myoHandKeyTurnRandom-v0',
  'myo-pen-twirl': 'myoHandPenTwirlFixed-v0',
  'myo-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
}

class MyoSuitePixel(embodied.Env):

  def __init__(self, task, repeat=1, size=(64, 64), render=True):
    task = "-".join(task.split("_"))
    if task.startswith('leg'):
      render=False
    task = 'myo-' + task
    print(task)
    from myosuite.utils import gym as myogym
    os.environ["MUJOCO_GL"] = "egl"
    if not task in MYOSUITE_TASKS:
      raise ValueError('Unknown Task:', task)
    self._env = myogym.make(MYOSUITE_TASKS[task])
    self._env = from_gymnasium.FromGymnasium(self._env, obs_key='state')
    self._camera_id = 'hand_side_inter'
    self._size = size
    self._repeat = repeat
    self._render = render

    # self._obs_dict = hasattr(self._env.observation_space, 'spaces')
    # self._act_dict = hasattr(self._env.action_space, 'spaces')
		# self._num_frames = 3
		# self._frames = deque([], maxlen=self._num_frames)

		# self.state_space = self.env.observation_space
		# self.action_space = self.env.action_space

  def step(self, action):
    if action['reset']:
      obs = self._env.step(action)
      obs['log_success'] = False
      self._once = True
    else:
      reward, success = 0.0, False
      for _ in range(self._repeat):
        obs = self._env.step(action)
        success = success or self._env.info.get('solved', 0)
        reward += obs['reward']
        if obs['is_last'] or obs['is_terminal']:
          break
      obs['reward'] = reward
      obs['log_success'] = success
    if self._render:
      obs['image'] = self._env._env.sim.renderer.render_offscreen(
      *self._size, camera_id=self._camera_id).copy()
    # if self._mode == 'eval':
    #   if obs['log_success'] and self._once:
    #     obs['reward'] = 1.0
    #     self._once = False
    #   else:
    #     obs['reward'] = 0.0
    
    return obs

  @functools.cached_property
  def obs_space(self):
    spaces = self._env.obs_space.copy()
    if self._render:
      spaces['image'] = embodied.Space(np.uint8, self._size + (3,))
    spaces['log_success'] = embodied.Space(bool)
    return spaces

  @functools.cached_property
  def act_space(self):
    return self._env.act_space

  def render(self, *args, **kwargs):
    return self._env.render()