
import heapq
import itertools
from collections import defaultdict
import logging
from typing import Optional, Tuple, List, Dict
import random #  

import numba
import numpy as np
from numba.core import types
from numba.typed import Dict as NumbaDict

ACTION_DELTAS = np.array([[0, 0], [-1, 0], [1, 0], [0, -1], [0, 1]], dtype=np.int8)
POS_KEY_TYPE = types.UniTuple(types.int64, 3)

class Constraint: #
    def __init__(self, agent_id, loc, timestep, is_edge_constraint=False, prev_loc=None):
        self.agent_id = agent_id; self.location = loc; self.timestep = timestep
        self.is_edge_constraint = is_edge_constraint; self.prev_location = prev_loc
    def __eq__(self, other):
        if not isinstance(other, Constraint): return False
        return (self.agent_id == other.agent_id and self.location == other.location and self.timestep == other.timestep and
                self.is_edge_constraint == other.is_edge_constraint and self.prev_location == other.prev_location)
    def __hash__(self): return hash((self.agent_id, self.location, self.timestep, self.is_edge_constraint, self.prev_location))
    def __repr__(self):
        if self.is_edge_constraint: return f"EdgeConstraint(A{self.agent_id}: {self.prev_location}->{self.location} @t={self.timestep})"
        return f"VertexConstraint(A{self.agent_id}: {self.location} @t={self.timestep})"

class Conflict: #
    VERTEX = 1; EDGE = 2
    def __init__(self, conflict_type, agent1_id, agent2_id, loc1, timestep, loc2=None):
        self.type = conflict_type; self.agent1_id = agent1_id; self.agent2_id = agent2_id
        self.location1 = loc1; self.timestep = timestep; self.location2 = loc2
    def __repr__(self):
        if self.type == Conflict.VERTEX: return f"VertexConflict(A{self.agent1_id}, A{self.agent2_id} @ {self.location1}, t={self.timestep})"
        elif self.type == Conflict.EDGE: return (f"EdgeConflict(A{self.agent1_id} {self.location1[0]}->{self.location1[1]} vs A{self.agent2_id} {self.location2[0]}->{self.location2[1]} @t={self.timestep})")
        return "UnknownConflict"

class CBSHighLevelNode: #
    _ids = itertools.count(0)
    def __init__(self, constraints: Optional[List[Constraint]] = None, paths=None, sum_of_costs=0):
        self.id = next(self._ids); self.constraints = constraints if constraints is not None else []
        self.paths = paths if paths is not None else {}; self.sum_of_costs = sum_of_costs; self.conflicts = []
    def __lt__(self, other):
        if len(self.conflicts) != len(other.conflicts): return len(self.conflicts) < len(other.conflicts)
        if self.sum_of_costs != other.sum_of_costs: return self.sum_of_costs < other.sum_of_costs
        return self.id < other.id

@numba.jit(nopython=True)
def _numba_low_level_astar( #
    start_pos: Tuple[int, int],
    goal_pos: Tuple[int, int],
    obstacles_local: np.ndarray,
    constraints_array: np.ndarray,
    max_path_len: int,
):

    map_height, map_width = obstacles_local.shape
    def h_func(p1_r, p1_c, p2_r, p2_c): return abs(p1_r - p2_r) + abs(p1_c - p2_c)
    open_set = [(h_func(start_pos[0], start_pos[1], goal_pos[0], goal_pos[1]), 0, start_pos[0], start_pos[1], 0)]
    came_from = NumbaDict.empty(key_type=POS_KEY_TYPE, value_type=POS_KEY_TYPE)
    g_scores = NumbaDict.empty(key_type=POS_KEY_TYPE, value_type=types.int64)
    start_key = (start_pos[0], start_pos[1], 0)
    g_scores[start_key] = 0
    while len(open_set) > 0:
        f, g, r, c, t = heapq.heappop(open_set)
        current_key = (r, c, t)
        if g > g_scores.get(current_key, np.iinfo(np.int64).max): continue
        if (r, c) == goal_pos:
            path = []
            curr = current_key
            while curr in came_from:
                path.append((curr[0], curr[1]))
                curr = came_from[curr]
            path.append((start_pos[0], start_pos[1]))
            path.reverse()
            while len(path) < max_path_len + 1: path.append(goal_pos)
            return path
        if t >= max_path_len: continue
        for dr_dc_idx in [1, 2, 3, 4, 0]:
            dr, dc = ACTION_DELTAS[dr_dc_idx]
            nr, nc = r + dr, c + dc
            nt = t + 1
            if not (0 <= nr < map_height and 0 <= nc < map_width and not obstacles_local[nr, nc]): continue
            is_move_valid = True
            for i in range(constraints_array.shape[0]):
                constr = constraints_array[i]
                if constr[3] == 0:
                    if constr[2] == nt and constr[0] == nr and constr[1] == nc: is_move_valid = False; break
                else:
                    if constr[2] == nt and constr[4] == r and constr[5] == c and constr[0] == nr and constr[1] == nc: is_move_valid = False; break
            if not is_move_valid: continue
            new_g = g + 1
            neighbor_key = (nr, nc, nt)
            if new_g < g_scores.get(neighbor_key, np.iinfo(np.int64).max):
                g_scores[neighbor_key] = new_g
                new_f = new_g + h_func(nr, nc, goal_pos[0], goal_pos[1])
                heapq.heappush(open_set, (new_f, new_g, nr, nc, nt))
                came_from[neighbor_key] = current_key
    return None

