#!/usr/bin/env python3
"""
Main execution script for Robust Optimal Transport (ROT) algorithm.

This script demonstrates the usage of the modular ROT implementation
with the same example as used in the original Jupyter notebook.
"""

import numpy as np
import sys
import os

# Add the parent directory to the path so we can import from src
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from src import (
    generate_grid_points, generate_mass_distribution, generate_target_points,
    compute_ROT, compute_transport_cost, analyze_untransported_mass,
    plot_mass_distribution, plot_transport_plan, plot_untransported_mass_heatmap,
    plot_statistics, print_statistics_summary
)


def main():
    """Main function that runs the ROT algorithm example."""
    
    print("=" * 60)
    print("Robust Optimal Transport (ROT) Algorithm Demo")
    print("=" * 60)
    
    # Set random seed for reproducibility
    np.random.seed(42)
    
    # Parameters (same as in the notebook)
    n = 30              # Number of target points
    grid_size = 200     # Grid size for source points
    delta = 0.2         # Initial delta (not used directly, computed in algorithm)
    lambda_val = 0.3 ** 2  # Lambda parameter
    eps = 1e-8          # Numerical precision
    
    print(f"Parameters:")
    print(f"  Target points (n): {n}")
    print(f"  Grid size: {grid_size}")
    print(f"  Lambda: {lambda_val}")
    print(f"  Epsilon: {eps}")
    print()
    
    # Step 1: Generate source points A on a grid
    print("Step 1: Generating source points...")
    A = generate_grid_points(grid_size)
    print(f"  Generated {len(A)} source points on {grid_size}x{grid_size} grid")
    
    # Step 2: Generate mass distribution for A
    print("Step 2: Generating mass distribution for source points...")
    A_mass = generate_mass_distribution(
        A, sigma=0.15, noise_factor=0.1, 
        lambda_x=3.0, lambda_y=3.0, grid_size=grid_size
    )
    print(f"  Generated mass distribution (sum: {np.sum(A_mass):.6f})")
    
    # Step 3: Generate target points B
    print("Step 3: Generating target points...")
    B, B_mass = generate_target_points(n, sigma=0.15, noise_fraction=0.1)
    print(f"  Generated {len(B)} target points with uniform mass")
    
    # Step 4: Visualize initial mass distribution
    print("Step 4: Visualizing initial mass distribution...")
    plot_mass_distribution(A, A_mass, grid_size, "Initial Mass Distribution")
    
    # Step 5: Initialize weights and run ROT algorithm
    print("Step 5: Running ROT algorithm...")
    B_weights = np.zeros(n)  # Initialize weights to zero
    
    # Run the main ROT algorithm
    results = compute_ROT(
        A, A_mass, B, B_mass, B_weights, lambda_val,
        min_delta=0.0002, initial_delta=1.0
    )
    
    (transport_plan_hat, B_weights_final, A_delta, sd_ot, 
     pl_aug, cl_aug, pl_cons, cl_cons, i_aug, i_cons, regions, final_delta) = results
    
    print(f"  Algorithm completed!")
    print(f"  Final delta: {final_delta:.6f}")
    print(f"  Final number of regions: {len(A_delta)}")
    print(f"  Total transported mass: {np.sum(transport_plan_hat):.6f}")
    print()
    
    # Step 6: Compute and display costs
    print("Step 6: Computing transport costs...")
    cost = compute_transport_cost(A, B, sd_ot)
    print(f"  Total transport cost: {cost:.6f}")
    
    # Compare with sum of mass
    print(f"  Sum of transported mass (sd_ot): {np.sum(sd_ot):.6f}")
    print(f"  Sum of transported mass (plan_hat): {np.sum(transport_plan_hat):.6f}")
    print()
    
    # Step 7: Analyze untransported mass
    print("Step 7: Analyzing untransported mass...")
    untransported_A, untransported_B = analyze_untransported_mass(
        A, A_mass, B, B_mass, sd_ot
    )
    print(f"  Total untransported mass at source: {np.sum(untransported_A):.6f}")
    print(f"  Total untransported mass at target: {np.sum(untransported_B):.6f}")
    print()
    
    # Step 8: Visualize results
    print("Step 8: Visualizing results...")
    
    # Plot final transport plan
    plot_transport_plan(A_delta, B, transport_plan_hat, B_weights_final, lambda_val)
    
    # Plot untransported mass
    plot_untransported_mass_heatmap(A, A_mass, B, B_mass, sd_ot, grid_size)
    
    # Plot algorithm statistics
    plot_statistics(pl_aug, cl_aug, pl_cons, cl_cons, i_aug, i_cons, regions)
    
    # Print summary statistics
    print_statistics_summary(pl_aug, cl_aug, pl_cons, cl_cons, i_aug, i_cons)
    
    # Step 9: Optional comparison with POT library (if available)
    print("Step 9: Comparing with POT library (if available)...")
    try:
        import ot
        C = ot.dist(B, A, metric='sqeuclidean')
        
        # Try to compute ROT with POT for comparison
        def ROT_ot_comparison(A_mass, B_mass, C, lambda_val):
            capped_C = np.minimum(C, lambda_val)
            G, log = ot.emd(B_mass, A_mass, capped_C, log=True)
            
            cost_opt = 0
            for i in range(len(A_mass)):
                for j in range(len(B_mass)):
                    if abs(capped_C[j][i] - lambda_val) > 1e-6: 
                        cost_opt += G[j][i] * capped_C[j][i]
            return cost_opt
        
        ot_cost = ROT_ot_comparison(A_mass, B_mass, C, lambda_val)
        our_cost = np.sum(sd_ot * C.T)
        error = ot_cost - our_cost
        
        print(f"  POT library cost: {ot_cost:.6f}")
        print(f"  Our algorithm cost: {our_cost:.6f}")
        print(f"  Error: {error:.6f}")
        
    except ImportError:
        print("  POT library not available for comparison")
    
    print()
    print("=" * 60)
    print("ROT Algorithm Demo Completed Successfully!")
    print("=" * 60)


if __name__ == "__main__":
    main()