"""
Search algorithms for Robust Optimal Transport.

This module contains depth-first search algorithms used for finding
augmenting paths and cycles in the transport network, including
functions for both augmentation and consolidation phases.
"""

import numpy as np
from typing import List, Optional, Tuple, Union


def partial_dfs_r(U: List[int], backwards: List[List[int]], curr_r: int, 
                  curr_path: List[int], C: List[List[Tuple[int, float]]], 
                  distance: float, residuals_b: np.ndarray, B_weights: np.ndarray,
                  lambda_val: float, distance_matrix: List[List[float]], 
                  transport_plan: np.ndarray) -> Optional[List[int]]:
    """
    Partial DFS from representative point (augmentation phase).
    
    Args:
        U: Set of unvisited B points
        backwards: Backward edges from B to representative points
        curr_r: Current representative point index
        curr_path: Current path being explored
        C: Sorted distance information for each representative point
        distance: Distance threshold
        residuals_b: Residual mass at B points
        B_weights: Weights for B points
        lambda_val: Lambda parameter
        distance_matrix: Distance matrix
        transport_plan: Current transport plan
        
    Returns:
        Augmenting path or cycle if found, None otherwise
    """
    for i in range(len(C[curr_r])):
        if C[curr_r][i][0] in U and C[curr_r][i][1] < distance:
            # Check for cycle
            for j in range(1, len(curr_path), 2):
                if C[curr_r][i][0] == curr_path[j]:
                    cycle = curr_path[j:] + [C[curr_r][i][0]]
                    return cycle
            
            # Continue DFS
            P = partial_dfs_b(U, backwards, C[curr_r][i][0], 
                             curr_path + [C[curr_r][i][0]], C, residuals_b, 
                             B_weights, lambda_val, distance_matrix, transport_plan)
            if P is not None:
                return P
    return None


def partial_dfs_b(U: List[int], backwards: List[List[int]], curr_b: int,
                  curr_path: List[int], C: List[List[Tuple[int, float]]], 
                  residuals_b: np.ndarray, B_weights: np.ndarray, lambda_val: float,
                  distance_matrix: List[List[float]], 
                  transport_plan: np.ndarray) -> Optional[List[int]]:
    """
    Partial DFS from B point (augmentation phase).
    
    Args:
        U: Set of unvisited B points
        backwards: Backward edges from B to representative points
        curr_b: Current B point index
        curr_path: Current path being explored
        C: Sorted distance information
        residuals_b: Residual mass at B points
        B_weights: Weights for B points
        lambda_val: Lambda parameter
        distance_matrix: Distance matrix
        transport_plan: Current transport plan
        
    Returns:
        Augmenting path if found, None otherwise
    """
    # Check if we found a deficit B point
    if residuals_b[curr_b] > 1e-6:
        return curr_path
    
    # Check if at lambda boundary
    if abs(B_weights[curr_b] - lambda_val) < 1e-6:
        return curr_path
    
    # Explore backward edges
    to_remove = []
    for i in backwards[curr_b]:
        if transport_plan[i][curr_b] > 1e-6:
            P = partial_dfs_r(U, backwards, i, curr_path + [i], C, 
                             distance_matrix[i][curr_b], residuals_b, 
                             B_weights, lambda_val, distance_matrix, transport_plan)
            if P is not None:
                for j in to_remove:
                    if j in backwards[curr_b]:
                        backwards[curr_b].remove(j)
                return P
        to_remove.append(i)
    
    if curr_b in U:
        U.remove(curr_b)
    return None


def partial_dfs_r_weights(U: List[int], forwards: List[List[int]], curr_r: int,
                         C: List[List[Tuple[int, float]]], distance: float,
                         lambda_val: float, distance_matrix: List[List[float]], 
                         K: List[int]) -> None:
    """
    Partial DFS for weight increase (from representative point).
    
    Args:
        U: Set of unvisited B points
        forwards: Forward edges from B to representative points
        curr_r: Current representative point index
        C: Sorted distance information
        distance: Distance threshold
        lambda_val: Lambda parameter
        distance_matrix: Distance matrix
        K: List to collect reachable B points
    """
    for i in range(len(C[curr_r])):
        if C[curr_r][i][0] in U and C[curr_r][i][1] < distance:
            partial_dfs_b_weights(U, forwards, C[curr_r][i][0], C, 
                                lambda_val, distance_matrix, K)


def partial_dfs_b_weights(U: List[int], forwards: List[List[int]], curr_b: int,
                         C: List[List[Tuple[int, float]]], lambda_val: float,
                         distance_matrix: List[List[float]], 
                         K: List[int]) -> None:
    """
    Partial DFS for weight increase (from B point).
    
    Args:
        U: Set of unvisited B points
        forwards: Forward edges from B to representative points
        curr_b: Current B point index
        C: Sorted distance information
        lambda_val: Lambda parameter
        distance_matrix: Distance matrix
        K: List to collect reachable B points
    """
    if curr_b in U:
        U.remove(curr_b)
    
    # Explore forward edges and remove them
    edges_to_remove = forwards[curr_b][:]
    for i in edges_to_remove:
        if i in forwards[curr_b]:
            forwards[curr_b].remove(i)
        partial_dfs_r_weights(U, forwards, i, C, distance_matrix[i][curr_b], 
                            lambda_val, distance_matrix, K)
    
    K.append(curr_b)


