import openai
import random
import time
import os
import pickle
import json
from collections import deque
from tools.utils import generate_scene_graph_json, check_path_and_adjacency, prompt_make, check_two_path_and_adjacency, parse_region_file
import copy


def call_llm_single_agent_with_reasoning(prompt_path, scene_graph_json, G, G_regionid):

    tools = [
        {
            "name": "check_two_path_and_adjacency",
            "description": "Check if two regions are connected and adjacent, and return path info",
            "parameters": {
                "type": "object",
                "properties": {
                    "start_region": {
                        "type": "string",
                        "description": "e.g., Bedroom_1"
                    },
                    "target_region": {
                        "type": "string",
                        "description": "e.g., Bathroom_5"
                    },
                    "end_region": {
                        "type": "string",
                        "description": "e.g., Kichen_0"
                    },

                },
                "required": ["start_region", "target_region", "end_region"]
            }
        }
    ]

    openai.default_headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {API_KEY}"
    }

    with open(prompt_path + "rule.txt", "r") as f:
        prompt_rule = f.read()

    with open(prompt_path + "example.txt", "r") as f:
        prompt_example = f.read()

    with open(prompt_path + "rolo.json", "r") as f:
        prompts_role = json.load(f)
    
    random_role_idx = random.randrange(len(prompts_role))
    # import pdb; pdb.set_trace()

    prompt_input = "Please observe the above rules strictly. Think step by step.\nINPUT:\n```\nscene graph: " \
        + json.dumps(scene_graph_json) \
        + "\nuser profile: " + json.dumps(prompts_role[random_role_idx])  + "```\nOUTPUT:"
    
    # import pdb; pdb.set_trace()
    
    prompt = prompt_rule + "\n" + prompt_example + "\n" + prompt_input

    prompt_system, prompt = prompt_make(prompt_path + "system.txt", prompt)

    messages = [
        {"role": "system", "content": prompt_system},
        {"role": "user", "content": prompt}
    ]

    max_retries = 3
    # import pdb; pdb.set_trace()
    gpt_model_type = 'gpt-4o-mini'

    for attempt in range(max_retries):
        try:
            # import pdb; pdb.set_trace()
            response = openai.chat.completions.create(
                model=gpt_model_type, # "gpt-4o-mini",
                messages=messages,
                functions=tools,
                function_call='auto',
                timeout=30
            )
            # import pdb; pdb.set_trace()
            msg = response.choices[0].message
            # import pdb; pdb.set_trace()
            tool_calls = 0
            MAX_TOOL_CALLS = 5
            while msg.function_call:
                print("LLM call locatin function for path checking...")
                print(msg)
                tool_calls += 1
                if tool_calls > MAX_TOOL_CALLS:
                    print("⚠️ Max retries, return None.")
                    return None, None

                func_call = msg.function_call
                func_name = func_call.name
                args = json.loads(func_call.arguments)

                if func_name == "check_two_path_and_adjacency":
                    result = check_two_path_and_adjacency(
                        args["start_region"],
                        args["target_region"],
                        args["end_region"],
                        G,
                        G_regionid
                    )
                else:
                    result = {"error": "Unknown function"}
            
                valid_combo = (
                        args["start_region"],
                        args["target_region"],
                        args["end_region"]
                    )

                if result["s2t_valid"] and not result["t2e_valid"]:
                    print("✅ start → target is ok, target → end is not ok，resample the end。")
                    messages.append({
                        "role": "user",
                        "content": (
                            f"Great, we confirmed that this combination is partially valid:\n"
                            f"- robot_1 start region = {valid_combo[0]}\n"
                            f"- target object region = {valid_combo[1]}\n"
                            f"You MUST keep the start and target object region fixed.\n"
                            f"DO NOT change them.\n"
                            f"Please resample a new delivery (end) region that meets all constraints."
                        )
                    })

                elif not result["s2t_valid"] and result["t2e_valid"]:
                    print("✅ target → end is ok, start → target is not ok，resample the start。")
                    messages.append({
                        "role": "user",
                        "content": (
                            f"Great, we confirmed that this combination is partially valid:\n"
                            f"- target object region = {valid_combo[1]}\n"
                            f"- delivery (end) region = {valid_combo[2]}\n"
                            f"You MUST keep the object target and end region fixed.\n"
                            f"DO NOT change them.\n"
                            f"Please resample a new start region sthat meets all constraints."
                        )
                    })

                elif not result["s2t_valid"] and not result["t2e_valid"]:
                    print("❌ both path is not ok，retry it。")
                    messages.append({
                        "role": "user",
                        "content": (
                            f"The current combination is invalid.\n"
                            f"Both paths (start → target, target → end) violate constraints.\n"
                            f"Please resample a completely new combination of:\n"
                            f"- start region\n"
                            f"- target object region\n"
                            f"- end region\n"
                        )
                    })
                elif result["valid"]:
                    print("✅ both path is ok")
                    messages.append({
                        "role": "user",
                        "content": (
                            f"The following region combination has been validated and meets all constraints:\n\n"
                            f"- robot_1 start region: {valid_combo[0]}\n"
                            f"- target object region: {valid_combo[1]}\n"
                            f"- delivery (end) region: {valid_combo[2]}\n\n"
                            f"You MUST use **exactly** this region combination in your final output.\n"
                            f"DO NOT re-sample regions.\n"
                            f"DO NOT invent new regions.\n"
                        )
                    })

                # messages.append(func_call)
                messages.append({
                    "role": "function",
                    "name": func_name,
                    "content": json.dumps(result)
                })
                print("Update Path...")
                # print(result)
                # import pdb; pdb.set_trace()
                response = openai.chat.completions.create(
                    model=gpt_model_type, # "gpt-4o-mini", # gpt-4o-mini
                    messages=messages,
                    functions=tools,
                    function_call="auto",
                    timeout=30
                )
                msg = response.choices[0].message
                # print("update msg: ", msg)

            content = msg.content
            # print("final content: ", content)
            return json.loads(content), random_role_idx

        except (openai.OpenAIError, json.JSONDecodeError, Exception) as e:
            # print(f"[Retry {attempt+1}] Failed: {e}")
            time.sleep(1)

    return None, None


