# local_cbs_solver_robust_v12_optimized.py (With Correct Conflict Lifecycle)   zuihou zhong dian hui xiaoshi

import heapq
import itertools
from collections import defaultdict
import logging
from typing import Optional, Tuple, List, Dict
import random # NEW: 导入random模块

import numba
import numpy as np
from numba.core import types
from numba.typed import Dict as NumbaDict

# --- 全局常量和类型定义，保持不变 ---
ACTION_DELTAS = np.array([[0, 0], [-1, 0], [1, 0], [0, -1], [0, 1]], dtype=np.int8)
POS_KEY_TYPE = types.UniTuple(types.int64, 3)


# --- 类定义 (保持不变) ---
class Constraint: #
    def __init__(self, agent_id, loc, timestep, is_edge_constraint=False, prev_loc=None):
        self.agent_id = agent_id; self.location = loc; self.timestep = timestep
        self.is_edge_constraint = is_edge_constraint; self.prev_location = prev_loc
    def __eq__(self, other):
        if not isinstance(other, Constraint): return False
        return (self.agent_id == other.agent_id and self.location == other.location and self.timestep == other.timestep and
                self.is_edge_constraint == other.is_edge_constraint and self.prev_location == other.prev_location)
    def __hash__(self): return hash((self.agent_id, self.location, self.timestep, self.is_edge_constraint, self.prev_location))
    def __repr__(self):
        if self.is_edge_constraint: return f"EdgeConstraint(A{self.agent_id}: {self.prev_location}->{self.location} @t={self.timestep})"
        return f"VertexConstraint(A{self.agent_id}: {self.location} @t={self.timestep})"

class Conflict: #
    VERTEX = 1; EDGE = 2
    def __init__(self, conflict_type, agent1_id, agent2_id, loc1, timestep, loc2=None):
        self.type = conflict_type; self.agent1_id = agent1_id; self.agent2_id = agent2_id
        self.location1 = loc1; self.timestep = timestep; self.location2 = loc2
    def __repr__(self):
        if self.type == Conflict.VERTEX: return f"VertexConflict(A{self.agent1_id}, A{self.agent2_id} @ {self.location1}, t={self.timestep})"
        elif self.type == Conflict.EDGE: return (f"EdgeConflict(A{self.agent1_id} {self.location1[0]}->{self.location1[1]} vs A{self.agent2_id} {self.location2[0]}->{self.location2[1]} @t={self.timestep})")
        return "UnknownConflict"

class CBSHighLevelNode: #
    _ids = itertools.count(0)
    def __init__(self, constraints: Optional[List[Constraint]] = None, paths=None, sum_of_costs=0):
        self.id = next(self._ids); self.constraints = constraints if constraints is not None else []
        self.paths = paths if paths is not None else {}; self.sum_of_costs = sum_of_costs; self.conflicts = []
    def __lt__(self, other):
        if len(self.conflicts) != len(other.conflicts): return len(self.conflicts) < len(other.conflicts)
        if self.sum_of_costs != other.sum_of_costs: return self.sum_of_costs < other.sum_of_costs
        return self.id < other.id

