from simple_rl.mdp.MDPClass import MDP
from simple_rl.mdp.StateClass import State
import virtualhome
import random
import atexit
import time
import numpy as np
import itertools

from unity_simulator import comm_unity as comm_unity
from evolving_graph import utils
from .utils import *

class SimpleRLVirtualHomeEnv(MDP):
    def __init__(self,
                 max_episode_length=200,
                 observation_types=None,
                 use_editor=False,
                 base_port=8080,
                 port_id=0,
                 executable_args={},
                 recording_options={'recording': False, 
                                    'output_folder': None, 
                                    'file_name_prefix': None,
                                    'cameras': 'PERSON_FROM_BACK',
                                    'modality': 'normal'},
                 seed=123,
                 restriction_dict=None,
                 task_id=0,
                 handmade_reward_fn=None,
                 gamma=0.95):

        self.handmade_reward_fn = handmade_reward_fn

        self.seed = seed
        self.prev_reward = 0.
        # self.rnd = random.Random(seed)
        # np.random.seed(seed)
        self.task_id = task_id


        self.steps = 0
        self.env_id = 0
        self.max_ids = {}


        self.num_agents = 1
        self.max_episode_length = max_episode_length
        self.actions_available = [
            # 'turnleft',
            # 'walkforward',
            # 'turnright',
            'walk',
            # 'run',
            # 'walktowards',
            'open',
            'close',
            'put',
            'grab',
            # 'no_action'
        ]
        self.restriction_dict = restriction_dict

        self.recording_options = recording_options
        self.base_port = base_port
        self.port_id = port_id
        self.executable_args = executable_args

        # Observation parameters
        self.num_camera_per_agent = 6
        self.CAMERA_NUM = 1  # 0 TOP, 1 FRONT, 2 LEFT..
        self.default_image_width = 300
        self.default_image_height = 300

        if observation_types is not None:
            self.observation_types = observation_types
        else:
            self.observation_types = ['full' for _ in range(self.num_agents)]

        
        self.agent_info = {
            0: 'Chars/Female1',
            # 1: 'Chars/Male1'
        }
        

        self.changed_graph = True
        self.rooms = None
        self.id2node = None
        self.num_static_cameras = None


        if use_editor:
            # Use Unity Editor
            self.port_number = 8080
            self.comm = comm_unity.UnityCommunication()
        else:
            # Launch the executable
            self.port_number = self.base_port + port_id
            # ipdb.set_trace()
            self.comm = comm_unity.UnityCommunication(port=str(self.port_number), **self.executable_args)
        atexit.register(self.close)
        obs = self.reset()
        super().__init__(actions=self.actions_available, transition_func=self.step, reward_func=self.reward, init_state=obs)


    def close(self):
        self.comm.close()

    def relaunch(self):
        self.comm.close()
        self.comm = comm_unity.UnityCommunication(port=str(self.port_number), **self.executable_args)

    def reward(self, state, action, next_state, get_terminal_condition=False):
        # This function is kind of unnecessary, since it just called self.handmade_reward_fn which is passed in via the constructor

        if self.handmade_reward_fn is not None:
            reward, done = self.handmade_reward_fn(state, action, next_state)
        else:
            reward, done = 0, False

        if get_terminal_condition:
            return reward * self.gamma ** self.steps, done
        else:
            return reward * self.gamma ** self.steps

    def step(self, state=None, action=None):
        # print(action)
        # TODO: potential try catch here
        script_list = convert_action({0: action})
        if len(script_list[0]) > 0:
            if self.recording_options['recording']:
                success, message = self.comm.render_script(script_list,
                                                           recording=True,
                                                           output_folder=self.recording_options['output_folder'],)
                                                        #    skip_animation=False,)
                                                        #    file_name_prefix='task_{}'.format(self.task_id),
                                                        #    image_synthesis=self.recording_options['modality'])
            else:
                success, message = self.comm.render_script(script_list,
                                                           recording=False,
                                                           skip_animation=True)
            if not success:
                print(message)
            else:
                self.changed_graph = True

        self.steps += 1
        obs = self.get_observations()

        # info['finished'] = done
        # info['graph'] = graph

        next_state = State(data=obs)

        _, done = self.reward(state, action, next_state, get_terminal_condition=True)
        if self.steps == self.max_episode_length:
            done = True
        return State(data=obs, is_terminal=done)

    def reset(self, environment_graph=None, environment_id=None, init_rooms=None, seed=None):
        """
        :param environment_graph: the initial graph we should reset the environment with
        :param environment_id: which id to start
        :param init_rooms: where to intialize the agents
        """
        self.env_id = environment_id
        # print("Resetting env", self.env_id)

        if self.env_id is not None:
            self.comm.reset(self.env_id)
        else:
            self.comm.reset()

        s,g = self.comm.environment_graph()
        if self.env_id not in self.max_ids.keys():
            max_id = max([node['id'] for node in g['nodes']])
            self.max_ids[self.env_id] = max_id

        max_id = self.max_ids[self.env_id]
        #print(max_id)
        if environment_graph is not None:
            # TODO: this should be modified to extend well
            # updated_graph = utils.separate_new_ids_graph(environment_graph, max_id)
            updated_graph = environment_graph
            success, m = self.comm.expand_scene(updated_graph)
        else:
            success = True

        if not success:
            print("Error expanding scene")
            # pdb.set_trace()
            return None
        self.num_static_cameras = self.comm.camera_count()[1]

        if init_rooms is None or init_rooms[0] not in ['kitchen', 'bedroom', 'livingroom', 'bathroom']:
            rooms = ['kitchen', 'livingroom'] # self.rnd.sample(['kitchen', 'bedroom', 'livingroom', 'bathroom'], 2)
        else:
            rooms = list(init_rooms)

        for i in range(self.num_agents):
            if i in self.agent_info:
                self.comm.add_character(self.agent_info[i], initial_room=rooms[i])
            else:
                self.comm.add_character()

        _, self.init_unity_graph = self.comm.environment_graph()


        self.changed_graph = True
        graph = self.get_graph()
        self.rooms = [(node['class_name'], node['id']) for node in graph['nodes'] if node['category'] == 'Rooms']
        self.id2node = {node['id']: node for node in graph['nodes']}

        obs = self.get_observations()
        self.steps = 0
        self.prev_reward = 0.
        time.sleep(0.1)
        return State(data=obs, is_terminal=False)

    def get_graph(self):
        if self.changed_graph:
            s, graph = self.comm.environment_graph()
            # if not s:
            #     pdb.set_trace()
            self.graph = graph
            self.changed_graph = False
        return self.graph

    def get_observations(self):
        dict_observations = {}
        for agent_id in range(self.num_agents):
            obs_type = self.observation_types[agent_id]
            dict_observations[agent_id] = self.get_observation(agent_id, obs_type)
        return dict_observations

    def get_action_space(self):
        dict_action_space = {}
        for agent_id in range(self.num_agents):
            if self.observation_types[agent_id] not in ['partial', 'full', 'full_trimmed', 'full_trimmed_large']:
                raise NotImplementedError
            else:
                # Even if you can see all the graph, you can only interact with visible objects
                obs_type = self.observation_types[agent_id]

            visible_graph = self.get_observation(agent_id, obs_type)
            dict_action_space[agent_id] = [node['id'] for node in visible_graph['nodes']]
        return dict_action_space

    def get_observation(self, agent_id, obs_type, info={}):
        if obs_type == 'partial':
            # agent 0 has id (0 + 1)
            curr_graph = self.get_graph()
            return utils.get_visible_nodes(curr_graph, agent_id=(agent_id+1))

        elif obs_type == 'full':
            return self.get_graph()
        
        elif obs_type == "full_trimmed":
            untrimmed = self.get_graph()
            # CAN_OPEN: microwave, fridge -> 313, 305
            # SURFACES: bookshelf -> 249
            # GRABBABLE: salmon, pie -> 327, 319
            obj_ids = [249, 305, 313, 319, 327]
            # all rooms: bathroom, bedroom, kitchen, livingroom
            room_ids = [11, 73, 205, 335]
            # there should only be one character with id == 1
            char_ids = [n["id"] for n in untrimmed["nodes"] if n["category"] == "Characters"]
            assert char_ids[0] == 1 and len(char_ids) == 1
            
            # trimmed nodes are all objects plus kitchen (205) and character
            trimmed_nodes = [n for n in untrimmed["nodes"] if n['id'] in obj_ids + [205] + char_ids]
            # only include the edge if both objects are rooms or in the trimmed state graph
            # all objects are inside kitchen, and salmon is on microwave
            trimmed_edges = [e for e in untrimmed["edges"] if e["from_id"] in obj_ids + room_ids + char_ids and e["to_id"] in obj_ids + room_ids + char_ids]

            # NOTE: We may want to do some additional filtering on the nodes, as each node contains properties
            # that may break our abstractions. E.g. 'obj_transform' is probably not necessary
            property_blacklist = ["obj_transform", 'prefab_name', 'bounding_box']
            for n in trimmed_nodes:
                for p in property_blacklist:
                    if p in n:
                        del n[p]
            
            return {"nodes": trimmed_nodes, "edges": trimmed_edges}

        elif obs_type == "full_trimmed_large":
            untrimmed = self.get_graph()
            # CAN_OPEN: microwave, fridge, cabinet (has surface but not used) -> 313, 305, 415
            # SURFACES: bookshelf, sofa, kitchentable -> 249, 368, 231
            # GRABBABLE: salmon, pie, toothpaste, remotecontrol, cereal -> 327, 319, 62, 452, 334
            obj_ids = [249, 305, 313, 319, 327, 62, 452, 334, 368, 231, 415]
            # all rooms: bathroom, bedroom, kitchen, livingroom
            room_ids = [11, 73, 205, 335]
            # there should only be one character with id == 1
            char_ids = [n["id"] for n in untrimmed["nodes"] if n["category"] == "Characters"]
            assert char_ids[0] == 1 and len(char_ids) == 1
            
            # trimmed nodes are all objects plus all rooms and character
            trimmed_nodes = [n for n in untrimmed["nodes"] if n['id'] in obj_ids + room_ids + char_ids]
            # only include the edge if both objects are rooms or in the trimmed state graph
            # all objects are inside kitchen, and salmon is on microwave
            trimmed_edges = [e for e in untrimmed["edges"] if e["from_id"] in obj_ids + room_ids + char_ids and e["to_id"] in obj_ids + room_ids + char_ids]
            
            # NOTE: We may want to do some additional filtering on the nodes, as each node contains properties
            # that may break our abstractions. E.g. 'obj_transform' is probably not necessary
            property_blacklist = ["obj_transform", 'prefab_name', 'bounding_box']
            for n in trimmed_nodes:
                for p in property_blacklist:
                    if p in n:
                        del n[p]
            
            return {"nodes": trimmed_nodes, "edges": trimmed_edges}

        elif obs_type == 'visible':
            # Only objects in the field of view of the agent
            raise NotImplementedError

        elif obs_type == 'image':
            camera_ids = [self.num_static_cameras + agent_id * self.num_camera_per_agent + self.CAMERA_NUM]
            if 'image_width' in info:
                image_width = info['image_width']
                image_height = info['image_height']
            else:
                image_width, image_height = self.default_image_width, self.default_image_height
            if 'mode' in info:
                current_mode = info['mode']
            else:
                current_mode = 'normal'
            s, images = self.comm.camera_image(camera_ids, mode=current_mode, image_width=image_width, image_height=image_height)
            # if not s:
            #     pdb.set_trace()
            return images[0]
        else:
            raise NotImplementedError
        
    def get_parameters(self):
        return {}


        