# ---------------------------------------------------------------------------
# Utilities & v2 interface (no function-calling loop)
# ---------------------------------------------------------------------------

def extract_json_from_message(message_content: str):
    """
    Strip optional ```json ... ``` fences from LLM output and parse as JSON.
    """
    txt = message_content.strip()
    if txt.startswith("```"):
        txt = "\n".join(
            ln for ln in txt.splitlines()
            if not ln.strip().startswith("```")
        )
    return json.loads(txt)

from copy import deepcopy
def build_subtask_list(task_json: dict, region_assets_dir):
    region_assets = parse_region_file(region_assets_dir)
    phase = task_json["phase_1"]

    # 目标对象
    target_object = phase["Target object"].replace('_', ' ')   # e.g. "appliance"
    obj_region    = phase["Single robot target object region"]["robot_1"]
    obj_regid     = obj_region.split("_")[-1]
    obj_identifier = f"{target_object}_{obj_regid}" 

    end_region    = phase["Single robot end region"]["robot_1"]          # Bathroom_4
    end_regid     = end_region.split("_")[-1]                            # 4
    region_key    = f"_{end_regid}"
    end_assets    = region_assets.get(region_key, [])
    end_asset_names = [a["name"] for a in end_assets] 
    
    search_content = deepcopy(phase["Task instruction"])
    search_content = search_content.replace(target_object, "").lower()
    
    chosen_asset = None
    # import pdb; pdb.set_trace()
    for name in end_asset_names:
        if name.lower() in search_content:
            chosen_asset = name
            break

    if chosen_asset is None:
        if not end_asset_names:
            raise ValueError(f"Region {end_region} no asset is available.")
        # chosen_asset = end_asset_names[0]
        # phase["Task instruction"] = (
        #     f"Take the {target_object} from {obj_region} "
        #     f"to the {chosen_asset} in {end_region}."
        # )
        return None
    # chosen_asset.replace(' ', '_')
    asset_identifier = f"{chosen_asset}_{end_regid}"                     # sink_4
    phase["Subtask list"] = [
        f"Move_to('{obj_identifier}')",
        f"Grab('{target_object}')",
        f"Move_to('{asset_identifier}')",
        f"Release('{target_object}')"
    ]
    return task_json