# --- Numba JIT A* 函数 (MODIFIED to handle cost_map) ---
@numba.jit(nopython=True)
def _numba_low_level_astar(
    start_pos: Tuple[int, int],
    goal_pos: Tuple[int, int],
    obstacles_local: np.ndarray,
    constraints_array: np.ndarray,
    cost_map: np.ndarray, # NEW: 接收一个成本地图
    max_path_len: int,
):
    map_height, map_width = obstacles_local.shape
    def h_func(p1_r, p1_c, p2_r, p2_c): return abs(p1_r - p2_r) + abs(p1_c - p2_c)
    
    # open_set stores (f_score, g_score, r, c, t)
    open_set = [(h_func(start_pos[0], start_pos[1], goal_pos[0], goal_pos[1]), 0, start_pos[0], start_pos[1], 0)]
    came_from = NumbaDict.empty(key_type=POS_KEY_TYPE, value_type=POS_KEY_TYPE)
    g_scores = NumbaDict.empty(key_type=POS_KEY_TYPE, value_type=types.int64)
    start_key = (start_pos[0], start_pos[1], 0)
    g_scores[start_key] = 0

    while len(open_set) > 0:
        f, g, r, c, t = heapq.heappop(open_set)
        current_key = (r, c, t)

        if g > g_scores.get(current_key, np.iinfo(np.int64).max):
            continue
        
        if (r, c) == goal_pos:
            path = []
            curr = current_key
            while curr in came_from:
                path.append((curr[0], curr[1]))
                curr = came_from[curr]
            path.append((start_pos[0], start_pos[1]))
            path.reverse()
            while len(path) < max_path_len + 1:
                path.append(goal_pos)
            return path

        if t >= max_path_len:
            continue
            
        for dr_dc_idx in [1, 2, 3, 4, 0]: # UP, DOWN, LEFT, RIGHT, STAY
            dr, dc = ACTION_DELTAS[dr_dc_idx]
            nr, nc = r + dr, c + dc
            nt = t + 1

            if not (0 <= nr < map_height and 0 <= nc < map_width and not obstacles_local[nr, nc]):
                continue

            is_move_valid = True
            for i in range(constraints_array.shape[0]):
                constr = constraints_array[i]
                if constr[3] == 0: # Vertex constraint
                    if constr[2] == nt and constr[0] == nr and constr[1] == nc:
                        is_move_valid = False
                        break
                else: # Edge constraint
                    if constr[2] == nt and constr[4] == r and constr[5] == c and constr[0] == nr and constr[1] == nc:
                        is_move_valid = False
                        break
            if not is_move_valid:
                continue

            # --- A* COST MODIFICATION ---
            # Base cost is 1 (for one step), plus additional cost from the cost_map.
            move_cost = 1 + int(cost_map[nr, nc])
            new_g = g + move_cost
            # --- END MODIFICATION ---
            
            neighbor_key = (nr, nc, nt)
            if new_g < g_scores.get(neighbor_key, np.iinfo(np.int64).max):
                g_scores[neighbor_key] = new_g
                new_f = new_g + h_func(nr, nc, goal_pos[0], goal_pos[1])
                heapq.heappush(open_set, (new_f, new_g, nr, nc, nt))
                came_from[neighbor_key] = current_key

    return None

# --- A* 包装器 (MODIFIED to accept and pass cost_map) ---
def low_level_astar_for_cbs(
    agent_id, start_pos, original_goal_pos, obstacles_local, constraints,
    max_plan_len: int,
    dynamic_cost_map: Optional[np.ndarray] = None, # NEW
    true_agent_global_goal_abs: Optional[Tuple[int, int]] = None,
    persistent_map_info: Optional[Dict] = None,
    verbose_cbs: bool = False
):
    final_target_goal_for_astar = original_goal_pos
    # ... (goal redirection logic remains same) ...

    agent_constraints = []
    # ... (constraint processing remains same) ...
    for c in constraints:
        if c.agent_id == agent_id:
            if c.is_edge_constraint:
                agent_constraints.append([c.location[0], c.location[1], c.timestep, 1, c.prev_location[0], c.prev_location[1]])
            else:
                agent_constraints.append([c.location[0], c.location[1], c.timestep, 0, -1, -1])

    if not agent_constraints:
        constraints_array = np.empty((0, 6), dtype=np.int32)
    else:
        constraints_array = np.array(agent_constraints, dtype=np.int32)
        
    # If no dynamic cost map is provided, use a zero-cost map.
    if dynamic_cost_map is None:
        dynamic_cost_map = np.zeros_like(obstacles_local, dtype=np.int32)

    path_tuples = _numba_low_level_astar(
        start_pos, final_target_goal_for_astar, obstacles_local,
        constraints_array, dynamic_cost_map, max_plan_len,
    )
    if path_tuples is not None:
        return [tuple(p) for p in path_tuples]
    return None

