import os
import random
from collections import defaultdict
import networkx as nx
import pickle
import json
import re

def prompt_make(prompt_path, ex_prompt):    # set prompt
    with open(prompt_path, "r", encoding='utf-8') as f:  # open the file
        txt = f.readlines()
        prompt_system = txt[1]
        prompt = txt[3]  # read the user line
        if len(txt) > 4:
            for i in range(4, len(txt)):
                prompt = prompt + txt[i]
        prompt = prompt + ex_prompt
        # print(prompt_system)
        # print(prompt)
        return prompt_system, prompt

def parse_region_file(file_path):
    region_objects = {}
    current_region = None

    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            region_match = re.match(r"Region id:(_\d+), position:\[.*\]", line)
            if region_match:
                current_region = region_match.group(1)
                region_objects[current_region] = []
                continue
                
            object_match = re.match(
                r"Id:(\d+), name:(.*?), position:\[([-\d.eE+]+), ([\d.eE+-]+), ([-\d.eE+]+)\]",
                line
            )
            if object_match and current_region is not None:
                obj_id = int(object_match.group(1))
                name = object_match.group(2)
                pos = [float(object_match.group(3)),
                       float(object_match.group(4)),
                       float(object_match.group(5))]
                region_objects[current_region].append({
                    'id': obj_id,
                    'name': name,
                    'position': pos
                })

    return region_objects

def clean_room_name(name):
    name = re.sub(r'[:&/]', '_', name)
    name = name.replace(' ', '_')
    name = re.sub(r'_+', '_', name)
    return name.strip('_')


def is_asset(name: str) -> bool:

    ASSET_KEYWORDS = {
    "bed", "sofa", "couch", "table", "desk", "shelf", "bookshelf", "cabinet", "drawer", "sink",
    "stove", "bathtub", "toilet", "fridge", "refrigerator", "tv", "television", "stand",
    "armchair", "chair", "bench", "counter", "fireplace", "closet", "wardrobe", "kitchen_extractor",
    "mirror", "dresser", "nightstand", "faucet", "washing_machine", "dryer", "oven",
    "coffee_machine", "dishwasher", "kitchen_island", "door", "archway", "vent"
    }

    name = name.lower().replace(' ', '_')
    for keyword in ASSET_KEYWORDS:
        if keyword in name:
            return True
    return False

def sample_unique_random(name_list, k=5):
    if not name_list:
        return []

    unique_names = list(set(name_list))
    sample_num = min(k, len(unique_names))
    return random.sample(unique_names, sample_num)

def get_unique_names(name_list):
    if not name_list:
        return []
    return list(set(name_list))

def generate_scene_graph_json(G, G_regionid, scene_item):


    IGNORE_ITEMS = {"wood", "cleaner", "door knob", "mirror", "switch", "device"}

    region_assets = parse_region_file(scene_item)

    region_ids = sorted(set(int(node.split('_')[0]) for node in G.nodes()))

    region_name_map = {
        int(rid): clean_room_name(name)
        for rid, name in G_regionid.items()
    }

    # import pdb; pdb.set_trace()

    room_list = [{'id': f"{region_name_map[rid]}_{rid}"} for rid in region_ids]

    region_edges = set()
    for u, v in G.edges():
        rid_u = int(u.split('_')[0])
        rid_v = int(v.split('_')[0])
        if rid_u != rid_v:
            if rid_u < rid_v:
                region_edges.add((rid_u, rid_v))
            else:
                region_edges.add((rid_v, rid_u))

    links = [
        f"{region_name_map[r1]}_{r1} <-> {region_name_map[r2]}_{r2}"
        for r1, r2 in sorted(region_edges)
    ]

    room_assets_map = defaultdict(list)
    room_objects_map = defaultdict(list)
    for region_key, asset_list in region_assets.items():
        rid = int(region_key.lstrip('_'))
        if rid not in region_name_map.keys():
            continue
        room_name = f"{region_name_map[rid]}_{rid}"

        asset_names = []
        object_names = []

        for item in asset_list:
            name = item['name'].replace(' ', '_').lower()
            if name in IGNORE_ITEMS:
                continue

            if is_asset(name):
                asset_names.append(name)
            else:
                object_names.append(name)

        room_assets_map[room_name].extend(asset_names)
        room_objects_map[room_name].extend(object_names)        

    room_to_items = {}

    for room_name in room_assets_map:
        asset_samples = sample_unique_random(room_assets_map[room_name], k=3)
        object_samples = sample_unique_random(room_objects_map.get(room_name, []), k=3)

        if not asset_samples and not object_samples:
            continue

        room_to_items[room_name] = {
            'region': room_name,
            'asset': asset_samples,
            'object': object_samples
        }
    # import pdb; pdb.set_trace()
    items = list(room_to_items.values())

    scene_json = {
        "floor_1": {
            "region": room_list,
            "link": links,
            'item': items
        }
    }

    return scene_json