import re
def clean_task_instruction(instr: str) -> str:
    """
    Remove numeric suffixes like '_6' *and* numbers that follow a space after
    room names, e.g. 'Bedroom 10' -> 'Bedroom'. Then convert underscores to
    spaces and squeeze multiple spaces.
    """
    # 1) Strip underscore-number patterns, e.g., Laundry_Room_6 -> Laundry_Room
    instr = re.sub(r'(_\d+)', '', instr)

    # 2) Strip sequences like 'Kitchen 12' -> 'Kitchen' (words followed by space + digits)
    instr = re.sub(r'\b([A-Za-z]+(?: [A-Za-z]+)*)\s+\d+\b', r'\1', instr)

    # 3) Convert underscores to spaces
    instr = instr.replace('_', ' ')

    # 4) Compress multiple spaces
    instr = re.sub(r'\s{2,}', ' ', instr)

    return instr.strip()

from pathlib import Path
def collect_ids(dir_path: str, scene_idx: str, ext: str) -> set[int]:
    ids = set()
    prefix = scene_idx + "_"
    for fname in os.listdir(dir_path):
        if fname.startswith(prefix) and fname.endswith(ext):
            stem = Path(fname).stem         
            tail = stem.split("_")[-1]       
            if tail.isdigit():
                ids.add(int(tail))
    return ids