def low_level_astar_for_cbs( #
    agent_id, start_pos, original_goal_pos, obstacles_local, constraints,
    max_plan_len: int,
    true_agent_global_goal_abs: Optional[Tuple[int, int]] = None,
    persistent_map_info: Optional[Dict] = None,
    verbose_cbs: bool = False ):
    final_target_goal_for_astar = original_goal_pos
    if true_agent_global_goal_abs and persistent_map_info:
        pk_map = persistent_map_info.get('persistent_known_map'); map_origin_r = persistent_map_info.get('map_global_origin_r'); map_origin_c = persistent_map_info.get('map_global_origin_c'); FREE_CELL_CONST = persistent_map_info.get('FREE_CELL')
        if pk_map is not None and map_origin_r is not None and map_origin_c is not None and FREE_CELL_CONST is not None:
            true_goal_r_cbs_map = true_agent_global_goal_abs[0] - map_origin_r; true_goal_c_cbs_map = true_agent_global_goal_abs[1] - map_origin_c
            cbs_map_h, cbs_map_w = obstacles_local.shape; pk_map_h, pk_map_w = pk_map.shape
            if (0 <= true_goal_r_cbs_map < cbs_map_h and 0 <= true_goal_c_cbs_map < cbs_map_w and 0 <= true_agent_global_goal_abs[0] < pk_map_h and 0 <= true_agent_global_goal_abs[1] < pk_map_w and pk_map[true_agent_global_goal_abs[0], true_agent_global_goal_abs[1]] == FREE_CELL_CONST and not obstacles_local[true_goal_r_cbs_map, true_goal_c_cbs_map]):
                final_target_goal_for_astar = (true_goal_r_cbs_map, true_goal_c_cbs_map)
    agent_constraints = []
    for c in constraints:
        if c.agent_id == agent_id:
            if c.is_edge_constraint:
                agent_constraints.append([c.location[0], c.location[1], c.timestep, 1, c.prev_location[0], c.prev_location[1]])
            else:
                agent_constraints.append([c.location[0], c.location[1], c.timestep, 0, -1, -1])
    if not agent_constraints:
        constraints_array = np.empty((0, 6), dtype=np.int32)
    else:
        constraints_array = np.array(agent_constraints, dtype=np.int32)
    path_tuples = _numba_low_level_astar(
        start_pos, final_target_goal_for_astar, obstacles_local,
        constraints_array, max_plan_len,
    )
    if path_tuples is not None:
        return [tuple(p) for p in path_tuples]
    return None

