"""
Main Robust Optimal Transport (ROT) algorithm module.

This module contains the main compute_ROT function that orchestrates
the entire robust optimal transport computation process.
"""

import numpy as np
from typing import Tuple, List
from .discrete_set import discrete_set, reconstruct_full_transport_plan
from .transport_optimization import (
    search_and_augment_weights, search_and_consolidate_red_weights
)


def compute_ROT(A: np.ndarray, A_mass: np.ndarray, B: np.ndarray, 
               B_mass: np.ndarray, B_weights: np.ndarray, 
               lambda_val: float, min_delta: float = 0.0002,
               initial_delta: float = 1.0) -> Tuple:
    """
    Compute Robust Optimal Transport between source A and target B distributions.
    
    This is the main function that implements the ROT algorithm using a
    multi-scale approach with decreasing delta values.
    
    Args:
        A: Source points of shape (n, 2)
        A_mass: Mass distribution for A of shape (n,)
        B: Target points of shape (m, 2)
        B_mass: Mass distribution for B of shape (m,)
        B_weights: Initial weights for B points of shape (m,)
        lambda_val: Regularization parameter (lambda^2)
        min_delta: Minimum delta value to stop iteration
        initial_delta: Starting delta value
        
    Returns:
        Tuple containing:
        - transport_plan_hat: Final transport plan for representative points
        - B_weights: Final weights for B points
        - A_delta: Final representative points
        - sd_ot: Full transport plan reconstructed for original points
        - total_path_lengths_aug: Path lengths for augmentation phases
        - total_cycle_lengths_aug: Cycle lengths for augmentation phases  
        - total_path_lengths_cons: Path lengths for consolidation phases
        - total_cycle_lengths_cons: Cycle lengths for consolidation phases
        - iters_aug: Iteration counts for augmentation phases
        - iters_cons: Iteration counts for consolidation phases
        - regions: Number of regions (representative points) at each scale
        - final_delta: Final delta value used
    """
    delta = initial_delta
    transport_plan_hat = None
    
    # Statistics tracking
    total_path_lengths_aug = []
    total_cycle_lengths_aug = []
    total_path_lengths_cons = []
    total_cycle_lengths_cons = []
    iters_aug = []
    iters_cons = []
    regions = []
    
    # For visualization/analysis
    A_deltas = []
    transport_plans = []
    
    # Initialize with empty transport plan
    transport_plan = np.zeros((len(A), len(B)))
    A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, arrangement, drawing_points = discrete_set(
        A, A_mass, B, B_weights, delta, transport_plan)
    
    # Multi-scale optimization loop
    while delta > min_delta:
        print(f"Processing delta = {delta:.6f}")
        
        # Augmentation phase: search and augment with weight increases
        (A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, 
         arrangement, B_weights, path_lengths, cycle_lengths, iters) = search_and_augment_weights(
            A, A_mass, B, B_mass, B_weights, delta, lambda_val, A_delta, 
            A_delta_mass, transport_plan_hat, C, distance_matrix, arrangement)
        
        # Store statistics and current state
        A_deltas.append([pt.copy() for pt in A_delta])
        transport_plans.append(transport_plan_hat.copy())
        total_path_lengths_aug.append(path_lengths)
        total_cycle_lengths_aug.append(cycle_lengths)
        iters_aug.append(iters)
        regions.append(len(A_delta))
        
        # Prepare for next scale
        delta /= 2
        B_weights -= 4 * delta  # Weight adjustment for next scale
        
        # Consolidation phase: search and consolidate with weight reductions
        (A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, 
         arrangement, B_weights, path_lengths, cycle_lengths, iters) = search_and_consolidate_red_weights(
            A, A_mass, B, B_mass, B_weights, delta, lambda_val, A_delta, 
            A_delta_mass, transport_plan_hat, C, distance_matrix, arrangement)
        
        total_path_lengths_cons.append(path_lengths)
        total_cycle_lengths_cons.append(cycle_lengths)
        iters_cons.append(iters)
        
        print(f"  Regions: {len(A_delta)}, Transported mass: {np.sum(transport_plan_hat):.6f}")
    
    # Final augmentation phase at minimum delta
    print(f"Final processing at delta = {delta:.6f}")
    (A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, 
     arrangement, B_weights, path_lengths, cycle_lengths, iters) = search_and_augment_weights(
        A, A_mass, B, B_mass, B_weights, delta, lambda_val, A_delta, 
        A_delta_mass, transport_plan_hat, C, distance_matrix, arrangement)
    
    A_deltas.append([pt.copy() for pt in A_delta])
    transport_plans.append(transport_plan_hat.copy())
    total_path_lengths_aug.append(path_lengths)
    total_cycle_lengths_aug.append(cycle_lengths)
    iters_aug.append(iters)
    regions.append(len(A_delta))
    
    # Final consolidation phase
    (A_delta, A_delta_mass, transport_plan_hat, C, distance_matrix, 
     arrangement, B_weights, path_lengths, cycle_lengths, iters) = search_and_consolidate_red_weights(
        A, A_mass, B, B_mass, B_weights, delta, lambda_val, A_delta, 
        A_delta_mass, transport_plan_hat, C, distance_matrix, arrangement)
    
    total_path_lengths_cons.append(path_lengths)
    total_cycle_lengths_cons.append(cycle_lengths)
    iters_cons.append(iters)
    
    # Reconstruct full transport plan
    sd_transport_plan = np.zeros((len(A), len(B)))
    for i in range(len(arrangement)):
        for k in range(len(B)):
            for j in arrangement[i][-1]:
                sd_transport_plan[j][k] = (transport_plan_hat[i][k] * 
                                         A_mass[j] / A_delta_mass[i])
    
    print(f"Final: Regions: {len(A_delta)}, Transported mass: {np.sum(transport_plan_hat):.6f}")
    
    return (transport_plan_hat, B_weights, A_delta, sd_transport_plan, 
            total_path_lengths_aug, total_cycle_lengths_aug, 
            total_path_lengths_cons, total_cycle_lengths_cons, 
            iters_aug, iters_cons, regions, delta)