def generate_scene_graph_total_json(G, G_regionid, scene_item):

    IGNORE_ITEMS = {"wood", "cleaner", "door knob", "mirror", "switch", "device"}

    region_assets = parse_region_file(scene_item)

    region_ids = sorted(set(int(node.split('_')[0]) for node in G.nodes()))

    region_name_map = {
        int(rid): clean_room_name(name)  
        for rid, name in G_regionid.items()
    }

    room_list = [{'id': f"{region_name_map[rid]}_{rid}"} for rid in region_ids]

    region_edges = set()
    for u, v in G.edges():
        rid_u = int(u.split('_')[0])
        rid_v = int(v.split('_')[0])
        if rid_u != rid_v:
            if rid_u < rid_v:
                region_edges.add((rid_u, rid_v))
            else:
                region_edges.add((rid_v, rid_u))

    links = [
        f"{region_name_map[r1]}_{r1} <-> {region_name_map[r2]}_{r2}"
        for r1, r2 in sorted(region_edges)
    ]

    room_assets_map = defaultdict(list)
    room_objects_map = defaultdict(list)

    for region_key, asset_list in region_assets.items():
        rid = int(region_key.lstrip('_'))
        if rid not in region_name_map.keys():
            continue
        room_name = f"{region_name_map[rid]}_{rid}"

        asset_names = []
        object_names = []

        for item in asset_list:
            name = item['name'].replace(' ', '_').lower()
            if name in IGNORE_ITEMS:
                continue

            if is_asset(name):
                asset_names.append(name)
            else:
                object_names.append(name)

        room_assets_map[room_name].extend(asset_names)
        room_objects_map[room_name].extend(object_names)        

    room_to_items = {}

    for room_name in room_assets_map:
        asset_samples = get_unique_names(room_assets_map[room_name])
        object_samples = get_unique_names(room_objects_map.get(room_name, []))

        if not asset_samples and not object_samples:
            continue

        room_to_items[room_name] = {
            'region': room_name,
            'asset': asset_samples,
            'object': object_samples
        }
    # import pdb; pdb.set_trace()
    items = list(room_to_items.values())

    scene_json = {
        "floor_1": {
            "region": room_list,
            "link": links,
            'item': items
        }
    }

    return scene_json