# --- 时空冲突检测 (保持不变) ---
def detect_all_conflicts_spacetime(paths_dict: Dict[int, List[Tuple[int, int]]], max_timestep: int):
    # ... (代码与原版相同) ...
    all_conflicts = []
    arrival_times = {}
    for aid, path in paths_dict.items():
        if path is None: return [] 
        goal_pos = path[-1]
        try: arrival_time = path.index(goal_pos)
        except ValueError: arrival_time = len(path)
        arrival_times[aid] = arrival_time
    vertex_reservation = defaultdict(list)
    edge_reservation = defaultdict(list)
    for agent_id, path in paths_dict.items():
        agent_arrival_time = arrival_times[agent_id]
        for t, pos in enumerate(path):
            if t > agent_arrival_time: break
            vertex_reservation[(t, pos)].append(agent_id)
        for t in range(len(path) - 1):
            if t + 1 > agent_arrival_time: break
            pos1, pos2 = path[t], path[t+1]
            if pos1 != pos2: edge_reservation[(t + 1, pos1, pos2)].append(agent_id)
    for (timestep, pos), occupants in vertex_reservation.items():
        if len(occupants) > 1:
            for agent1_id, agent2_id in itertools.combinations(occupants, 2):
                all_conflicts.append(Conflict(Conflict.VERTEX, agent1_id, agent2_id, pos, timestep))
    for (timestep, pos1, pos2), occupants1 in edge_reservation.items():
        if (timestep, pos2, pos1) in edge_reservation:
            occupants2 = edge_reservation[(timestep, pos2, pos1)]
            for agent1_id in occupants1:
                for agent2_id in occupants2:
                    if agent1_id < agent2_id:
                        all_conflicts.append(Conflict(Conflict.EDGE, agent1_id, agent2_id, (pos1, pos2), timestep, (pos2, pos1)))
    return all_conflicts

