import openai
import time
import os
import pickle
import json
from collections import deque
from tools.utils import generate_scene_graph_json, get_path_and_adjacency, \
                        prompt_make, check_path_and_adjacency, generate_scene_graph_total_json, \
                        parse_region_file, check_collab_path_efficient_sim_graph, \
                        find_target_position, nearest_navpoint_to_object_vec, cluster_center_node
import networkx as nx
import math
import random

import re
def clean_task_instruction(instr: str) -> str:
    """
    Remove numeric suffixes like '_6' *and* numbers that follow a space after
    room names, e.g. 'Bedroom 10' -> 'Bedroom'. Then convert underscores to
    spaces and squeeze multiple spaces.
    """
    # 1) Strip underscore-number patterns, e.g., Laundry_Room_6 -> Laundry_Room
    instr = re.sub(r'(_\d+)', '', instr)

    # 2) Strip sequences like 'Kitchen 12' -> 'Kitchen' (words followed by space + digits)
    instr = re.sub(r'\b([A-Za-z]+(?: [A-Za-z]+)*)\s+\d+\b', r'\1', instr)

    # 3) Convert underscores to spaces
    instr = instr.replace('_', ' ')

    # 4) Compress multiple spaces
    instr = re.sub(r'\s{2,}', ' ', instr)

    return instr.strip()

def convert(name):
    match = re.search(r"'(.*?)'", name)
    if match:
        obj = match.group(1)
        parts = re.match(r'(.+?)(_)(\d+)$', obj)
        if parts:
            prefix = parts.group(1).replace('_', ' ')
            result = f"{prefix}_{parts.group(3)}"
            return name.replace(obj, result)
    return name