def detect_all_conflicts_spacetime(paths_dict: Dict[int, List[Tuple[int, int]]], max_timestep: int):
    """
    Detects vertex and edge conflicts from a dictionary of paths.
    Correctly handles agents that reach their destination and "disappear" from the grid.
    """
    all_conflicts = []
    
    arrival_times = {}
    for aid, path in paths_dict.items():
        if path is None:
            return [] # 
        
        goal_pos = path[-1]
        try:
            arrival_time = path.index(goal_pos)
            arrival_times[aid] = arrival_time
        except ValueError:
            arrival_times[aid] = len(path)

    vertex_reservation = defaultdict(list)
    edge_reservation = defaultdict(list)

    for agent_id, path in paths_dict.items():
        agent_arrival_time = arrival_times[agent_id]
        
        for t, pos in enumerate(path):
            if t > agent_arrival_time:
                break # 
            vertex_reservation[(t, pos)].append(agent_id)
            
        for t in range(len(path) - 1):
            if t + 1 > agent_arrival_time:
                break # 
            
            pos1 = path[t]
            pos2 = path[t+1]
            if pos1 != pos2:
                edge_reservation[(t + 1, pos1, pos2)].append(agent_id)

    for (timestep, pos), occupants in vertex_reservation.items():
        if len(occupants) > 1:
            for agent1_id, agent2_id in itertools.combinations(occupants, 2):
                all_conflicts.append(Conflict(Conflict.VERTEX, agent1_id, agent2_id, pos, timestep))

    for (timestep, pos1, pos2), occupants1 in edge_reservation.items():
        if (timestep, pos2, pos1) in edge_reservation:
            occupants2 = edge_reservation[(timestep, pos2, pos1)]
            for agent1_id in occupants1:
                for agent2_id in occupants2:
                    if agent1_id < agent2_id:
                        all_conflicts.append(Conflict(Conflict.EDGE, agent1_id, agent2_id, 
                                                      (pos1, pos2), timestep, (pos2, pos1)))
                                                      
    return all_conflicts



