import cv2
import matplotlib.pyplot as plt
from PIL import Image
import matplotlib
import os
import random
import numpy as np
import omnigibson as og
from omnigibson_modified import TraversableMap
from utils import (
    get_trav_map,
    get_robot_position
)
from omnigibson.utils.asset_utils import get_og_scene_path
from omnigibson.robots import Fetch
from omnigibson.scenes.interactive_traversable_scene import InteractiveTraversableScene
from omnigibson.sensors import VisionSensor
from constants import *
from video_recorder import video_recorder
from action import *
from sub_task import cook, pour, wash, clean, pick_and_place

class Task:
    # define sequence task here
    def __init__(self,
                 task_flag: str = "SequencePlanning",
                 task_name: str = "",
                 task_id: int = 0,
                 scene_model: str = "Beechwood_0_int",
                 sub_task_nums: dict = {},
                 **kwargs):
        self.task_flag = task_flag
        assert self.task_flag in ["Rearrangement", "SequencePlanning"]
        self.task_name = task_name
        self.task_id = task_id
        self.scene_model = scene_model
        self.sub_task_nums = sub_task_nums
        self.check = True
        self.reset()
    
    def import_scene(self):
        self.scene = InteractiveTraversableScene(self.scene_model, not_load_object_categories=["ceilings"])
        og.sim.import_scene(self.scene)

        for obj_name in SceneObjects[self.scene_model]:
            obj = import_A_of_B(obj_name, SceneObjects[self.scene_model][obj_name]["model"])
            set_A_at_P(obj, SceneObjects[self.scene_model][obj_name]["pos"])

    def import_objects(self):
        self.sub_task_list = {}
        loaded_obj = set()
        for sub_task in self.sub_task_nums:
            self.sub_task_list[sub_task] = {}
            available_objs = set()
            for obj_type in Subtasks[sub_task]:
                available_objs = available_objs | ObjectInfo[obj_type].keys()
            available_objs = available_objs - loaded_obj
            if sub_task == 'Clean':
                while True:
                    obj_names = random.sample(available_objs, self.sub_task_nums[sub_task])
                    is_all_valid = True
                    for obj_name in obj_names:
                        is_valid_pos = False
                        for obj in list(og.sim.scene.object_registry("category", obj_name)):
                            robot_pos = get_robot_position(obj, self.trav_map, self.trav_map_size)
                            if self.trav_map.has_node(0, np.array([robot_pos[0], robot_pos[1]])):
                                is_valid_pos = True
                                break
                        if not(is_valid_pos):
                            is_valid_pos = False
                    if is_all_valid:
                        break
            obj_names = random.sample(available_objs, self.sub_task_nums[sub_task])
            
            for obj_name in obj_names:
                loaded_obj.add(obj_name)
                print(sub_task, obj_name)
                if sub_task == 'Cook':
                    if "model" in ObjectInfo["food"][obj_name]:
                        obj = import_A_of_B(obj_name, ObjectInfo["food"][obj_name]["model"])
                    else:
                        obj = import_A_of_B(obj_name)
                    while True:
                        obj_pos = self.trav_map.get_random_point(floor=0)[1]
                        set_A_at_P(obj, obj_pos)
                        robot_pos = get_robot_position(obj, self.trav_map, self.trav_map_size)
                        if self.trav_map.has_node(0, np.array([robot_pos[0], robot_pos[1]])):
                            break
                    self.sub_task_list[sub_task][obj_name] = None
                if sub_task == 'Wash':
                    obj = import_A_of_B(obj_name)
                    while True:
                        obj_pos = self.trav_map.get_random_point(floor=0)[1]
                        set_A_at_P(obj, obj_pos)
                        robot_pos = get_robot_position(obj, self.trav_map, self.trav_map_size)
                        if self.trav_map.has_node(0, np.array([robot_pos[0], robot_pos[1]])):
                            break
                    system = random.choice(["dust", "stain"])
                    import_A_on_B(system, obj)
                    self.sub_task_list[sub_task][obj_name] = system
                if sub_task == 'Clean':
                    while True:
                        obj = random.choice(list(og.sim.scene.object_registry("category", obj_name)))
                        robot_pos = get_robot_position(obj, self.trav_map, self.trav_map_size)
                        if self.trav_map.has_node(0, np.array([robot_pos[0], robot_pos[1]])):
                            break
                    system = random.choice(["dust", "stain"])
                    import_A_on_B(system, obj)
                    self.sub_task_list[sub_task][obj.name] = system
                if sub_task == 'Pour':
                    obj = import_A_of_B(obj_name)
                    while True:
                        obj_pos = self.trav_map.get_random_point(floor=0)[1]
                        set_A_at_P(obj, obj_pos)
                        robot_pos = get_robot_position(obj, self.trav_map, self.trav_map_size)
                        if self.trav_map.has_node(0, np.array([robot_pos[0], robot_pos[1]])):
                            break
                    system = random.choice(ObjectInfo["bottles"][obj_name]["contain"])
                    import_A_in_B(system, obj)
                    self.sub_task_list[sub_task][obj_name] = system
                if sub_task == 'PickAndPlace':
                    obj = import_A_of_B(obj_name)
                    while True:
                        obj_pos = self.trav_map.get_random_point(floor=0)[1]
                        set_A_at_P(obj, obj_pos)
                        robot_pos = get_robot_position(obj, self.trav_map, self.trav_map_size)
                        if self.trav_map.has_node(0, np.array([robot_pos[0], robot_pos[1]])):
                            break
                    self.sub_task_list[sub_task][obj_name] = None

    def import_robot(self):
        self.robot = Fetch(
            prim_path="/World/robot",
            name="robot",
            fixed_base=True,
            controller_config={
                "arm_0": {
                    "name": "JointController",
                    "motor_type": "position"
                },
                "gripper_0": {
                    "name": "MultiFingerGripperController",
                    "mode": "binary"
                },
                "camera": {
                    "name": "JointController"
                }
            },
            grasping_mode="sticky",
            obs_modalities=["rgb"]
        )
        og.sim.import_object(self.robot)
        # get random pos in the scene

        pos = self.trav_map.get_random_point(floor=0)[1]
        self.robot.set_position(pos)
        
        # At least one simulation step while the simulator is playing must occur for the robot (or in general, any object)
        # to be fully initialized after it is imported into the simulator
        og.sim.play()
        og.sim.step()
        # Make sure none of the joints are moving
        self.robot.keep_still()
        # Expand the filed of view
        for sensor in self.robot.sensors.values():
            if isinstance(sensor, VisionSensor):
                sensor.horizontal_aperture = 50
                sensor.image_width = 512
                sensor.image_height = 512
    
    def reset(self):
        self.trav_map_img, self.trav_map_size = get_trav_map(self.scene_model)
        self.trav_map = TraversableMap()
        self.trav_map.load_map(os.path.join(get_og_scene_path(
            self.scene_model), "layout"))
        self.trav_map.build_trav_graph(self.trav_map_size, os.path.join(get_og_scene_path(
            self.scene_model), "layout"), 1, self.trav_map_img.copy())
        
        # Allow user to move camera more easily
        og.sim.enable_viewer_camera_teleoperation()
        
        self.import_scene()
        self.import_robot()
        self.import_objects()
    
    def step(self):
        for sub_task in self.sub_task_list:
            for obj_name in self.sub_task_list[sub_task]:
                obj = og.sim.scene.object_registry("name", obj_name)
                if sub_task == 'Cook':
                    cooker = og.sim.scene.object_registry("name", random.choice(SceneInfo[self.scene_model]["cooker"]))
                    bowl = og.sim.scene.object_registry("name", "plate")
                    cook(obj, None, True, False, cooker, bowl, self.robot, self.scene_model,\
                         self.trav_map, self.trav_map_img, self.trav_map_size)
                if sub_task == 'Wash':
                    system = og.sim.scene.system_registry("name", self.sub_task_list[sub_task][obj_name])
                    sink = og.sim.scene.object_registry("name", random.choice(SceneInfo[self.scene_model]["sink"]))
                    wash(system, obj, None, True, False, sink, self.robot, self.scene_model,\
                         self.trav_map, self.trav_map_img, self.trav_map_size)
                if sub_task == 'Clean':
                    system = og.sim.scene.system_registry("name", self.sub_task_list[sub_task][obj_name])
                    clean(system, obj, self.robot, self.scene_model,\
                          self.trav_map, self.trav_map_img, self.trav_map_size)
                if sub_task == 'Pour':
                    system = og.sim.scene.system_registry("name", self.sub_task_list[sub_task][obj_name])
                    mug = og.sim.scene.object_registry("name", "mug")
                    pour(system, obj, None, True, False, mug, self.robot, self.scene_model,\
                         self.trav_map, self.trav_map_img, self.trav_map_size)
                if sub_task == 'PickAndPlace':
                    container = og.sim.scene.object_registry("name", random.choice(SceneInfo[self.scene_model]["container"]))
                    pick_and_place(obj, None, True, False, container, container.get_position(), False, self.robot, self.scene_model,\
                                   self.trav_map, self.trav_map_img, self.trav_map_size)

    def init_figure(self, 
                    camera_pos = np.array([2.32248, -8.74338, 9.85436]),
                    camera_ori = np.array([0.39592, 0.13485, 0.29286, 0.85982]),
                    save_name="0"):
        # plt.figure(figsize=(12, 12))
        # plt.imshow(self.trav_map_img)
        # plt.title(f"Traversable area of {self.scene_model} scene")

        # Update the viewer camera's pose so that it points towards the robot
        og.sim.viewer_camera.set_position_orientation(position=camera_pos, orientation=camera_ori)
        video_recorder.set(camera=og.sim.viewer_camera, robot=self.robot,\
                           save_path=os.path.join(f"{og.root_path}/../../dataset/", f"{self.task_name}", f"{self.scene_model}", f"{self.task_id}"), name=f"{self.task_id}", trav_map_img=self.trav_map_img)
        og.log.info(og.root_path)
        
    # def save_figure(self):
    #     plt.savefig(f"{og.root_path}/../../images/sequence/{self.task_name}.png")
    #     return 
    
    def close(self):
        # video_recorder.release()
        og.sim.stop()