def call_llm_collab_agent_with_reasoning(prompt_path, scene_graph_json, single_graph_json, target_region_key, phase_0_list, phase_1_list):

    tools = [
        {
            "name": "check_collab_path_efficient_sim_graph",
            "description": "Check spatial relationship between robot_2's candidate start region (region_a) with robot_1's start region, target object region and end region.",
            "parameters": {
                "type": "object",
                "properties": {
                    "robot_2_start_region": {
                        "type": "string",
                        "description": "e.g., Bedroom_1"
                    },
                    "transfer_region": {
                        "type": "string",
                        "description": "e.g., Bathroom_5"
                    },
                    "transfer_asset":{
                        "type": "string",
                        "description": "e.g. table"
                    },
                    "collab_type": {
                        "type": "string",
                        "description": "e.g., Type-A1"
                    }
                },
                "required": ["robot_2_start_region", "transfer_region", "transfer_asset", "collab_type"]
            }
        }
    ]

    openai.default_headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {API_KEY}"
    }

    with open(prompt_path + "rule.txt", "r") as f:
        prompt_rule = f.read()

    with open(prompt_path + "example.txt", "r") as f:
        prompt_example = f.read()

    prompt_input = "Please observe the above rules strictly. Think step by step.\nINPUT:\n```\nscene graph: " + json.dumps(scene_graph_json) + "\nsingle agent task: " + json.dumps(single_graph_json) + "\n```\nOUTPUT:"
    # import pdb; pdb.set_trace()
    prompt = prompt_rule + "\n" + prompt_example + "\n" + prompt_input

    prompt_system, prompt = prompt_make(prompt_path + "system.txt", prompt)

    messages = [
        {"role": "system", "content": prompt_system},
        {"role": "user", "content": prompt}
    ]

    max_retries = 5

    gpt_inference_model = "gpt-4o-mini" # "cc-3-5-haiku-20241022"

    for attempt in range(max_retries):
        try:
            response = openai.chat.completions.create(
                model=gpt_inference_model, # gpt-4o-mini
                messages=messages,
                functions=tools,
                function_call='auto',
                # tools=tools,
                # tool_choice="auto",
                timeout=30)
            # import pdb; pdb.set_trace()
            msg = response.choices[0].message
            # import pdb; pdb.set_trace()
            print(msg)
            tool_calls = 0
            MAX_TOOL_CALLS = 5
            # import pdb; pdb.set_trace()
            while msg.function_call:
                print("LLM call locatin function for path checking...")
                # print(msg)
                tool_calls += 1
                if tool_calls > MAX_TOOL_CALLS:
                    print("⚠️  Max retries, return None.")
                    return None
                
                func_call = msg.function_call
                func_name = func_call.name
                args = json.loads(func_call.arguments)

                if func_name == "check_collab_path_efficient_sim_graph":
                    print("Funcation Calling...")

                    # import pdb; pdb.set_trace()
                    if '_'+args["transfer_region"].split('_')[-1] == target_region_key:
                        print("Same transfer region and target object region")
                        result = {
                            "efficient": False,
                        }

                    else:
                        check_result = check_collab_path_efficient_sim_graph(
                            args["robot_2_start_region"],
                            args["transfer_region"],
                            args["transfer_asset"],
                            args['collab_type'],
                            phase_1_list[0],  # solo_ids
                            phase_1_list[1],  # start_1 nodes
                            phase_1_list[2],  # target node
                            phase_1_list[3],  # end node
                            phase_1_list[4],  # parse_file 
                            phase_1_list[-1]  # G
                        )

                        result = {
                            "efficient": check_result['efficient'],
                        }

                else:
                    result = {"error": "Unknown function"}

                if result["efficient"]:
                    print("Sample success...")
                    valid_combo = (
                        args["robot_2_start_region"],
                        args["transfer_region"],
                        args["transfer_asset"],
                        args['collab_type']
                    )
                    messages.append({
                        "role": "user",
                        "content": f"Great, we confirmed that this combination is efficient:\n"
                                f"robot_2_start_region = {valid_combo[0]}\n"
                                f"transfer_region = {valid_combo[1]}\n"
                                f"transfer_asset = {valid_combo[2]}\n"
                                f"target_object_region = {phase_0_list[0]}\n"
                                f"end_asset_region = {phase_0_list[1]}\n"
                                f"collab_type = {valid_combo[2]}\n"
                                f"You MUST use this exact combination in your final output. Do NOT invent new ones."
                    })
                else:
                    print("Low efficient sample，user feedback...")
                    messages.append({
                        "role": "user",
                        "content": (
                            "We confirmed that the current collaboration setup is INEFFICIENT.\n"
                            "- Re-sample a new robot_2_start_region.\n"
                            "- Re-sample a new transfer_region.\n"
                            "DO NOT reuse any region (start or transfer) that was chosen in the previous attempt."
                        )
                    })

                messages.append({
                    "role": "function",
                    "name": func_name,
                    "content": json.dumps(result)
                })
                print("Update path...")
                response = openai.chat.completions.create(
                    model=gpt_inference_model, # gpt-4o-mini
                    messages=messages,
                    functions=tools,
                    function_call="auto",
                    # tools=tools,
                    # tool_choice="auto",
                    timeout=30,
                )
                msg = response.choices[0].message

            content = msg.content
            return [json.loads(content), check_result]
        
        except (openai.OpenAIError, json.JSONDecodeError, Exception) as e:
            time.sleep(1)

def is_duplicate(habitat_json, save_dir):
    for file in os.listdir(save_dir):
        if file.endswith(".json"):
            file_path = os.path.join(save_dir, file)
            with open(file_path, "r") as f:
                try:
                    existing_json = json.load(f)
                    if existing_json == habitat_json:
                        return True  
                except Exception as e:
                    print(f"[WARN] Failed to parse {file_path}: {e}")
    return False