def check_path_and_adjacency(region_a, region_b, G, G_regionid):
    """
    Simultaneously determine whether two regions are connected and adjacent, and return the path length and path.
    """
    region1_num = region_a.split("_")[-1]
    region2_num = region_b.split("_")[-1]
    # import pdb; pdb.set_trace()

    start_nodes = [node for node in G.nodes if node.split("_")[0] == region1_num]
    target_nodes = [node for node in G.nodes if node.split("_")[0] == region2_num]

    # 判断 adjacency
    adjacency = any(G.has_edge(na, nb) for na in start_nodes for nb in target_nodes)

    shortest_path = None
    shortest_length = float('inf')

    for s in start_nodes:
        for t in target_nodes:
            try:
                path = nx.shortest_path(G, source=s, target=t)
                if len(path) < shortest_length:
                    shortest_length = len(path)
                    shortest_path = path
            except nx.NetworkXNoPath:
                continue

    if shortest_path is None:
        return {
            "connect": False,
            "adjacency": adjacency,
            "path_length": 0,
            "path": []
        }

    region_steps = []
    prev_region_id = None
    for node in path:
        region_id = node.split("_")[0]
        if region_id != prev_region_id:
            region_type = G_regionid[region_id].replace(" ", "_")
            region_step = f"{region_type}_{region_id}"
            region_steps.append(region_step)
            prev_region_id = region_id

    transitions = [f"{region_steps[i]} -> {region_steps[i+1]}" for i in range(len(region_steps) - 1)]

    return {
        "connect": True,
        "adjacency": adjacency,
        "path_length": len(transitions),
        "path": transitions
    }


def check_collab_path_efficient(
    robot_2_region, 
    transfer_nodes,
    collab_type,
    robot_1_region,
    target_region,
    end_region,
    G
    ):

    def get_region_nodes(region_id):
        """Return all nodes in a region by matching prefix like '8_'"""
        region_prefix = region_id.split("_")[-1]
        return [n for n in G.nodes if n.startswith(f"{region_prefix}_")]

    def shortest_geometric_path_length(nodes_a, nodes_b):
        """Compute the shortest geometric path between two node sets."""
        min_dist = float('inf')
        for na in nodes_a:
            for nb in nodes_b:
                try:
                    length = nx.shortest_path_length(G, na, nb, weight='weight')
                    if length < min_dist:
                        min_dist = length
                except nx.NetworkXNoPath:
                    continue
        return min_dist

    # Gather region nodes
    r2_nodes = get_region_nodes(robot_2_region)
    r1_nodes = get_region_nodes(robot_1_region)
    if transfer_nodes:
        transfer_nodes = get_region_nodes(transfer_nodes)
    target_nodes = get_region_nodes(target_region)
    end_nodes = get_region_nodes(end_region)

    r1_to_target = shortest_geometric_path_length(r1_nodes, target_nodes)
    target_to_end = shortest_geometric_path_length(target_nodes, end_nodes)
    solo_cost = r1_to_target + target_to_end
    # import pdb; pdb.set_trace()

    if collab_type == "Type-A1":
        r2_to_target = shortest_geometric_path_length(r2_nodes, target_nodes)
        target_to_transfer = shortest_geometric_path_length(target_nodes, transfer_nodes)
        robot2_first_leg = r2_to_target + target_to_transfer
        # robot_1: r1 -> transfer
        r1_to_transfer = shortest_geometric_path_length(r1_nodes, transfer_nodes)
        parallel_leg = max(robot2_first_leg, r1_to_transfer)
        # robot_1: transfer -> end
        transfer_to_end = shortest_geometric_path_length(transfer_nodes, end_nodes)
        parallel_cost = parallel_leg + transfer_to_end

    elif collab_type == "Type-A2":

        r1_to_target = shortest_geometric_path_length(r1_nodes, target_nodes)
        target_to_transfer = shortest_geometric_path_length(target_nodes, transfer_nodes)
        robot1_first_leg = r1_to_target + target_to_transfer

        # robot_2: r2 -> transfer
        r2_to_transfer = shortest_geometric_path_length(r2_nodes, transfer_nodes)

        parallel_leg = max(robot1_first_leg, r2_to_transfer)

        # robot_2: transfer -> end
        transfer_to_end = shortest_geometric_path_length(transfer_nodes, end_nodes)

        parallel_cost = parallel_leg + transfer_to_end

    elif collab_type == "Type-B1":
        # robot_2: r2 -> target -> end
        r2_to_target = shortest_geometric_path_length(r2_nodes, target_nodes)
        target_to_end = shortest_geometric_path_length(target_nodes, end_nodes)
        parallel_cost = r2_to_target + target_to_end

    elif collab_type == "Type-B2":
        # same as solo
        parallel_cost = solo_cost
        
    else:
            return {
                "efficient": False
            }

    print("efficient: ", parallel_cost < solo_cost)
    print("rate: ", parallel_cost / solo_cost)

    return {
        "efficient": parallel_cost < solo_cost,
        "rate": parallel_cost / solo_cost
    }

