import os
import copy
import json
import yaml
import open3d as o3d
from tqdm import tqdm
from openai import OpenAI
from hovsg.graph.graph import Graph
from omegaconf import DictConfig

class Nav3DSG:
    def __init__(self, api_key, params):
        self.client = OpenAI(api_key=api_key)
        self.hovsg = Graph(params)
        self.hovsg.load_graph(params.main.graph_path)

        with open("REVERIE_Navigation/mp3d_room_types.txt", "r") as f:
            room_types = f.readlines()
        room_types = [r.strip() for r in room_types]

        # generate room names
        self.hovsg.generate_room_names(
                generate_method="view_embedding",
                default_room_types=room_types)
    
        # backup the original graph
        self.hovsg_backup = copy.deepcopy(self.hovsg)
        self.assigned_id = {}
    
    def generate_navigation_steps(self, instruction):
        prompt = f"""
        You are given a scene graph representing an entire house, including floors, rooms, and objects. Each element in the house is connected by relations, such as spatial adjacency or containment.

        I will provide a language instruction, and you need to break it down into structured querying steps. Each step should be formatted as follows:

        Step X: Target Type, Target Name, Target Description, Search Area

        Guidelines:
        1. Focus only on navigation and spatial querying steps—do not include actions related to manipulating objects.
        2. Identify floors, rooms, and objects as needed to reach the final target.
        3. If a target is described relative to another feature (e.g., "next to the bathroom"), encode this relationship explicitly in the description.
        4. When a target from a previous step appears in a later step's description, assign it an ID to maintain clarity and avoid ambiguity.
        5. The target description should provide enough information to uniquely identify the target.
        6. Make the steps as many as needed to reach the final target.

        Examples:
        Example 1:
        Instruction: "Go to the dining room on level 2 that has a corded black telephone on the wall at the start of the room and tell me the time on the overhead clock."
        Output:
        Step 0, Floor, Level 2, None, all floors
        Step 1: Object, corded black telephone, on the wall at the start of the room, inside Target_0
        Step 2, Room, dining room, has Target_1, inside Target_0
        Step 3, Object, overhead clock, None, inside Target_2

        Example 2:
        Instruction: "Go to the hallway next to the bathroom on Level 4 and look at the flower picture."
        Output:
        Step 0, Floor, Level 4, None, all floors  
        Step 1, Room, hallway, next to bathroom, inside Target_0
        Step 2, Object, flower picture, None, inside Target_2

        Example 3:
        Instruction: "Open the cabinet by the toilet in the bathroom with a floral ring design in the sink."
        Output:
        Step 0, Object, sink, with a floral ring design, all floors  
        Step 1, Room, bathroom, has Target_0, all rooms
        Step 3, Object, cabinet, by the toilet, inside Target_1

        Now process the following instruction: "{instruction}"
        Output:
        """
        
        completion = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "You are an expert in parsing natural language instructions into structured navigation steps."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.0,
        )
        
        return completion.choices[0].message
    
    def prune_graph(self, search_area):
        """
        Prune the graph based on the search area
        Only prune rooms and floors, objects should remain the same
        """
        # print("******************** Pruning Graph ********************")
        # generat the description of the whole scene graph for LLM
        floors = self.hovsg.floors
        description = f"The scene graph contains {len(floors)} floors: "
        for floor in floors:
            description += f"{floor.name}, "
            rooms = floor.rooms
            description += f"which contains {len(rooms)} rooms: "
            for room in rooms:
                description += f"(Id: {room.room_id}, Coordinate: {room.pcd.get_center()}); "
        
        # generate the pruned graph based on the search area
        for k, v in self.assigned_id.items():
            if k in search_area:
                search_area = search_area.replace(k, v)
        # print(f"Search Area: {search_area}")
        
        system_prompt = f"""
        Based on the description of the scene graph, prune the graph to only include '{search_area}'.
        Note that if a floor is pruned, all rooms in that floor should be pruned as well.
        And if a room is in the search area, its containing floor should also be included.
        Only return the id of floors and rooms that are part of the search area in the following format:
        Floors: id1, id2, ...
        Rooms: id1, id2, ...
        """
        completion = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": description}
            ],
            temperature=0.0,
        )
        response = completion.choices[0].message.content
        
        rooms = response.split("\n")[1].split("Rooms: ")[1].split(", ")    
        rooms = [x.strip().replace("Room ", "") for x in rooms]
        basic_floors = [room.floor_id for room in self.hovsg.rooms if room.room_id in rooms]

        try:
            floors = response.split("\n")[0].split("Floors: ")[1].split(", ")
            floors = basic_floors + [x.strip().split('_')[1] for x in floors]
            floors = list(set(floors))
        except:
            floors = list(set(basic_floors))

        # avoid prunning floors and rooms in assigned_id
        for k, v in self.assigned_id.items():
            if "Floor" in v:
                floor_id = v.split(' ')[1]
                floors.append(floor_id)
            elif "Room" in v:
                room_id = v.split(' ')[1]
                rooms.append(room_id)
                room = [room for room in self.hovsg.rooms if room.room_id == room_id][0]
                floors.append(room.floor_id)
            elif "Object" in v:
                obj_id = v.split(' ')[1]
                obj = [obj for obj in self.hovsg.objects if obj.object_id == obj_id][0]
                rooms.append(obj.room_id)
                room = [room for room in self.hovsg.rooms if room.room_id == obj.room_id][0]
                floors.append(room.floor_id)
            else:
                raise ValueError(f"Invalid target type: {v}")
        
        floors = list(set(floors))
        rooms = list(set(rooms))

        self.hovsg.prune_floors(floors)
        self.hovsg.prune_rooms(rooms)

        # print(f"Pruned Floors: {floors}")
        # print(f"Pruned Rooms: {rooms}")
        # print([floor.floor_id for floor in self.hovsg.floors])
        # print([room.room_id for room in self.hovsg.rooms])

    def analyse_description(self, target_name, target_desc):
        """
        Analyse the target description to extract the relevant information
        """
        prompt = f"""
        I have a scene graph which contains floors, rooms, and objects. Each element is connected by relations, such as spatial adjacency or containment.
        For a given target, analyze its description as follows:
        1. If the description mentions other entities, use the following structure as the Query Text to represen their relationships: <subject; relation; object; relation; ...>. Note that the first subject must be the target name.
        2. Generate the Relation Type, which is the entitiy types in the Query Text, which can be floor, room, or object.
        3. If the descrption doesn't mention other entities, return a sentence to describe the target.
        
        Return the result in the format: Query Type, Relation Type, Query Text

        Examples:
        Example 1: Target Name: "bathroom", Target Description: "with flower wallpaper"
        Output: scene graph, room-object, <bathroom; has; flower wallpaper>

        Example 2: Target Name: "kitchen", Target Description: "between floor1 and floor2"
        Output: scene graph, room-floor-floor, <kitchen; between; floor1; and; floor2>

        Example 3: Target Name: "sink", Target Description: "with a floral ring design"
        Output: CLIP, None, "sink with a floral ring design"

        Now process the following target: Target Name: "{target_name}", Target Description: "{target_desc}"
        Output:
        """
        completion = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "You are an expert in analysing target descriptions."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.0,
        )
        response = completion.choices[0].message.content
        # print(f"Analyse Description with Target Name: {target_name}, Target Description: {target_desc}")
        # print(response)
        return response

    def query_room_sg(self, target_name, relation_type, query_text, target_desc):
        query_link = query_text.replace("<", "").replace(">", "").split("; ")
        etts = query_link[::2]
        assert etts[0] == target_name, "Target name should be the subject of the query link"
        ett_types = relation_type.split('-')

        cand1_ids, scores = self.hovsg.query_room(etts[0], query_method='Nav3DSG')
        cand_scores = {cand: score for cand, score in zip(cand1_ids, scores)}
        cand1_rooms = [room for room in self.hovsg.rooms if room.room_id in cand1_ids]
        prompt = f"""
I want to locate the '{etts[0]}' that has the following relations '{target_desc}' in a house.
There are {len(cand1_rooms)} candidates for the target '{etts[0]}':
"""
        for room in cand1_rooms:
            prompt += f"Room {room.room_id}, Type: {room.name}, Coordinate: {room.pcd.get_center()})\n"
        
        for ett2, ett_type in zip(etts[1:], ett_types[1:]):
            if ett2 in self.assigned_id:
                prompt += f"The '{ett2}' has been identified as {self.assigned_id[ett2]}.\n"
                if 'Object' in self.assigned_id[ett2]:
                    obj_id = self.assigned_id[ett2].split(' ')[1]
                    obj = [obj for obj in self.hovsg.objects if obj.object_id == obj_id][0]
                    prompt += f"Object {obj.object_id}, belongs to Room {obj.room_id}, Type: {obj.name}, Coordinate: {obj.pcd.get_center()})\n"
                elif 'Room' in self.assigned_id[ett2]:
                    room_id = self.assigned_id[ett2].split(' ')[1]
                    room = [room for room in self.hovsg.rooms if room.room_id == room_id][0]
                    prompt += f"Room {room.room_id}, Type: {room.name}, Coordinate: {room.pcd.get_center()})\n"
                else:
                    raise ValueError(f"Invalid target type: {self.assigned_id[ett2]}")
            else:
                if ett_type == "object":
                    cand2_ids, _ = self.hovsg.query_object_all(ett2)    
                    cand2_objs = [obj for obj in self.hovsg.objects if obj.object_id in cand2_ids]
                    prompt += f"There are {len(cand2_objs)} '{ett2}' in the house.\n"
                    for obj in cand2_objs:
                        prompt += f"Object {obj.object_id}, belongs to Room {obj.room_id}, Type: {obj.name}, Coordinate: {obj.pcd.get_center()})\n"        
                elif ett_type == "room":
                    cand2_ids, _ = self.hovsg.query_room(ett2, query_method='Nav3DSG')
                    cand2_rooms = [room for room in self.hovsg.rooms if room.room_id in cand2_ids]
                    prompt += f"There are {len(cand2_rooms)} '{ett2}' in the house.\n"
                    for room in cand2_rooms:
                        prompt += f"Room {room.room_id}, Type: {room.name}, Coordinate: {room.pcd.get_center()})\n"
                else:
                    raise ValueError(f"Invalid relation type: {relation_type}")

        prompt += f"Based on the description, return all the rooms that satisfy the specified relation '{target_desc}'."
        prompt += f"\nOnly return the id of the {etts[0]} in the format 'Rooms: id1, id2, ...'.\n"
        prompt += f"If no rooms satisfy the relation, return 'None'."

        completion = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "You are an expert in finding the target room based on the relation."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.0,
        )
        response = completion.choices[0].message.content

        if "None" in response:
            return [], []
        try:
            room_ids = response.split("Rooms: ")[1].strip().split(", ")
        except:
            print(response)
            exit(0)
        try:
            scores = [cand_scores[room_id] for room_id in room_ids] 
        except:
            print(prompt)
            print(response)
            print(cand_scores)
            print(room_ids)
            exit(0) 
        return room_ids, scores

    def query_obj_sg(self, target_name, relation_type, query_text, target_desc):
        query_link = query_text.replace("<", "").replace(">", "").split("; ")
        etts = query_link[::2]
        assert etts[0] == target_name, "Target name should be the subject of the query link"
        ett_types = relation_type.split('-')

        cand1_ids, scores = self.hovsg.query_object_all(etts[0])
        cand_scores = {cand: score for cand, score in zip(cand1_ids, scores)}
        cand1_objs = [obj for obj in self.hovsg.objects if obj.object_id in cand1_ids]
        prompt = f"""
I want to locate the '{etts[0]}' that has the following relation '{target_desc}' in a house.
There are {len(cand1_objs)} candidates for the target '{etts[0]}'. 
"""
        for obj in cand1_objs:
            prompt += f"Object {obj.object_id}, belongs to Room {obj.room_id}, Type: {obj.name}, Coordinate: {obj.pcd.get_center()})\n"
        
        for ett2, ett_type in zip(etts[1:], ett_types[1:]):
            if ett2 in self.assigned_id:
                prompt += f"The '{ett2}' has been identified as {self.assigned_id[ett2]}.\n"
                if 'Object' in self.assigned_id[ett2]:
                    obj_id = self.assigned_id[ett2].split(' ')[1]
                    obj = [obj for obj in self.hovsg.objects if obj.object_id == obj_id][0]
                    prompt += f"Object {obj.object_id}, belongs to Room {obj.room_id}, Type: {obj.name}, Coordinate: {obj.pcd.get_center()})\n"
                elif 'Room' in self.assigned_id[ett2]:
                    room_id = self.assigned_id[ett2].split(' ')[1]
                    try:
                        room = [room for room in self.hovsg.rooms if room.room_id == room_id][0]
                    except:
                        print(room_id)
                        print(self.assigned_id[ett2])
                        exit(0)
                    prompt += f"Room {room.room_id}, Type: {room.name}, Coordinate: {room.pcd.get_center()})\n"
                else:
                    raise ValueError(f"Invalid target type: {self.assigned_id[ett2]}")
            else:
                if ett_type == "object":
                    cand2_ids, _ = self.hovsg.query_object_all(ett2)    
                    cand2_objs = [obj for obj in self.hovsg.objects if obj.object_id in cand2_ids]
                    prompt += f"There are {len(cand2_objs)} '{ett2}' in this house.\n"
                    for obj in cand2_objs:
                        prompt += f"Object {obj.object_id}, belongs to Room {obj.room_id}, Type: {obj.name}, Coordinate: {obj.pcd.get_center()})\n"        
                elif ett_type == "room":
                    cand2_ids, _ = self.hovsg.query_room(ett2, query_method='Nav3DSG')
                    cand2_rooms = [room for room in self.hovsg.rooms if room.room_id in cand2_ids]
                    prompt += f"There are {len(cand2_rooms)} '{ett2}' in this house.\n"
                    for room in cand2_rooms:
                        prompt += f"Room {room.room_id}, Type: {room.name}, Coordinate: {room.pcd.get_center()})\n"
                else:
                    print(relation_type)
                    raise ValueError(f"Invalid relation type: {relation_type}")

        prompt += f"Based on the description, return all the '{etts[0]}' that satisfy the specified relation '{target_desc}'."
        prompt += f"\nOnly return the id of the objects in the format 'Objects: id1, id2, ...'."
        prompt += f"\nIf no objects satisfy the relation, return 'None'."

        completion = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "You are an expert in finding the target object based on the relation"},
                {"role": "user", "content": prompt}
            ],
            temperature=0.0,
        )
        response = completion.choices[0].message.content
        if "None" in response:
            return [], []
        if "*" in response:
            response = response.replace("*", "")
        object_ids = response.split("Objects: ")[1].strip().split(", ")
        try:
            scores = [cand_scores[obj_id] for obj_id in object_ids]
        except:
            print(prompt)
            print(response)
            print(cand_scores)
            print(object_ids)
            exit(0)

        return object_ids, scores

    def locate_floor(self, target_name, target_desc):
        assert "None" in target_desc, "Floor description should be None"
        floor_id = self.hovsg.query_floor(target_name)
        # print("******************** Locating Target ********************")
        # print(f"Target Name: {target_name}")
        # print(f"Target Description: {target_desc}")
        # print(f"Target Floor ID: {floor_id}")
        return [floor_id], [1.0]

    def locate_room(self, target_name, target_desc):
        # print("******************** Locating Room ********************")
        if "None" in target_desc:
            room_ids, scores = self.hovsg.query_room(target_name, query_method='Nav3DSG')
        else:
            analysis = self.analyse_description(target_name, target_desc)
            Query_Type, Relation_Type, Query_Text = analysis.split(', ')

            if "scene graph" in Query_Type:
                room_ids, scores = self.query_room_sg(target_name, Relation_Type, Query_Text, target_desc)
            else:
                assert "CLIP" in Query_Type, "Invalid query type"
                assert "None" in Relation_Type, "Invalid relation type"
                room_ids, scores = self.hovsg.query_room(Query_Text, query_method='Nav3DSG')

        return room_ids, scores

    def locate_object(self, target_name, target_desc):
        # print("******************** Locating Object ********************")
        if "None" in target_desc:
            obj_ids, scores = self.hovsg.query_object_all(target_name)
            # print(obj_ids, scores)
        else:
            analysis = self.analyse_description(target_name, target_desc)
            Query_Type, Relation_Type, Query_Text = analysis.split(', ')

            if "scene graph" in Query_Type:
                obj_ids, scores = self.query_obj_sg(target_name, Relation_Type, Query_Text, target_desc)
            else:
                assert "CLIP" in Query_Type, "Invalid query type"
                assert "None" in Relation_Type, "Invalid relation type"
                obj_ids, scores = self.hovsg.query_object_all(Query_Text)

        return obj_ids, scores

    def select_target(self, candidates):
        return candidates[0]

    def execute_step(self, step):
        # print('='*50)
        # print(step)
        try:
            individual_steps = step.split(', ')
            if len(individual_steps) == 5:
                step_num, target_type, target_name, target_desc, search_area = individual_steps
            elif len(individual_steps) > 5:
                step_num, target_type, target_name = individual_steps[:3]
                target_desc = ', '.join(individual_steps[3:-1])
                search_area = individual_steps[-1]
            else:
                raise ValueError("Invalid step format")
        except:
            print(step)
            exit(0)

        # prune the graph based on the search area
        self.prune_graph(search_area)

        # locate the target
        if target_type == "Floor":
            candidates, scores = self.locate_floor(target_name, target_desc)
        elif target_type == "Room":
            candidates, scores = self.locate_room(target_name, target_desc)
        elif target_type == "Object":
            candidates, scores = self.locate_object(target_name, target_desc)
        else:
            raise ValueError(f"Invalid target type: {target_type}")

        if type(candidates) is not list:
            candidates = [candidates]
        
        # print(f"Candidates: {candidates}")

        return candidates, scores, target_type

    def navigate(self, instruction):
        self.assigned_id = {}
        steps = self.generate_navigation_steps(instruction).content
        # print(f"Instruction: {instruction}")
        # print(f"Navigation Steps: {steps}")
        steps = [x.strip() for x in steps.split("\n") if x.startswith("Step")]
        
        all_paths = []
        
        def dfs(idx, path, total_score):
            if idx == len(steps):
                all_paths.append((path, total_score))
                return
            
            step = steps[idx]
            candidates, scores, target_type = self.execute_step(step)

            hovsg_backup = copy.deepcopy(self.hovsg)
            
            for candidate, score in zip(candidates, scores):
                # print(f"Trying Candidate: {candidate} for Step {idx}")
                self.assigned_id['Target_'+str(idx)] = target_type + ' ' + str(candidate)
                dfs(idx + 1, path + [(target_type, candidate)], total_score + score)
                
                self.hovsg = copy.deepcopy(hovsg_backup)

        dfs(0, [], 0)
        
        # for i, (path, score) in enumerate(all_paths):
        #     print(f"Path {i+1}: {path} with Score: {score}")
        
        self.reset()

        if all_paths == []:
            return [], None
        final_path = max(all_paths, key=lambda x: x[1])[0]
        target = final_path[-1][1]
        
        return final_path, target

    def test(self):
        self.assigned_id = {}
        steps = ['Step 0, Floor, Level 1, None, all floors  ', 'Step 1, Room, kitchen, None, inside Target_0  ', 'Step 2, Room, bathroom, with flower wallpaper, next to Target_1  ', 'Step 3, Object, desk, None, inside Target_2']
        steps = [x.strip() for x in steps]
        for idx, step in enumerate(steps):
            print('='*50)
            print(step)
            step_num, target_type, target_name, target_desc, search_area = step.split(', ')

            # prune the graph based on the search area
            self.prune_graph(search_area)

            # locate the target
            if target_type == "Floor":
                candidates = self.locate_floor(target_name, target_desc)
            elif target_type == "Room":
                candidates = self.locate_room(target_name, target_desc)
            elif target_type == "Object":
                candidates = self.locate_object(target_name, target_desc)
            else:
                raise ValueError(f"Invalid target type: {target_type}")

            if type(candidates) is not list:
                candidates = [candidates]
            
            print(f"Candidates: {candidates}")
            # specify the target
            if len(candidates) == 1:
                target = candidates[0]
            else:
                print(f"Multiple candidates found for {target_name} in {search_area}")
                target = self.select_target(candidates)
            
            print(f"Target: {target}")
            self.assigned_id['Target_'+str(idx)] = target_type + ' ' + str(target)

        return target

    def reset(self):
        self.hovsg.__dict__.clear()  # Clear current attributes
        self.hovsg.__dict__.update(copy.deepcopy(self.hovsg_backup.__dict__))