def compute_transport_cost(A: np.ndarray, B: np.ndarray, 
                          transport_plan: np.ndarray) -> float:
    """
    Compute the total transport cost for a given transport plan.
    
    Args:
        A: Source points of shape (n, 2)
        B: Target points of shape (m, 2)
        transport_plan: Transport plan of shape (n, m)
        
    Returns:
        Total transport cost (sum of distances * transported mass)
    """
    from .utils import distance_to_point
    
    cost = 0.0
    for i in range(len(A)):
        distances = distance_to_point(B, A[i])
        cost += np.sum(distances * transport_plan[i])
    
    return cost


def ROT_ot_comparison(A_mass: np.ndarray, B_mass: np.ndarray, 
                     cost_matrix: np.ndarray, lambda_val: float) -> float:
    """
    Compute ROT cost using the OT library for comparison.
    
    Args:
        A_mass: Mass distribution for source points
        B_mass: Mass distribution for target points  
        cost_matrix: Cost matrix between source and target points
        lambda_val: Lambda parameter for regularization
        
    Returns:
        Optimal ROT cost
    """
    try:
        import ot
        
        # Cap the cost matrix at lambda_val
        capped_C = np.minimum(cost_matrix, lambda_val)
        
        # Solve optimal transport
        G, log = ot.emd(B_mass, A_mass, capped_C, log=True)
        
        # Compute cost (only count costs below lambda_val)
        cost_opt = 0.0
        for i in range(len(A_mass)):
            for j in range(len(B_mass)):
                if abs(capped_C[j][i] - lambda_val) > 1e-6:
                    cost_opt += G[j][i] * capped_C[j][i]
        
        return cost_opt
    
    except ImportError:
        print("Warning: POT library not available for comparison")
        return 0.0


def analyze_untransported_mass(A: np.ndarray, A_mass: np.ndarray, 
                              B: np.ndarray, B_mass: np.ndarray,
                              transport_plan: np.ndarray, 
                              threshold: float = 1e-5) -> Tuple[np.ndarray, np.ndarray]:
    """
    Analyze untransported mass at source and target points.
    
    Args:
        A: Source points
        A_mass: Mass distribution for source points
        B: Target points  
        B_mass: Mass distribution for target points
        transport_plan: Transport plan
        threshold: Threshold below which mass is considered zero
        
    Returns:
        Tuple of (untransported_mass_A, untransported_mass_B)
    """
    # Untransported mass at source points
    untransported_mass_A = A_mass - np.sum(transport_plan, axis=1)
    untransported_mass_A[untransported_mass_A < threshold] = 0
    
    # Untransported mass at target points
    untransported_mass_B = B_mass - np.sum(transport_plan, axis=0)
    untransported_mass_B[untransported_mass_B < threshold] = 0
    
    return untransported_mass_A, untransported_mass_B