"""
Transport optimization module for Robust Optimal Transport.

This module contains the main optimization functions for the ROT algorithm,
including search and augment, search and consolidate procedures, and
weight adjustment mechanisms.
"""

import numpy as np
from typing import List, Tuple
from .search_algorithms import (
    partial_dfs_r, partial_dfs_b_weights, partial_dfs_r_weights,
    partial_dfs_r_cons, partial_dfs_b_red_weights, partial_dfs_r_red_weights
)
from .discrete_set import discrete_set


def search_and_augment(A: List, A_mass: List[float], B: np.ndarray, 
                      B_mass: np.ndarray, C: List[List[Tuple[int, float]]], 
                      B_weights: np.ndarray, delta: float, 
                      transport_plan: np.ndarray, lambda_val: float,
                      distance_matrix: List[List[float]]) -> Tuple[int, int]:
    """
    Search for augmenting paths and cycles to improve transport plan.
    
    Args:
        A: Representative points from discrete set
        A_mass: Mass for representative points
        B: Target points
        B_mass: Mass for target points
        C: Sorted distance information
        B_weights: Weights for B points
        delta: Approximation parameter
        transport_plan: Current transport plan
        lambda_val: Lambda parameter
        distance_matrix: Distance matrix
        
    Returns:
        Tuple of (total_path_length, total_cycle_length)
    """
    U = [i for i in range(len(B_mass))]
    backwards = [[i for i in range(len(A)) if transport_plan[i][j] > 1e-6] 
                 for j in range(len(B))]

    transported_a = np.sum(transport_plan, axis=1)
    transported_b = np.sum(transport_plan, axis=0)
    
    residual_a = np.array(A_mass) - transported_a
    residual_b = B_mass - transported_b
    
    total_path_length = 0
    total_cycle_length = 0
    
    i = 0
    # Iterate over all deficit points in A
    while i < len(A):
        if residual_a[i] > 1e-6:
            P = partial_dfs_r(U, backwards, i, [i], C, 2 * delta, residual_b, 
                             B_weights, lambda_val, distance_matrix, transport_plan)
            if P is not None:
                if len(P) % 2 == 1:  # Admissible cycle
                    mass_to_transport = 1
                    for j in range(0, len(P) - 1, 2):
                        mass_to_transport = min(mass_to_transport, 
                                              transport_plan[P[j+1]][P[j]])
                    
                    for j in range(1, len(P), 2):
                        transport_plan[P[j]][P[j - 1]] -= mass_to_transport
                        transport_plan[P[j]][P[j + 1]] += mass_to_transport
                        if transport_plan[P[j]][P[j - 1]] < 1e-6:
                            if P[j] in backwards[P[j-1]]:
                                backwards[P[j-1]].remove(P[j])
                    
                    total_cycle_length += len(P) - 1
                else:  # Augmenting path
                    mass_to_transport = min(residual_a[i], residual_b[P[-1]])
                    
                    for j in range(1, len(P) - 1, 2):
                        mass_to_transport = min(mass_to_transport, 
                                              transport_plan[P[j + 1]][P[j]])
                    
                    for j in range(2, len(P), 2):
                        transport_plan[P[j]][P[j - 1]] -= mass_to_transport
                        transport_plan[P[j]][P[j + 1]] += mass_to_transport
                        if transport_plan[P[j]][P[j - 1]] < 1e-6:
                            if P[j] in backwards[P[j-1]]:
                                backwards[P[j-1]].remove(P[j])
                    
                    transport_plan[P[0]][P[1]] += mass_to_transport
                    residual_a[i] -= mass_to_transport
                    residual_b[P[-1]] -= mass_to_transport
                    total_path_length += len(P)
                
                i -= 1
        i += 1
                    
    return total_path_length, total_cycle_length


