import cv2
import matplotlib.pyplot as plt
from PIL import Image
import matplotlib
import os
import random
import numpy as np
import omnigibson as og
from omnigibson.macros import gm
from omnigibson.objects import DatasetObject
from omnigibson.maps import segmentation_map
from omnigibson_modified import TraversableMap
from utils import (
    get_trav_map,
    convert
)
from omnigibson.utils.asset_utils import get_og_scene_path
from omnigibson.robots import Fetch
from omnigibson.scenes.interactive_traversable_scene import InteractiveTraversableScene
from omnigibson.sensors import VisionSensor
from constants import *
from video_recorder import video_recorder
from action import *
from sub_task import cook, pour, wash, clean, pick_and_place

class Task:
    # define sequence task here
    def __init__(self,
                 task_flag: str = "SequencePlanning",
                 task_name: str = "",
                 scene_model: str = "Beechwood_0_int",
                 sub_task_nums: dict = {},
                 **kwargs):
        self.task_flag = task_flag
        assert self.task_flag in ["Rearrangement", "SequencePlanning"]
        self.task_name = task_name
        self.scene_model = scene_model
        self.sub_task_nums = sub_task_nums
        self.check = True
        self.reset()
    
    def import_scene(self):
        self.scene = InteractiveTraversableScene(self.scene_model, not_load_object_categories=["ceilings"])
        og.sim.import_scene(self.scene)

        for obj_name in SceneObjects[self.scene_model]:
            obj = import_A_of_B(obj_name, SceneObjects[self.scene_model][obj_name]["model"])
            set_A_at_P(obj, SceneObjects[self.scene_model][obj_name]["pos"])

    def import_objects(self):
        self.sub_task_list = {}
        loaded_obj = set()
        for sub_task in self.sub_task_nums:
            self.sub_task_list[sub_task] = {}
            available_objs = set()
            for obj_type in Subtasks[sub_task]:
                available_objs = available_objs | ObjectInfo[obj_type].keys()
            available_objs = available_objs - loaded_obj
            obj_names = random.sample(available_objs, self.sub_task_nums[sub_task])
            
            for obj_name in obj_names:
                loaded_obj.add(obj_name)
                print(sub_task, obj_name)
                obj_pos = self.trav_map.get_random_point(floor=0)[1]
                if sub_task == 'Cook':
                    if "model" in ObjectInfo["food"][obj_name]:
                        obj = import_A_of_B(obj_name, ObjectInfo["food"][obj_name]["model"])
                    else:
                        obj = import_A_of_B(obj_name)
                    set_A_at_P(obj, obj_pos)
                    self.sub_task_list[sub_task][obj_name] = None
                if sub_task == 'Wash':
                    obj = import_A_of_B(obj_name)
                    set_A_at_P(obj, obj_pos)
                    system = random.choice(["dust", "stain"])
                    import_A_on_B(system, obj)
                    self.sub_task_list[sub_task][obj_name] = system
                if sub_task == 'Clean':
                    obj = random.choice(list(og.sim.scene.object_registry("category", obj_name)))
                    set_A_at_P(obj, obj_pos)
                    system = random.choice(["dust", "stain"])
                    import_A_on_B(system, obj)
                    self.sub_task_list[sub_task][obj.name] = system
                if sub_task == 'Pour':
                    obj = import_A_of_B(obj_name)
                    set_A_at_P(obj, obj_pos)
                    system = random.choice(ObjectInfo["bottles"][obj_name]["contain"])
                    import_A_in_B(system, obj)
                    self.sub_task_list[sub_task][obj_name] = system
                if sub_task == 'PickAndPlace':
                    obj = import_A_of_B(obj_name)
                    set_A_at_P(obj, obj_pos)
                    self.sub_task_list[sub_task][obj_name] = None

    def import_robot(self):
        self.robot = Fetch(
            prim_path="/World/robot",
            name="robot",
            fixed_base=True,
            controller_config={
                "arm_0": {
                    "name": "JointController",
                    "motor_type": "position"
                },
                "gripper_0": {
                    "name": "MultiFingerGripperController",
                    "mode": "binary"
                },
                "camera": {
                    "name": "JointController"
                }
            },
            grasping_mode="sticky",
            obs_modalities=["rgb"]
        )
        og.sim.import_object(self.robot)
        # get random pos in the scene

        pos = self.trav_map.get_random_point(floor=0)[1]
        self.robot.set_position(pos)
        
        # At least one simulation step while the simulator is playing must occur for the robot (or in general, any object)
        # to be fully initialized after it is imported into the simulator
        og.sim.play()
        og.sim.step()
        # Make sure none of the joints are moving
        self.robot.keep_still()
        # Expand the filed of view
        for sensor in self.robot.sensors.values():
            if isinstance(sensor, VisionSensor):
                # sensor.focal_length = 4.0
                sensor.horizontal_aperture = 50
    
    def reset(self):
        self.trav_map_img, self.trav_map_size = get_trav_map(self.scene_model)
        self.trav_map = TraversableMap()
        self.trav_map.load_map(os.path.join(get_og_scene_path(
            self.scene_model), "layout"))
        self.trav_map.build_trav_graph(self.trav_map_size, os.path.join(get_og_scene_path(
            self.scene_model), "layout"), 1, self.trav_map_img.copy())
        
        self.import_scene()
        self.import_robot()
        self.import_objects()
    
    def step(self):
        for sub_task in self.sub_task_list:
            for obj_name in self.sub_task_list[sub_task]:
                obj = og.sim.scene.object_registry("name", obj_name)
                if sub_task == 'Cook':
                    cooker = og.sim.scene.object_registry("name", random.choice(SceneInfo[self.scene_model]["cooker"]["name"]))
                    bowl = og.sim.scene.object_registry("name", "bowl")
                    cook(obj, None, True, False, cooker, bowl, self.robot, self.scene_model,\
                         self.trav_map, self.trav_map_img, self.trav_map_size)
                if sub_task == 'Wash':
                    system = og.sim.scene.system_registry("name", self.sub_task_list[sub_task][obj_name])
                    sink = og.sim.scene.object_registry("name", random.choice(SceneInfo[self.scene_model]["sink"]["name"]))
                    wash(system, obj, None, True, False, sink, self.robot, self.scene_model,\
                         self.trav_map, self.trav_map_img, self.trav_map_size)
                if sub_task == 'Clean':
                    system = og.sim.scene.system_registry("name", self.sub_task_list[sub_task][obj_name])
                    clean(system, obj, self.robot, self.scene_model,\
                          self.trav_map, self.trav_map_img, self.trav_map_size)
                if sub_task == 'Pour':
                    system = og.sim.scene.system_registry("name", self.sub_task_list[sub_task][obj_name])
                    mug = og.sim.scene.object_registry("name", "mug")
                    pour(system, obj, None, True, False, mug, self.robot, self.scene_model,\
                         self.trav_map, self.trav_map_img, self.trav_map_size)
                if sub_task == 'PickAndPlace':
                    fridge = random.choice(list(og.sim.scene.object_registry("category", "fridge")))
                    pick_and_place(obj, None, True, False, fridge, fridge.get_position(), False, self.robot, self.scene_model,\
                                   self.trav_map, self.trav_map_img, self.trav_map_size)

    def init_figure(self, 
                    camera_pos = np.array([2.32248, -8.74338, 9.85436]),
                    camera_ori = np.array([0.39592, 0.13485, 0.29286, 0.85982]),
                    save_path=".", 
                    save_name="task_sample_desk_test"):
        # plt.figure(figsize=(12, 12))
        # plt.imshow(self.trav_map_img)
        # plt.title(f"Traversable area of {self.scene_model} scene")

        # Update the viewer camera's pose so that it points towards the robot
        og.sim.viewer_camera.set_position_orientation(position=camera_pos, orientation=camera_ori)
        video_recorder.set(camera=og.sim.viewer_camera, robot=self.robot, \
            save_path=os.path.join(f"{og.root_path}/../../dataset/", save_path), name=save_name,trav_map_img=self.trav_map_img)
        og.log.info(og.root_path)
        
    def save_figure(self):
        plt.savefig(f"{og.root_path}/../../images/sequence/{self.task_name}.png")
        return 
    
    def close(self):
        video_recorder.release()
        og.sim.stop()