import numpy as np
def cluster_center_node(
    region_key: str,
    node_view,                # networkx NodeView
    nav_coords_dict: dict,    # {'8_00013': np.array([x, y, z]), ...}
    use_3d: bool = False
):
    prefix = region_key.lstrip('_')                # '_2' -> '2'
    cand_nodes = [n for n in node_view
                  if str(n).split('_', 1)[0] == prefix]
    if not cand_nodes:
        return None

    coords = np.array([nav_coords_dict[n] for n in cand_nodes])  # (N, 3)

    if use_3d:
        pts = coords                                 # (x, y, z)
    else:
        pts = coords[:, [0, 1]]                      # (x, y)
        pts[:, 1] *= -1                               # y → -y 

    centroid = pts.mean(axis=0)                      # 质心
    dists    = np.linalg.norm(pts - centroid, axis=1)
    idx      = np.argmin(dists)

    center_node  = cand_nodes[idx]

    return center_node

def nearest_navpoint_to_object_vec(
    target_pos,
    region_key,
    node_view,
    nav_coords_dict,
    index=None
):
    region_prefix = region_key.lstrip('_')           # '_8'→'8'
    # node_view[index]，slice / list / int
    cand_nodes = node_view[index] if index is not None else node_view
    cand_nodes = [n for n in cand_nodes
                  if str(n).split('_', 1)[0] == region_prefix]
    # import pdb; pdb.set_trace()
    if not cand_nodes:
        return None, None, np.inf

    coords_xy = np.array([
        [nav_coords_dict[n][0], -nav_coords_dict[n][1]]  # (x, -y)
        for n in cand_nodes
    ])                               # shape = (N, 2)

    obj_xy = np.array([target_pos[0], target_pos[2]])   # shape = (2,)
    
    dists = np.linalg.norm(coords_xy - obj_xy, axis=1)   # shape = (N,)
    idx   = np.argmin(dists)

    best_node  = cand_nodes[idx]
    best_coord = nav_coords_dict[best_node]
    best_dist  = float(dists[idx])
    return best_node, best_coord, best_dist

def find_target_position(region_objects, region_key, target_object):
    objects_in_region = region_objects.get(region_key, [])
    
    matched_objects = [obj for obj in objects_in_region if obj['name'] == target_object]
    
    if not matched_objects:
        return None 
    
    selected_object = min(matched_objects, key=lambda x: x['id'])
    
    return selected_object['position']

import math
def insert_temp_point(tmp_id, G: nx.Graph, point_xyz, anchor_node_id, 
                      node_prefix="tmp_obj") -> str:
    G.add_node(tmp_id, position=point_xyz)

    x1, y1, *rest1 = point_xyz
    x2, y2, *rest2 = G.nodes[anchor_node_id]['position']
    dist_xy = math.hypot(x2 - x1, y2 - y1)     

    G.add_edge(tmp_id, anchor_node_id, weight=dist_xy)
    G.add_edge(anchor_node_id, tmp_id, weight=dist_xy)

    return tmp_id

