import time
import numpy as np
import collections
import matplotlib.pyplot as plt
import dm_env
import os
from libero.libero import benchmark
from libero.libero.envs import OffScreenRenderEnv
import cv2
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))


class RealEnv:

    def __init__(self):
        from libero.libero import get_libero_path
        benchmark_dict = benchmark.get_benchmark_dict()
        task_suite_name = "libero_object"  # can also choose libero_spatial, libero_object, etc.
        task_suite = benchmark_dict[task_suite_name]()

        # retrieve a specific task
        task_id = 3
        
        task = task_suite.get_task(task_id)
        task_name = task.name
        task_description = task.language
        print(task_description)
        task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

        # step over the environment
        env_args = {
            "bddl_file_name": task_bddl_file,
            "camera_heights": 128,
            "camera_widths": 128
        }

        env = OffScreenRenderEnv(**env_args)
        env.seed(1000)

        init_states = task_suite.get_task_init_states(task_id)  # for benchmarking purpose, we fix the a set of initial states
        init_state_id = 0

        self.init_state = init_states[init_state_id]

        self.env = env

    def get_observation(self, ts):
        obs = collections.OrderedDict()
        
        obs['qpos'] = np.concatenate([ts["robot0_gripper_qpos"], ts["robot0_joint_pos"]])
        obs['qvel'] = obs['qpos']
        obs['effort'] = obs['qpos']
        obs['images'] = {
            "agentview_rgb": ts["agentview_image"][::-1],
            "eye_in_hand_rgb": ts["robot0_eye_in_hand_image"]
        }

        return obs

    def get_reward(self):
        return 0

    def reset(self, fake=False):
        ts = self.env.reset()

        self.env.set_init_state(self.init_state)

        dummy = np.zeros(7)
        for _ in range(5):
            ts, _, _, _ = self.env.step(dummy)

        return dm_env.TimeStep(
            step_type=dm_env.StepType.FIRST,
            reward=self.get_reward(),
            discount=None,
            observation=self.get_observation(ts))

    def step(self, action):
        obs, reward, done, info = self.env.step(action)

        return dm_env.TimeStep(
            step_type=dm_env.StepType.MID,
            reward=reward,
            discount=None,
            observation=self.get_observation(obs))


def make_real_env():
    env = RealEnv()
    return env