if __name__ == '__main__':

    API_KEY = '...'
    openai.api_key = API_KEY # Set the OpenAI API key

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch', type=int, help="batch index")
    parser.add_argument('--total_num', type=int, help="number of task to generate")
    parser.add_argument('--max_iters', type=int, help="let LLM to try how many times")
    parser.add_argument('--record_idx', type=str, help="record timestamp")
    parser.add_argument('--floor_num', type=str, help="floor number")
    args = parser.parse_args()

    batch_idx = args.batch
    total_num = args.total_num
    max_iters = args.max_iters
    record_idx = args.record_idx
    floor_num = args.floor_num

    if not os.path.exists(f'.../nav_graph_generation/database/floor_{floor_num}/record/record_{record_idx}'):
        os.makedirs(f'.../nav_graph_generation/database/floor_{floor_num}/record/record_{record_idx}')
                        
    prompt_path = '.../nav_graph_generation/navcraft_s_prompts/' 
    save_dir = f'.../nav_graph_generation/database/floor_{floor_num}/record/record_{record_idx}/batch_raw_{batch_idx}'
    habitat_save_dir = f'.../nav_graph_generation/database/floor_{floor_num}/record/record_{record_idx}/batch_{batch_idx}'

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    # habitat_save_dir = os.path.join(prompt_path, 'habitat_'+name_template)
    if not os.path.exists(habitat_save_dir):
        os.mkdir(habitat_save_dir)
        
    scene_idx_list = []

    with open("scene_floors.txt", "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()

            if len(parts) == 2 and parts[1] == str(floor_num):
                scene_tmp = parts[0].replace(".txt", "")
                scene_idx_list.append(scene_tmp)

    for total_num_indx in range(total_num):

        scene_idx = random.choice(scene_idx_list)

        if str(floor_num) == '1':
            graph_connect_dir = os.path.join(f'.../nav_graph_generation/G_T/G_{floor_num}', f"{scene_idx}.pkl")
            graph_region_index_dir = os.path.join(f'.../nav_graph_generation/G_region_index_T/G_region_index_{floor_num}', f"{scene_idx}.json")
        else:
            pkl_ids = collect_ids(f'.../nav_graph_generation/G_T/G_{floor_num}', scene_idx, ".pkl")
            json_ids = collect_ids(f'.../nav_graph_generation/G_region_index_T/G_region_index_{floor_num}', scene_idx, ".json")
            if sorted(pkl_ids & json_ids):
                floor_idx = random.choice(sorted(pkl_ids & json_ids))
                graph_connect_dir = os.path.join(f'.../nav_graph_generation/G_T/G_{floor_num}', f"{scene_idx}_{floor_idx}.pkl")
                graph_region_index_dir = os.path.join(f'.../nav_graph_generation/G_region_index_T/G_region_index_{floor_num}', f"{scene_idx}_{floor_idx}.json")
            else:
                continue
            # import pdb; pdb.set_trace()
        region_assets_dir = os.path.join('.../nav_gen/scene', f"{scene_idx}.txt")

        with open(graph_connect_dir, 'rb') as f:
            G = pickle.load(f)

        with open(graph_region_index_dir, 'r') as f:
            G_regionid = json.load(f)
        # import pdb; pdb.set_trace()
        scene_graph_json = generate_scene_graph_json(G, G_regionid, region_assets_dir)

        for i in range(max_iters):
            
            answer, role_idx = call_llm_single_agent_with_reasoning(prompt_path, scene_graph_json, G, G_regionid)

            print("Path is OK!")
            print(answer)

            if answer != None:
                
                answer_copy = copy.deepcopy(answer)

                update_answer = build_subtask_list(answer, region_assets_dir)
                
                if update_answer != None:
                    habitat_json = {}
                    habitat_json["Task instruction"] = clean_task_instruction(update_answer["phase_1"]['Task instruction']).replace('/', '')
                    habitat_json["Subtask list"] = update_answer["phase_1"]["Subtask list"]
                    
                    habitat_json["Robot"] = random.choice(['spot', 'fetch'])
                    habitat_json["Scene"] = scene_idx

                    input_scene = {}
                    for entry in scene_graph_json["floor_1"]["item"]:
                        # Split region string by '_' and extract id
                        parts = entry['region'].split('_')
                        region_id = parts[-1]              # numeric id as string
                        region_type = ' '.join(parts[:-1]) # join remaining parts for full type
                        combined_items = entry['asset'] + entry['object']
                        output_key = f"Region {region_id}: {region_type}"
                        input_scene[output_key] = combined_items
                
                    wrong = False
                    objs = []
                    # import pdb; pdb.set_trace()
                    for task in habitat_json["Subtask list"]:
                        if "Move_to" in task:
                            obj_id = task[9:-2]
                            obj = obj_id.split("_")
                            ok = None
                            for k in list(input_scene.keys()):
                                if obj[1] in k:
                                    ok = k
                                    break
                            if ok == None:
                                wrong = True
                                # import pdb; pdb.set_trace()
                                break
                            objs.append([obj[0], ok])
                            if obj[0].replace(' ', '_') not in input_scene[ok]:
                                wrong = True
                                # import pdb; pdb.set_trace()
                                break
                    if "region" in habitat_json['Task instruction'] or "Region" in habitat_json['Task instruction']:
                        wrong = True
                    if wrong:
                        continue
                    length = len(objs)
            
                    habitat_json['Object'] = objs
                    habitat_json["Start"] = {}
                    habitat_json["Start"]["robot_1"] = update_answer["phase_1"]["Single robot start region"]["robot_1"]
                    habitat_json["Role"] = role_idx
                    if str(floor_num) != '1':
                        habitat_json["Floor"] = floor_idx

                    habitat_save_final_dir = os.path.join(habitat_save_dir, habitat_json['Task instruction'])
                    save_raw_dir = os.path.join(save_dir, habitat_json['Task instruction'])
                    if not os.path.exists(habitat_save_final_dir):
                        os.mkdir(habitat_save_final_dir)
                    
                    if not os.path.exists(save_raw_dir):
                        os.mkdir(save_raw_dir)
                    
                    with open(os.path.join(habitat_save_final_dir, 'single_config.json'), 'w') as f:
                        json.dump(habitat_json, f, indent=2)

                    answer_copy["Scene"] = scene_idx
                    with open(os.path.join(save_raw_dir, 'single_config.json'), 'w') as f:
                        json.dump(answer_copy, f, indent=2)

    