def solve_local_cbs_robust( #
    agents_data: List[Dict], obstacles_local_map: np.ndarray, max_plan_len: int,
    max_cbs_iterations=500,
    initial_paths_dict: Optional[Dict[int, List[Tuple[int, int]]]] = None,
    agents_true_global_goals_abs: Optional[Dict[int, Tuple[int,int]]] = None,
    persistent_map_bundle: Optional[Dict] = None,
    verbose_cbs_solver: bool = False,
    initial_constraints_list: Optional[List[Constraint]] = None
    ):
    
    def is_cardinal(conflict: Conflict, agent_id: int, current_node: CBSHighLevelNode) -> bool:
        original_path = current_node.paths.get(agent_id)
        if not original_path: return False
        
        original_goal = original_path[-1]
        try:
            original_cost = original_path.index(original_goal)
        except ValueError:
            original_cost = len(original_path) -1

        if conflict.type == Conflict.VERTEX:
            new_constr = Constraint(agent_id, conflict.location1, conflict.timestep)
        elif conflict.type == Conflict.EDGE:
            edge = conflict.location1 if agent_id == conflict.agent1_id else conflict.location2
            new_constr = Constraint(agent_id, edge[1], conflict.timestep, True, edge[0])
        else:
            return False

        temp_constraints = current_node.constraints + [new_constr]
        agent_info = next((a for a in agents_data if a['id'] == agent_id), None)
        if not agent_info: return False

        new_path = low_level_astar_for_cbs(
            agent_id, agent_info['start_local'], agent_info['goal_local'],
            obstacles_local_map, temp_constraints, max_plan_len,
            true_agent_global_goal_abs=agents_true_global_goals_abs.get(agent_id) if agents_true_global_goals_abs else None,
            persistent_map_info=persistent_map_bundle, verbose_cbs=False
        )

        if new_path is None: return True
        
        new_goal = new_path[-1]
        try:
            new_cost = new_path.index(new_goal)
        except ValueError:
            new_cost = len(new_path) - 1
            
        return new_cost > original_cost

    open_list = []
    iteration_count = 0
    root_node = CBSHighLevelNode(constraints=initial_constraints_list if initial_constraints_list is not None else [])
    
    for agent_info in agents_data:
        agent_id = agent_info['id']; start_pos_local_cbs = agent_info['start_local']; original_proposal_goal_local_cbs = agent_info['goal_local']
        true_global_goal_for_agent = agents_true_global_goals_abs.get(agent_id) if agents_true_global_goals_abs else None
        path_coords = low_level_astar_for_cbs(agent_id, start_pos_local_cbs, original_proposal_goal_local_cbs, obstacles_local_map, root_node.constraints, max_plan_len, true_agent_global_goal_abs=true_global_goal_for_agent, persistent_map_info=persistent_map_bundle, verbose_cbs=verbose_cbs_solver)
        if path_coords is None:
            if verbose_cbs_solver: logging.warning(f"  [CBS RootFail] Agt {agent_id} initial path fail.")
            return None
        root_node.paths[agent_id] = path_coords
        goal = path_coords[-1]
        try:
            cost = path_coords.index(goal)
        except ValueError:
            cost = len(path_coords) - 1
        root_node.sum_of_costs += cost

    root_node.conflicts = detect_all_conflicts_spacetime(root_node.paths, max_plan_len)
    heapq.heappush(open_list, root_node)

    while open_list and iteration_count < max_cbs_iterations:
        current_node = heapq.heappop(open_list)
        iteration_count += 1
        
        if not current_node.conflicts:
            if verbose_cbs_solver: logging.info(f"[CBS] Solution found! Iterations={iteration_count}")
            return current_node.paths

        classified_conflicts = []
        for c in current_node.conflicts:
            is_cardinal_for_a1 = is_cardinal(c, c.agent1_id, current_node)
            is_cardinal_for_a2 = is_cardinal(c, c.agent2_id, current_node)
            priority = 0 if (is_cardinal_for_a1 or is_cardinal_for_a2) else 1
            classified_conflicts.append((priority, c))
            
        classified_conflicts.sort(key=lambda x: (x[0], x[1].timestep, x[1].type, x[1].agent1_id, x[1].agent2_id))
        conflict = classified_conflicts[0][1]
        agents_to_constrain = [conflict.agent1_id, conflict.agent2_id]

        selected_priority, conflict = classified_conflicts[0]

        if selected_priority == 1: # 0 for cardinal, 1 for non-cardinal
            random.shuffle(agents_to_constrain)
        for agent_to_constrain in agents_to_constrain:
            new_constraints_for_child = list(current_node.constraints)
            if conflict.type == Conflict.VERTEX: new_constr = Constraint(agent_to_constrain, conflict.location1, conflict.timestep)
            elif conflict.type == Conflict.EDGE:
                edge = conflict.location1 if agent_to_constrain == conflict.agent1_id else conflict.location2
                new_constr = Constraint(agent_to_constrain, edge[1], conflict.timestep, True, edge[0])
            else: continue
            if new_constr in new_constraints_for_child: continue
            new_constraints_for_child.append(new_constr)
            child_node = CBSHighLevelNode(constraints=new_constraints_for_child, paths={}, sum_of_costs=0)
            path_valid_for_child = True; temp_child_paths = {}
            constrained_agent_info = next((a for a in agents_data if a['id'] == agent_to_constrain), None)
            if not constrained_agent_info: path_valid_for_child = False
            else:
                original_proposal_goal_constrained = constrained_agent_info['goal_local']
                true_global_goal_constrained = agents_true_global_goals_abs.get(agent_to_constrain) if agents_true_global_goals_abs else None
                new_path_constrained_agent = low_level_astar_for_cbs(
                    agent_to_constrain, constrained_agent_info['start_local'], original_proposal_goal_constrained,
                    obstacles_local_map, child_node.constraints, max_plan_len,
                    true_agent_global_goal_abs=true_global_goal_constrained,
                    persistent_map_info=persistent_map_bundle, verbose_cbs=verbose_cbs_solver)
                if new_path_constrained_agent is None: path_valid_for_child = False
                else:
                    temp_child_paths[agent_to_constrain] = new_path_constrained_agent
                    
                    for other_agent_info in agents_data:
                        other_id = other_agent_info['id']
                        if other_id in current_node.paths:
                            path_to_copy = temp_child_paths.get(other_id, current_node.paths[other_id])
                            temp_child_paths[other_id] = path_to_copy
                            goal = path_to_copy[-1]
                            try:
                                cost = path_to_copy.index(goal)
                            except ValueError:
                                cost = len(path_to_copy) - 1
                            child_node.sum_of_costs += cost
                        else: path_valid_for_child = False; break
            if path_valid_for_child:
                child_node.paths = temp_child_paths
                child_node.conflicts = detect_all_conflicts_spacetime(child_node.paths, max_plan_len)
                heapq.heappush(open_list, child_node)
                
    if verbose_cbs_solver: logging.warning(f"[CBS] Failed. Iterations: {iteration_count}, OpenListEmpty: {not open_list}")
    return None