def check_collab_path_efficient_sim_graph(
    robot_2_region, 
    transfer_nodes,
    transfer_asset,
    collab_type,
    solo_cost,  # 
    r1_nodes,
    target_nodes,
    end_nodes,
    region_objects,
    G
    ):
    r2_nodes = cluster_center_node('_'+robot_2_region.split("_")[-1], nx.nodes(G), nx.get_node_attributes(G, 'position'))

    if r2_nodes == None:
        return {"efficient": False}
    
    print("r2_nodes:", r2_nodes)

    if transfer_nodes:
        print("transfer_asset: ", transfer_asset)
        transfer_asset_pos = find_target_position(region_objects, '_'+transfer_nodes.split("_")[-1], ' '.join(transfer_asset.split('_')))
        if transfer_asset_pos != None:
            transfer_nodes, _, _ = nearest_navpoint_to_object_vec(transfer_asset_pos, '_'+transfer_nodes.split("_")[-1], nx.nodes(G), nx.get_node_attributes(G, 'position'))
            # import pdb; pdb.set_trace()
            insert_temp_point("transfer_"+transfer_asset, G, transfer_asset_pos, transfer_nodes)
            transfer_nodes = "transfer_"+transfer_asset
            
            # import pdb; pdb.set_trace()
            
            if transfer_nodes != None:
                print("transfer_nodes", transfer_nodes)
            else:
                return {"efficient": False}
        else:
            return {"efficient": False}
    else:
        return {"efficient": False}

    if collab_type == "Type-A1":

        try:
            r2_to_target = nx.shortest_path_length(G,
                                source=r2_nodes,
                                target=target_nodes,
                                weight='weight')   
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            return {"efficient": False}

        try:
            target_to_transfer = nx.shortest_path_length(G,
                                source=target_nodes,
                                target=transfer_nodes,
                                weight='weight')   
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            return {"efficient": False}        

        robot2_first_leg = r2_to_target + target_to_transfer

        try:
            r1_to_transfer = nx.shortest_path_length(G,
                                source=r1_nodes,
                                target=transfer_nodes,
                                weight='weight')  
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            return {"efficient": False}

        parallel_leg = max(robot2_first_leg, r1_to_transfer)

        try:
            transfer_to_end = nx.shortest_path_length(G,
                                source=transfer_nodes,
                                target=end_nodes,
                                weight='weight')  
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            return {"efficient": False}

        g_parallel_cost = parallel_leg + transfer_to_end
        # print("transfer_to_end :", transfer_to_end)
        r1_parallel_cost = r1_to_transfer + transfer_to_end
        # import pdb; pdb.set_trace()
        # if r1_parallel_cost < solo_cost:
            # import pdb; pdb.set_trace()
        # import pdb; pdb.set_trace()
        path_infor = {
            "r2_to_target": r2_to_target,
            "target_to_transfer": target_to_transfer,
            "robot2_first_leg": robot2_first_leg,
            "r1_to_transfer": r1_to_transfer,
            "parallel_leg": parallel_leg,
            "transfer_to_end": transfer_to_end,
            "type": collab_type
        }

        # print("efficient: ", g_parallel_cost < solo_cost)
        print("g_rate: ", g_parallel_cost / solo_cost)
        print("r1_rate: ", r1_parallel_cost / solo_cost)

        return {
            "g_efficient": g_parallel_cost < solo_cost,
            "r1_efficient": r1_parallel_cost < solo_cost,
            "efficient": True if g_parallel_cost < solo_cost or r1_parallel_cost < solo_cost else False,
            "g_rate": g_parallel_cost / solo_cost,
            "r1_rate": r1_parallel_cost / solo_cost,
            'path_info': path_infor
        }


    elif collab_type == "Type-A2":

        try:
            r1_to_target = nx.shortest_path_length(G,
                                source=r1_nodes,
                                target=target_nodes,
                                weight='weight')   
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            return {"efficient": False}

        try:
            target_to_transfer = nx.shortest_path_length(G,
                                source=target_nodes,
                                target=transfer_nodes,
                                weight='weight')   
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            return {"efficient": False}
        
        robot1_first_leg = r1_to_target + target_to_transfer

        # robot_2: r2 -> transfer
        # r2_to_transfer = shortest_geometric_path_length(r2_nodes, transfer_nodes)
        try:
            r2_to_transfer = nx.shortest_path_length(G,
                                source=r2_nodes,
                                target=transfer_nodes,
                                weight='weight')  
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            return {"efficient": False}
        
        parallel_leg = max(robot1_first_leg, r2_to_transfer)

        try:
            transfer_to_end = nx.shortest_path_length(G,
                                source=transfer_nodes,
                                target=end_nodes,
                                weight='weight')   
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            return {"efficient": False}
        
        g_parallel_cost = parallel_leg + transfer_to_end
        r1_parallel_cost = r1_to_target + target_to_transfer
        # if r1_parallel_cost < solo_cost:
        #     import pdb; pdb.set_trace()
        path_infor = {
            "r1_to_target":r1_to_target,
            "target_to_transfer": target_to_transfer,
            "robot1_first_leg": robot1_first_leg,
            "r2_to_transfer": r2_to_transfer,
            "parallel_leg": parallel_leg,
            "transfer_to_end": transfer_to_end,
            "type": collab_type
        }

        print("g_rate: ", g_parallel_cost / solo_cost)
        print("r1_rate: ", r1_parallel_cost / solo_cost)

        return {
            "g_efficient": g_parallel_cost < solo_cost,
            "r1_efficient": r1_parallel_cost < solo_cost,
            "efficient": True if g_parallel_cost < solo_cost or r1_parallel_cost < solo_cost else False,
            "g_rate": g_parallel_cost / solo_cost,
            "r1_rate": r1_parallel_cost / solo_cost,
            'path_info': path_infor
        }

    else:
            return {
                "efficient": False
            }