def insert_temp_point(tmp_id, G: nx.Graph, point_xyz, anchor_node_id, 
                      node_prefix="tmp_obj") -> str:
    G.add_node(tmp_id, position=point_xyz)

    x1, y1, *rest1 = point_xyz
    x2, y2, *rest2 = G.nodes[anchor_node_id]['position']
    dist_xy = math.hypot(x2 - x1, y2 - y1)   

    G.add_edge(tmp_id, anchor_node_id, weight=dist_xy)
    G.add_edge(anchor_node_id, tmp_id, weight=dist_xy)

    return tmp_id

def extract_json_from_message(message_content: str):
    """
    Extract and parse the JSON object from an LLM ChatCompletionMessage content
    string which may be wrapped in ```json ... ``` fences.

    Returns
    -------
    dict
        The parsed JSON as a Python dict.

    Raises
    ------
    ValueError
        If valid JSON cannot be located.
    """
    # Remove triple‑backtick fences if present
    txt = message_content.strip()
    if txt.startswith("```"):
        txt = "\n".join(
            ln for ln in txt.splitlines()
            if not ln.strip().startswith("```")
        )

    try:
        return json.loads(txt)
    except json.JSONDecodeError as exc:
        raise ValueError(f"Failed to parse JSON from LLM content: {exc}") from exc

def extract_transfer_asset(subtask_list):
    """
    Given a subtask list like ["Grab('rug')", ...] return the asset name
    ('rug'). We look for the first Grab(...) action.
    """
    for action in subtask_list:
        m = re.match(r"Grab\('([^']+)'\)", action.strip())
        if m:
            return m.group(1)
    return None