# --- 核心CBS求解器 (MODIFIED to handle agent_memories) ---
def solve_local_cbs_robust(
    agents_data: List[Dict], obstacles_local_map: np.ndarray, max_plan_len: int,
    max_cbs_iterations=500,
    agent_memories: Optional[Dict[int, 'AgentMemory']] = None, # NEW
    agents_true_global_goals_abs: Optional[Dict[int, Tuple[int,int]]] = None,
    persistent_map_bundle: Optional[Dict] = None,
    verbose_cbs_solver: bool = False,
    initial_constraints_list: Optional[List[Constraint]] = None
):
    
    def is_cardinal(conflict: Conflict, agent_id: int, current_node: CBSHighLevelNode) -> bool:
        # ... (is_cardinal 逻辑保持不变) ...
        original_path = current_node.paths.get(agent_id)
        if not original_path: return False
        original_goal = original_path[-1]
        try: original_cost = original_path.index(original_goal)
        except ValueError: original_cost = len(original_path) -1
        if conflict.type == Conflict.VERTEX:
            new_constr = Constraint(agent_id, conflict.location1, conflict.timestep)
        elif conflict.type == Conflict.EDGE:
            edge = conflict.location1 if agent_id == conflict.agent1_id else conflict.location2
            new_constr = Constraint(agent_id, edge[1], conflict.timestep, True, edge[0])
        else: return False
        temp_constraints = current_node.constraints + [new_constr]
        agent_info = next((a for a in agents_data if a['id'] == agent_id), None)
        if not agent_info: return False
        # Get agent's specific cost map if available
        dynamic_cost_map_agent = agent_memories[agent_id].dynamic_cost_map if agent_memories and agent_id in agent_memories else None
        new_path = low_level_astar_for_cbs(
            agent_id, agent_info['start_local'], agent_info['goal_local'],
            obstacles_local_map, temp_constraints, max_plan_len, dynamic_cost_map=dynamic_cost_map_agent,
            true_agent_global_goal_abs=agents_true_global_goals_abs.get(agent_id) if agents_true_global_goals_abs else None,
            persistent_map_info=persistent_map_bundle, verbose_cbs=False )
        if new_path is None: return True
        new_goal = new_path[-1]
        try: new_cost = new_path.index(new_goal)
        except ValueError: new_cost = len(new_path) - 1
        return new_cost > original_cost

    # --- CBS主逻辑 ---
    open_list = []
    iteration_count = 0
    root_node = CBSHighLevelNode(constraints=initial_constraints_list if initial_constraints_list is not None else [])
    
    for agent_info in agents_data:
        agent_id = agent_info['id']
        dynamic_cost_map_agent = agent_memories[agent_id].dynamic_cost_map if agent_memories and agent_id in agent_memories else None
        
        path_coords = low_level_astar_for_cbs(
            agent_id, agent_info['start_local'], agent_info['goal_local'],
            obstacles_local_map, root_node.constraints, max_plan_len,
            dynamic_cost_map=dynamic_cost_map_agent, # Pass the map here
            true_agent_global_goal_abs=agents_true_global_goals_abs.get(agent_id) if agents_true_global_goals_abs else None,
            persistent_map_info=persistent_map_bundle, verbose_cbs=verbose_cbs_solver
        )
        if path_coords is None:
            if verbose_cbs_solver: logging.warning(f"  [CBS RootFail] Agt {agent_id} initial path fail.")
            return None
        root_node.paths[agent_id] = path_coords
        goal = path_coords[-1]
        try: cost = path_coords.index(goal)
        except ValueError: cost = len(path_coords) - 1
        root_node.sum_of_costs += cost

    root_node.conflicts = detect_all_conflicts_spacetime(root_node.paths, max_plan_len)
    heapq.heappush(open_list, root_node)

    while open_list and iteration_count < max_cbs_iterations:
        current_node = heapq.heappop(open_list)
        iteration_count += 1
        
        if not current_node.conflicts:
            if verbose_cbs_solver: logging.info(f"[CBS] Solution found! Iterations={iteration_count}")
            return current_node.paths

        # ... (冲突选择和分支逻辑保持不变) ...
        # ... BUT, the call to low_level_astar_for_cbs inside the loop must also pass the cost map ...
        classified_conflicts = []
        for c in current_node.conflicts:
            is_cardinal_for_a1 = is_cardinal(c, c.agent1_id, current_node)
            is_cardinal_for_a2 = is_cardinal(c, c.agent2_id, current_node)
            priority = 0 if (is_cardinal_for_a1 or is_cardinal_for_a2) else 1
            classified_conflicts.append((priority, c))
        classified_conflicts.sort(key=lambda x: (x[0], x[1].timestep, x[1].type, x[1].agent1_id, x[1].agent2_id))
        conflict = classified_conflicts[0][1]
        agents_to_constrain = [conflict.agent1_id, conflict.agent2_id]
        if classified_conflicts[0][0] == 1: random.shuffle(agents_to_constrain)

        for agent_to_constrain in agents_to_constrain:
            new_constraints_for_child = list(current_node.constraints)
            if conflict.type == Conflict.VERTEX: new_constr = Constraint(agent_to_constrain, conflict.location1, conflict.timestep)
            elif conflict.type == Conflict.EDGE:
                edge = conflict.location1 if agent_to_constrain == conflict.agent1_id else conflict.location2
                new_constr = Constraint(agent_to_constrain, edge[1], conflict.timestep, True, edge[0])
            else: continue
            if new_constr in new_constraints_for_child: continue
            new_constraints_for_child.append(new_constr)
            child_node = CBSHighLevelNode(constraints=new_constraints_for_child, paths={}, sum_of_costs=0)
            path_valid_for_child = True; temp_child_paths = {}
            constrained_agent_info = next((a for a in agents_data if a['id'] == agent_to_constrain), None)

            if not constrained_agent_info: path_valid_for_child = False
            else:
                dynamic_cost_map_constrained = agent_memories[agent_to_constrain].dynamic_cost_map if agent_memories and agent_to_constrain in agent_memories else None
                new_path_constrained_agent = low_level_astar_for_cbs(
                    agent_to_constrain, constrained_agent_info['start_local'], constrained_agent_info['goal_local'],
                    obstacles_local_map, child_node.constraints, max_plan_len,
                    dynamic_cost_map=dynamic_cost_map_constrained, # Pass map in child node
                    true_agent_global_goal_abs=agents_true_global_goals_abs.get(agent_to_constrain) if agents_true_global_goals_abs else None,
                    persistent_map_info=persistent_map_bundle, verbose_cbs=verbose_cbs_solver)
                if new_path_constrained_agent is None: path_valid_for_child = False
                else:
                    temp_child_paths[agent_to_constrain] = new_path_constrained_agent
                    for other_agent_info in agents_data:
                        other_id = other_agent_info['id']
                        if other_id in current_node.paths:
                            path_to_copy = temp_child_paths.get(other_id, current_node.paths[other_id])
                            temp_child_paths[other_id] = path_to_copy
                            goal = path_to_copy[-1]
                            try: cost = path_to_copy.index(goal)
                            except ValueError: cost = len(path_to_copy) - 1
                            child_node.sum_of_costs += cost
                        else: path_valid_for_child = False; break
            if path_valid_for_child:
                child_node.paths = temp_child_paths
                child_node.conflicts = detect_all_conflicts_spacetime(child_node.paths, max_plan_len)
                heapq.heappush(open_list, child_node)
                
    if verbose_cbs_solver: logging.warning(f"[CBS] Failed. Iterations: {iteration_count}, OpenListEmpty: {not open_list}")
    return None