def get_path_and_adjacency(region_a, single_agent, single_target, single_end, G, G_regionid):

    robot2_to_object = check_path_and_adjacency(region_a, single_agent, G, G_regionid)

    robot2_to_single_target = check_path_and_adjacency(region_a, single_target)

    robot2_to_single_end = check_path_and_adjacency(region_a, single_end)


    return {
        "robot2_to_robot1": robot2_to_object,
        "robot2_to_target": robot2_to_single_target,
        "robot2_to_end": robot2_to_single_end
    }

def check_two_path_and_adjacency(start_region, target_region, end_region, G, G_regionid):
    res = {
        "start_2_target": None,
        "target_2_end": None,
        "s2t_valid": False,
        "t2e_valid": False,
        "valid": False
    }

    s2t = check_path_and_adjacency(start_region, target_region, G, G_regionid)
    res["start_2_target"] = s2t
    if s2t["connect"] and not s2t["adjacency"] and s2t["path_length"] >= 2:
        res["s2t_valid"] = True

    t2e = check_path_and_adjacency(target_region, end_region, G, G_regionid)
    res["target_2_end"] = t2e
    if t2e["connect"] and not t2e["adjacency"] and t2e["path_length"] >= 2:
        res["t2e_valid"] = True

    res["valid"] = res["s2t_valid"] and res["t2e_valid"]

    return res

def check_three_path_and_adjacency(start_region, target_region, end_region, G, G_regionid):
    res = {
        "start_2_target": None,
        "target_2_end": None,
        "start_2_end": None,
        "s2t_valid": False,
        "t2e_valid": False,
        "valid": False
    }

    s2t = check_path_and_adjacency(start_region, target_region, G, G_regionid)
    res["start_2_target"] = s2t
    if s2t["connect"] and not s2t["adjacency"] and s2t["path_length"] >= 2:
        res["s2t_valid"] = True

    t2e = check_path_and_adjacency(target_region, end_region, G, G_regionid)
    res["target_2_end"] = t2e
    if t2e["connect"] and not t2e["adjacency"] and t2e["path_length"] >= 2:
        res["t2e_valid"] = True

    res["valid"] = res["s2t_valid"] and res["t2e_valid"]

    return res