def increase_weights(A: np.ndarray, A_mass: np.ndarray, A_delta: List, 
                    A_delta_mass: List[float], B: np.ndarray, B_mass: np.ndarray,
                    C: List[List[Tuple[int, float]]], B_weights: np.ndarray, 
                    delta: float, lambda_val: float, transport_plan: np.ndarray,
                    arrangement: List, distance_matrix: List[List[float]]) -> Tuple:
    """
    Increase weights for active surplus points and update discrete set.
    
    Args:
        A: Original source points
        A_mass: Original mass distribution
        A_delta: Current representative points
        A_delta_mass: Mass for representative points
        B: Target points
        B_mass: Mass for target points
        C: Sorted distance information
        B_weights: Weights for B points
        delta: Approximation parameter
        lambda_val: Lambda parameter
        transport_plan: Current transport plan
        arrangement: Current partition arrangement
        distance_matrix: Distance matrix
        
    Returns:
        Updated (A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, 
                arrangement, B_weights)
    """
    U = [i for i in range(len(B_mass))]
    forwards = [[i for i in range(len(A_delta))] for _ in range(len(B_mass))]
    
    K = []
    
    transported_b = np.sum(transport_plan, axis=0)
    residual_b = B_mass - transported_b
    
    active_surplus = [i for i in range(len(B)) 
                      if residual_b[i] > 0 and B_weights[i] < lambda_val - 1e-6]
    
    for i in active_surplus:
        if i in U:
            partial_dfs_b_weights(U, forwards, i, C, lambda_val, 
                                distance_matrix, K)
    
    # Increase weights for reachable points
    for i in K:
        B_weights[i] += delta
        
    # Reconstruct full transport plan
    sd_transport_plan = np.zeros((len(A), len(B)))
    for i in range(len(arrangement)):
        for j in arrangement[i][-1]:
            for k in range(len(B)):
                sd_transport_plan[j][k] = (transport_plan[i][k] * 
                                         A_mass[j] / A_delta_mass[i])
    
    # Update discrete set
    A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, arrangement, drawing_points = discrete_set(
        A, A_mass, B, B_weights, delta, sd_transport_plan)
    
    return (A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, 
            arrangement, B_weights)


def search_and_augment_weights(A: np.ndarray, A_mass: np.ndarray, B: np.ndarray,
                              B_mass: np.ndarray, B_weights: np.ndarray, 
                              delta: float, lambda_val: float, A_delta: List,
                              A_delta_mass: List[float], transport_plan_hat: np.ndarray,
                              C: List[List[Tuple[int, float]]], 
                              distance_matrix: List[List[float]], 
                              arrangement: List) -> Tuple:
    """
    Main procedure for search and augment with weight increases.
    
    Args:
        A: Original source points
        A_mass: Original mass distribution
        B: Target points
        B_mass: Mass for target points
        B_weights: Weights for B points
        delta: Approximation parameter
        lambda_val: Lambda parameter
        A_delta: Current representative points
        A_delta_mass: Mass for representative points
        transport_plan_hat: Current transport plan
        C: Sorted distance information
        distance_matrix: Distance matrix
        arrangement: Current partition arrangement
        
    Returns:
        Updated parameters and statistics
    """
    transport_plan = np.zeros((len(A), len(B)))
    A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, arrangement, drawing_points = discrete_set(
        A, A_mass, B, B_weights, delta, transport_plan)

    total_path_length = 0
    total_cycle_length = 0
    iters = 0
    
    # Find residual weights and active surplus
    residual_b = B_mass - np.sum(transport_plan_hat, axis=0)
    active_surplus = [i for i in range(len(B)) 
                      if B_weights[i] < lambda_val - 1e-6 and residual_b[i] > 1e-6]
    
    while active_surplus:
        iters += 1
        path_length, cycle_length = search_and_augment(
            A_delta, A_delta_mass, B, B_mass, C, B_weights, delta, 
            transport_plan_hat, lambda_val, distance_matrix)
        
        A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, arrangement, B_weights = increase_weights(
            A, A_mass, A_delta, A_delta_mass, B, B_mass, C, B_weights, 
            delta, lambda_val, transport_plan_hat, arrangement, distance_matrix)
        
        residual_b = B_mass - np.sum(transport_plan_hat, axis=0)
        active_surplus = [i for i in range(len(B)) 
                         if B_weights[i] < lambda_val - 1e-6 and residual_b[i] > 1e-6]
        
        total_path_length += path_length
        total_cycle_length += cycle_length
    
    return (A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, 
            arrangement, B_weights, total_path_length, total_cycle_length, iters)