def main():
    api_key = os.getenv("OPENAI_API_KEY")
    param_path="config/Nav3DSG.yaml"
    with open(param_path, "r") as f:
        params = yaml.safe_load(f)
    params = DictConfig(params)

    agent = Nav3DSG(api_key=api_key, params=params)

    # inst = "Go the living room with the windows in back and grab the black pillow on the far right under the window"
    # path, target = agent.navigate(inst)
    # print(f"Final Path: {path}")
    # print(f"Final Target: {target}")
    # target_obj = [obj for obj in agent.hovsg.objects if obj.object_id == target][0]
    
    # obj_pcd = target_obj.pcd.paint_uniform_color([0, 1, 0])
    # room_pcd = [room.pcd for room in agent.hovsg.rooms if room.room_id == target_obj.room_id][0]
    # o3d.visualization.draw_geometries([room_pcd, obj_pcd])
    # print(f"Final Target Object: {target_obj.name}")
    # print(f"Final Target Object: {target_obj.pcd.get_center()}")
    # exit(0)    

    target_scan = "8WUmhLawc2A"
    with open(f"REVERIE_Navigation/{target_scan}_enc.json", "r") as f:
        nav_data = json.load(f)
    
    if not os.path.exists(f"REVERIE_Navigation/query_results_Nav3DSG_{target_scan}.json"):
        results = {}
    else:
        with open(f"REVERIE_Navigation/query_results_Nav3DSG_{target_scan}.json", "r") as f:
            results = json.load(f)
    
    flag = False
    for item in tqdm(nav_data):
        item_id = item['id']
        for idx, inst in enumerate(item['instructions']):
            inst_id = f"{item_id}_{idx}"
            if inst_id == "4224_495_2":
                flag = True

            if not flag:
                print(f"Skipping {inst_id}")
                continue

            if inst_id in results:
                print(f"Skipping {inst_id}")
                continue

            path, target = agent.navigate(inst)
            if target is None:
                continue
            target_obj = [obj for obj in agent.hovsg.objects if obj.object_id == target][0]

            results[inst_id] = {
                "query": inst,
                "target_name": target_obj.name,
                "target_id": target,
                "target_position": target_obj.pcd.get_center().tolist()
            }
    
            with open(f"REVERIE_Navigation/query_results_Nav3DSG_{target_scan}.json", "w") as f:
                json.dump(results, f, indent=4)

if __name__ == "__main__":
    main()