if __name__ == '__main__':

    API_KEY = '...'
    openai.api_key = API_KEY # Set the OpenAI API key

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch', type=int, help="batch index")
    parser.add_argument('--record_idx', type=str, help="record timestamp")
    parser.add_argument('--max_tries', type=int, help="let LLM to try how many times")
    parser.add_argument('--max_iters', type=int, help="let LLM to try how many times")
    parser.add_argument('--floor_num', type=str, help="floor number")
    args = parser.parse_args()

    batch_idx = args.batch
    record_idx = args.record_idx
    max_tries = args.max_tries
    max_iters = args.max_iters
    floor_num = args.floor_num
    
    batch_path = f'.../nav_graph_generation/database/floor_{floor_num}/record/record_{record_idx}/batch_{batch_idx}' 
    batch_raw_path = f'.../nav_graph_generation/database/floor_{floor_num}/record/record_{record_idx}/batch_raw_{batch_idx}' 
    prompt_path = '.../nav_graph_generation/navcraft_c_prompts/' 

    for task in os.listdir(batch_path):
        single_agent_task = os.path.join(batch_raw_path, task, 'single_config.json')

        with open(single_agent_task, 'r') as f:
            single_agent_json = json.load(f)

        habitat_single_agent_task = os.path.join(batch_path, task, 'single_config.json')
        with open(habitat_single_agent_task, 'r') as f:
            habitat_agent_json = json.load(f)

        scene_idx = single_agent_json['Scene']
        print(scene_idx)
        if str(floor_num) != '1':
            floor_idx = habitat_agent_json['Floor']
            graph_connect_dir = os.path.join(f'.../nav_graph_generation/G_T/G_{floor_num}', f"{scene_idx}_{floor_idx}.pkl")
            graph_region_index_dir = os.path.join(f'.../nav_graph_generation/G_region_index_T/G_region_index_{floor_num}', f"{scene_idx}_{floor_idx}.json")
        else:
            graph_connect_dir = os.path.join('.../nav_graph_generation/G_T/G_1', f"{scene_idx}.pkl")
            graph_region_index_dir = os.path.join('.../nav_graph_generation/G_region_index_T/G_region_index_1', f"{scene_idx}.json")
        
        region_assets_dir = os.path.join('.../nav_gen/scene', f"{scene_idx}.txt")

        with open(graph_connect_dir, 'rb') as f:
            G_orin = pickle.load(f)

        G = G_orin.copy()
        for u, v in G.edges():
            # import pdb; pdb.set_trace()
            x1, y1 = G.nodes[u]['position'][0], G.nodes[u]['position'][1]
            x2, y2 = G.nodes[v]['position'][0], G.nodes[v]['position'][1]
            euclid = math.hypot(x2 - x1, y2 - y1)   # sqrt(dx**2 + dy**2)
            # import pdb; pdb.set_trace()
            G.edges[u, v]['weight'] = euclid        
        
        with open(graph_region_index_dir, 'r') as f:
            G_regionid = json.load(f)

        scene_graph_total_json = generate_scene_graph_total_json(G, G_regionid, region_assets_dir)

        scene_graph_json = generate_scene_graph_json(G, G_regionid, region_assets_dir)

        save_dir = os.path.join(batch_path, task, 'collab_json')
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        start_pos = single_agent_json['phase_1']['Single robot start region']['robot_1']
        target_pos = single_agent_json['phase_1']['Single robot target object region']['robot_1']
        end_pos = single_agent_json['phase_1']['Single robot end region']['robot_1']

        region_objects = parse_region_file(os.path.join('.../nav_gen', 'scene', scene_idx+'.txt'))

        target_object = habitat_agent_json["Object"][0][0]
        end_asset = habitat_agent_json["Object"][1][0]
        target_region_key = '_'+target_pos.split('_')[-1]
        end_asset_key = '_'+end_pos.split('_')[-1]
        target_object_pos = find_target_position(region_objects, target_region_key, target_object)
        end_asset_pos = find_target_position(region_objects, end_asset_key, end_asset)

        nearest_target_node, _, _ = nearest_navpoint_to_object_vec(
                    target_object_pos,
                    target_region_key,
                    nx.nodes(G),
                    nx.get_node_attributes(G, 'position'),  
                )
        # insert targt obj into scene graph 上;
        insert_temp_point('targetobj_'+target_object, G, target_object_pos, nearest_target_node)
        
        nearest_end_asset_node, _, _ = nearest_navpoint_to_object_vec(
                    end_asset_pos,
                    end_asset_key,
                    nx.nodes(G),
                    nx.get_node_attributes(G, 'position'),  # node_view[0:3]
                )
        
        insert_temp_point('endasset_'+end_asset, G, end_asset_pos, nearest_end_asset_node)

        start_pos_key = '_'+start_pos.split('_')[-1]
        start_node = cluster_center_node(start_pos_key, nx.nodes(G), nx.get_node_attributes(G, 'position'))
        try:
            r1_to_target = nx.shortest_path_length(G,
                                    source=start_node,
                                    target=nearest_target_node,
                                    weight='weight')  
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            continue

        try:
            target_to_end = nx.shortest_path_length(G,
                                    source=nearest_target_node,
                                    target=nearest_end_asset_node,
                                    weight='weight')
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            continue
        
        solo_cost = r1_to_target + target_to_end
        print("solo distance: ", solo_cost)
        print("solo robot 1 from start to target distance: ", r1_to_target)
        print("solo robot 1 from target to end distance: ", target_to_end)
        # import pdb; pdb.set_trace()
        user_feed_target = target_object + target_region_key
        user_feed_end = end_asset + end_asset_key
        # import pdb; pdb.set_trace()
        # check_collab_path_efficient("Bedroom_9", "Tie_1", "Type-A1", start_pos, target_pos, end_pos, G)
        
        forward_path = check_path_and_adjacency(start_pos, target_pos, G, G_regionid)['path']
        backward_path = check_path_and_adjacency(target_pos, end_pos, G, G_regionid)['path']

        single_agent_json['phase_1']['Single robot travel path']['robot_1'] = forward_path + backward_path

        for i in range(max_tries):
            print(i)
            if len(os.listdir(save_dir)) >= max_iters:
                # print("Max content")
                continue
        
            llm_pkg = call_llm_collab_agent_with_reasoning(prompt_path, scene_graph_json, single_agent_json, target_region_key, [user_feed_target, user_feed_end], [solo_cost, start_node, 'targetobj_'+target_object, 'endasset_'+end_asset, region_objects, G])
    
            # print(llm_pkg)
            if llm_pkg != None:

                answer, check_result = llm_pkg[0], llm_pkg[1]

                if check_result['efficient']:
                
                    print("Path is ok!")
                    print(answer)

                    habitat_json = {}
                    habitat_json['robot_1'] = {}
                    habitat_json['robot_2'] = {}
                    
                    habitat_json['robot_1']["Task instruction"] = clean_task_instruction(answer["phase_2"]['Subtask instruction']['robot_1'])
                    habitat_json['robot_2']["Task instruction"] = clean_task_instruction(answer["phase_2"]['Subtask instruction']['robot_2'])

                    habitat_json['robot_1']['Subtask list'] = [convert(s) for s in answer["phase_2"]['Subtask list']['robot_1']]
                    habitat_json['robot_2']['Subtask list'] = [convert(s) for s in answer["phase_2"]['Subtask list']['robot_2']]

                    # habitat_json["Subtask list"] = answer["phase_1"]["Subtask list"]
                    habitat_json['robot_1']["Robot"] = habitat_agent_json['Robot']
                    habitat_json['robot_2']["Robot"] = random.choice(['spot', 'fetch'])

                    habitat_json["Scene"] = scene_idx
                    
                    if str(floor_num) != '1':
                        habitat_json['Floor'] = floor_idx

                    habitat_json["Type"] = answer["phase_2"]["Collaborative type"]
                    print(habitat_json["Type"])
                    habitat_json["G_Rate"] = check_result['g_rate']
                    habitat_json["R1_Rate"] = check_result['r1_rate']
                                            
                    input_scene = {}
                    for entry in scene_graph_total_json["floor_1"]["item"]:
                        # Split region string by '_' and extract id
                        parts = entry['region'].split('_')
                        region_id = parts[-1]              # numeric id as string
                        region_type = ' '.join(parts[:-1]) # join remaining parts for full type
                        combined_items = entry['asset'] + entry['object']
                        output_key = f"Region {region_id}: {region_type}"
                        input_scene[output_key] = combined_items
                    
                    valid = True

                    for iter in range(1, 3):
                        wrong = False
                        objs = []
                        # import pdb; pdb.set_trace()
                        for task in habitat_json[f'robot_{iter}']["Subtask list"]:
                            if "Move_to" in task:
                                obj_id = task[9:-2]
                                obj = obj_id.split("_")
                                ok = None
                                for k in list(input_scene.keys()):
                                    if obj[1] in k:
                                        ok = k
                                        break
                                if ok == None:
                                    wrong = True
                                    # import pdb; pdb.set_trace()
                                    break
                                objs.append([obj[0], ok])
                                if obj[0].replace(' ', '_') not in input_scene[ok]:
                                    wrong = True
                                    # import pdb; pdb.set_trace()
                                    break
                        if "region" in habitat_json[f'robot_{iter}']['Task instruction'] or "Region" in habitat_json[f'robot_{iter}']['Task instruction']:
                            wrong = True
                        if wrong:
                            valid = False
                            break
                        length = len(objs)
                        habitat_json[f'robot_{iter}']['Object'] = objs
                        habitat_json[f'robot_{iter}']["Start"] = answer["phase_2"]["Collaborative robot start region"][f'robot_{iter}']
          
                    if valid:
                        if is_duplicate(habitat_json, save_dir):
                            print("[SKIP] Duplicate habitat json, not saving.")
                        else:
                            exist = str(len(os.listdir(save_dir))).zfill(2)
                            habitat_save_path = os.path.join(save_dir, f'collab_config_{exist}.json')                    
                            with open(habitat_save_path, 'w') as f:
                                json.dump(habitat_json, f, indent=2)
                            print(f"[OK] Saved habitat json: {habitat_save_path}")
                    else:
                        print(f"[SKIP] Invalid habitat json, not saved.")