def search_and_consolidate(A: List, A_mass: List[float], B: np.ndarray, 
                          B_mass: np.ndarray, C: List[List[Tuple[int, float]]], 
                          B_weights: np.ndarray, delta: float, 
                          transport_plan: np.ndarray, lambda_val: float,
                          distance_matrix: List[List[float]], 
                          eps: float = 1e-8) -> Tuple[int, int]:
    """
    Search for consolidating paths and cycles to reduce transport cost.
    
    Args:
        A: Representative points from discrete set
        A_mass: Mass for representative points
        B: Target points
        B_mass: Mass for target points
        C: Sorted distance information
        B_weights: Weights for B points
        delta: Approximation parameter
        transport_plan: Current transport plan
        lambda_val: Lambda parameter
        distance_matrix: Distance matrix
        eps: Numerical precision threshold
        
    Returns:
        Tuple of (total_path_length, total_cycle_length)
    """
    U = [i for i in range(len(B_mass))]
    backwards = [[i for i in range(len(A)) if transport_plan[i][j] > 0] 
                 for j in range(len(B))]

    transported_a = np.sum(transport_plan, axis=1)
    transported_b = np.sum(transport_plan, axis=0)
    
    residual_a = np.array(A_mass) - transported_a
    residual_b = B_mass - transported_b
    
    total_cycle_length = 0
    total_path_length = 0
    
    i = 0
    # Iterate over all deficit points in A
    while i < len(A):
        if residual_a[i] > eps and C[i][0][1] < -eps:
            P, is_cycle = partial_dfs_r_cons(U, backwards, i, [i], C, 2 * delta, 
                                           residual_b, B_weights, lambda_val, 
                                           distance_matrix)
            if P is not None:
                if len(P) % 2 == 1:
                    if is_cycle:  # Consolidating cycle
                        total_cycle_length += len(P) - 1
                        mass_to_transport = 1
                        for j in range(1, len(P), 2):
                            mass_to_transport = min(mass_to_transport, 
                                                  transport_plan[P[j]][P[j-1]])
                        
                        for j in range(1, len(P), 2):
                            transport_plan[P[j]][P[j - 1]] -= mass_to_transport
                            transport_plan[P[j]][P[j + 1]] += mass_to_transport
                            if transport_plan[P[j]][P[j - 1]] < 1e-6:
                                if P[j] in backwards[P[j-1]]:
                                    backwards[P[j-1]].remove(P[j])
                                transport_plan[P[j]][P[j - 1]] = 0
                    else:  # Consolidating path
                        total_path_length += len(P)
                        mass_to_transport = 1
                        for j in range(1, len(P), 2):
                            mass_to_transport = min(mass_to_transport, 
                                                  transport_plan[P[j+1]][P[j]])
                        
                        for j in range(1, len(P), 2):
                            transport_plan[P[j - 1]][P[j]] += mass_to_transport
                            transport_plan[P[j + 1]][P[j]] -= mass_to_transport
                            if transport_plan[P[j + 1]][P[j]] < eps:
                                if P[j+1] in backwards[P[j]]:
                                    backwards[P[j]].remove(P[j+1])
                                transport_plan[P[j + 1]][P[j]] = 0
                else:  # Augmenting path
                    total_path_length += len(P)
                    mass_to_transport = min(residual_a[i], residual_b[P[-1]])
                    
                    for j in range(1, len(P) - 1, 2):
                        mass_to_transport = min(mass_to_transport, 
                                              transport_plan[P[j + 1]][P[j]])
                    
                    for j in range(2, len(P), 2):
                        transport_plan[P[j]][P[j - 1]] -= mass_to_transport
                        transport_plan[P[j]][P[j + 1]] += mass_to_transport
                        if transport_plan[P[j]][P[j - 1]] < eps:
                            if P[j] in backwards[P[j-1]]:
                                backwards[P[j-1]].remove(P[j])
                            transport_plan[P[j]][P[j - 1]] = 0
                    
                    transport_plan[P[0]][P[1]] += mass_to_transport
                    residual_a[i] -= mass_to_transport
                    residual_b[P[-1]] -= mass_to_transport
                
                i -= 1
        i += 1
        
    return total_path_length, total_cycle_length