def partial_dfs_r_cons(U: List[int], backwards: List[List[int]], curr_r: int,
                      curr_path: List[int], C: List[List[Tuple[int, float]]], 
                      distance: float, residuals_b: np.ndarray, 
                      B_weights: np.ndarray, lambda_val: float,
                      distance_matrix: List[List[float]]) -> Tuple[Optional[List[int]], Optional[bool]]:
    """
    Partial DFS from representative point (consolidation phase).
    
    Args:
        U: Set of unvisited B points
        backwards: Backward edges from B to representative points
        curr_r: Current representative point index
        curr_path: Current path being explored
        C: Sorted distance information
        distance: Distance threshold
        residuals_b: Residual mass at B points
        B_weights: Weights for B points
        lambda_val: Lambda parameter
        distance_matrix: Distance matrix
        
    Returns:
        Tuple of (path, is_cycle) where path is the found path/cycle and
        is_cycle indicates if it's a cycle
    """
    for i in range(len(C[curr_r])):
        if C[curr_r][i][0] in U and C[curr_r][i][1] < distance:
            # Check for cycle
            for j in range(1, len(curr_path), 2):
                if C[curr_r][i][0] == curr_path[j]:
                    cycle = curr_path[j:] + [C[curr_r][i][0]]
                    return cycle, True
            
            # Continue DFS
            P, is_cycle = partial_dfs_b_cons(U, backwards, C[curr_r][i][0],
                                           curr_path + [C[curr_r][i][0]], C,
                                           residuals_b, B_weights, lambda_val, 
                                           distance_matrix)
            if P is not None:
                return P, is_cycle
    return None, None


def partial_dfs_b_cons(U: List[int], backwards: List[List[int]], curr_b: int,
                      curr_path: List[int], C: List[List[Tuple[int, float]]], 
                      residuals_b: np.ndarray, B_weights: np.ndarray, 
                      lambda_val: float, 
                      distance_matrix: List[List[float]]) -> Tuple[Optional[List[int]], Optional[bool]]:
    """
    Partial DFS from B point (consolidation phase).
    
    Args:
        U: Set of unvisited B points
        backwards: Backward edges from B to representative points
        curr_b: Current B point index
        curr_path: Current path being explored
        C: Sorted distance information
        residuals_b: Residual mass at B points
        B_weights: Weights for B points
        lambda_val: Lambda parameter
        distance_matrix: Distance matrix
        
    Returns:
        Tuple of (path, is_cycle) where path is the found path/cycle and
        is_cycle indicates if it's a cycle
    """
    # Check termination conditions
    if residuals_b[curr_b] > 1e-6:
        return curr_path, False
    
    # Explore backward edges
    for i in backwards[curr_b]:
        if C[i][0][1] > 1e-6:
            return curr_path + [i], False
        
        P, is_cycle = partial_dfs_r_cons(U, backwards, i, curr_path + [i], C,
                                       distance_matrix[i][curr_b], residuals_b,
                                       B_weights, lambda_val, distance_matrix)
        if P is not None:
            return P, is_cycle
        
        if i in backwards[curr_b]:
            backwards[curr_b].remove(i)
    
    if curr_b in U:
        U.remove(curr_b)
    return None, None


def partial_dfs_r_red_weights(U: List[int], backwards: List[List[int]], 
                             curr_r: int, C: List[List[Tuple[int, float]]], 
                             distance: float, lambda_val: float,
                             distance_matrix: List[List[float]], 
                             K: List[int]) -> None:
    """
    Partial DFS for weight reduction (from representative point).
    
    Args:
        U: Set of unvisited B points
        backwards: Backward edges from B to representative points
        curr_r: Current representative point index
        C: Sorted distance information
        distance: Distance threshold
        lambda_val: Lambda parameter
        distance_matrix: Distance matrix
        K: List to collect reachable B points
    """
    for i in range(len(C[curr_r])):
        if C[curr_r][i][0] in U and C[curr_r][i][1] < distance:
            partial_dfs_b_red_weights(U, backwards, C[curr_r][i][0], C,
                                    lambda_val, distance_matrix, K)


def partial_dfs_b_red_weights(U: List[int], backwards: List[List[int]], 
                             curr_b: int, C: List[List[Tuple[int, float]]], 
                             lambda_val: float, distance_matrix: List[List[float]], 
                             K: List[int]) -> None:
    """
    Partial DFS for weight reduction (from B point).
    
    Args:
        U: Set of unvisited B points
        backwards: Backward edges from B to representative points
        curr_b: Current B point index
        C: Sorted distance information
        lambda_val: Lambda parameter
        distance_matrix: Distance matrix
        K: List to collect reachable B points
    """
    if curr_b in U:
        U.remove(curr_b)
    
    # Explore backward edges and remove them
    edges_to_remove = backwards[curr_b][:]
    for i in edges_to_remove:
        if i in backwards[curr_b]:
            backwards[curr_b].remove(i)
        partial_dfs_r_red_weights(U, backwards, i, C, distance_matrix[i][curr_b],
                                lambda_val, distance_matrix, K)
    
    K.append(curr_b)