def decrease_weights(A: np.ndarray, A_mass: np.ndarray, A_delta: List,
                    A_delta_mass: List[float], B: np.ndarray, B_mass: np.ndarray,
                    C: List[List[Tuple[int, float]]], B_weights: np.ndarray, 
                    delta: float, lambda_val: float, transport_plan: np.ndarray,
                    arrangement: List, distance_matrix: List[List[float]], 
                    eps: float = 1e-8) -> Tuple:
    """
    Decrease weights for violating points and update discrete set.
    
    Args:
        A: Original source points
        A_mass: Original mass distribution
        A_delta: Current representative points
        A_delta_mass: Mass for representative points
        B: Target points
        B_mass: Mass for target points
        C: Sorted distance information
        B_weights: Weights for B points
        delta: Approximation parameter
        lambda_val: Lambda parameter
        transport_plan: Current transport plan
        arrangement: Current partition arrangement
        distance_matrix: Distance matrix
        eps: Numerical precision threshold
        
    Returns:
        Updated parameters
    """
    U = [i for i in range(len(B_mass))]
    backwards = [[i for i in range(len(A_delta)) if transport_plan[i][j] > 0] 
                 for j in range(len(B))]
    
    K = []
    
    transported_a = np.sum(transport_plan, axis=1)
    residual_a = np.array(A_delta_mass) - transported_a
    
    violating = [i for i in range(len(A_delta)) 
                 if residual_a[i] > eps and C[i][0][1] < -eps]
    
    for i in violating:
        partial_dfs_r_red_weights(U, backwards, i, C, 2 * delta, lambda_val, 
                                distance_matrix, K)
    
    # Decrease weights for reachable points
    for i in K:
        B_weights[i] -= delta
        
    # Reconstruct full transport plan
    sd_transport_plan = np.zeros((len(A), len(B)))
    for i in range(len(arrangement)):
        for j in arrangement[i][-1]:
            for k in range(len(B)):
                sd_transport_plan[j][k] = (transport_plan[i][k] * 
                                         A_mass[j] / A_delta_mass[i])
    
    # Update discrete set
    A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, arrangement, drawing_points = discrete_set(
        A, A_mass, B, B_weights, delta, sd_transport_plan)
    
    return (A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, 
            arrangement, B_weights)


def search_and_consolidate_red_weights(A: np.ndarray, A_mass: np.ndarray, 
                                      B: np.ndarray, B_mass: np.ndarray, 
                                      B_weights: np.ndarray, delta: float, 
                                      lambda_val: float, A_delta: List,
                                      A_delta_mass: List[float], 
                                      transport_plan_hat: np.ndarray,
                                      C: List[List[Tuple[int, float]]], 
                                      distance_matrix: List[List[float]], 
                                      arrangement: List, 
                                      eps: float = 1e-8) -> Tuple:
    """
    Main procedure for search and consolidate with weight reductions.
    
    Args:
        A: Original source points
        A_mass: Original mass distribution
        B: Target points
        B_mass: Mass for target points
        B_weights: Weights for B points
        delta: Approximation parameter
        lambda_val: Lambda parameter
        A_delta: Current representative points
        A_delta_mass: Mass for representative points
        transport_plan_hat: Current transport plan
        C: Sorted distance information
        distance_matrix: Distance matrix
        arrangement: Current partition arrangement
        eps: Numerical precision threshold
        
    Returns:
        Updated parameters and statistics
    """
    # Reconstruct full transport plan selectively
    sd_transport_plan = np.zeros((len(A), len(B)))
    for i in range(len(arrangement)):
        for k in range(len(B)):
            if distance_matrix[i][k] < 2 * delta:
                for j in arrangement[i][-1]:
                    sd_transport_plan[j][k] = (transport_plan_hat[i][k] * 
                                             A_mass[j] / A_delta_mass[i])
        
    A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, arrangement, drawing_points = discrete_set(
        A, A_mass, B, B_weights, delta, sd_transport_plan)
                
    # Find violating points
    residual_a = np.array(A_delta_mass) - np.sum(transport_plan_hat, axis=1)
    violating = [i for i in range(len(A_delta)) 
                 if C[i][0][1] < -eps and residual_a[i] > eps]
    
    total_path_length = 0
    total_cycle_length = 0
    iters = 0
    
    while violating:
        iters += 1
        path_length, cycle_length = search_and_consolidate(
            A_delta, A_delta_mass, B, B_mass, C, B_weights, delta, 
            transport_plan_hat, lambda_val, distance_matrix, eps)
        
        A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, arrangement, B_weights = decrease_weights(
            A, A_mass, A_delta, A_delta_mass, B, B_mass, C, B_weights, 
            delta, lambda_val, transport_plan_hat, arrangement, distance_matrix, eps)
        
        residual_a = np.array(A_delta_mass) - np.sum(transport_plan_hat, axis=1)
        violating = [i for i in range(len(A_delta)) 
                     if C[i][0][1] < -eps and residual_a[i] > eps]
        
        total_path_length += path_length
        total_cycle_length += cycle_length
    
    return (A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, 
            arrangement, B_weights, total_path_length, total_cycle_length, iters)