"""
Visualization module for Robust Optimal Transport.

This module contains functions for plotting and visualizing
transport plans, mass distributions, and algorithm results.
"""

import numpy as np
import matplotlib.pyplot as plt
import math
from typing import List, Optional, Tuple


def plot_mass_distribution(A: np.ndarray, A_mass: np.ndarray, 
                          grid_size: int, title: str = "Mass Distribution") -> None:
    """
    Plot mass distribution as a heatmap.
    
    Args:
        A: Points array of shape (n, 2)
        A_mass: Mass distribution of shape (n,)
        grid_size: Size of the grid for reshaping
        title: Title for the plot
    """
    plt.figure(figsize=(8, 6))
    plt.imshow(A_mass.reshape(grid_size, grid_size), 
               cmap='coolwarm', interpolation='nearest', origin='lower')
    plt.colorbar(label='Mass')
    plt.title(title)
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.show()


def plot_transport_plan(A_delta: List, B: np.ndarray, transport_plan: np.ndarray,
                       B_weights: np.ndarray, lambda_val: float = 0.1, 
                       figsize: Tuple[int, int] = (10, 10),
                       save_path: Optional[str] = None) -> None:
    """
    Visualize transport plan with connections and weight circles.
    
    Args:
        A_delta: Representative source points
        B: Target points
        transport_plan: Transport plan matrix
        B_weights: Weights for B points (for circle sizes)
        lambda_val: Lambda parameter (for title)
        figsize: Figure size tuple
        save_path: Optional path to save the figure
    """
    plt.figure(figsize=figsize)
    
    # Draw transport connections
    for i in range(len(transport_plan)):
        for j in range(len(B)):
            if transport_plan[i][j] > 1e-8:
                plt.plot([A_delta[i][0], B[j][0]], [A_delta[i][1], B[j][1]], 
                        'k-', linewidth=max(1, 10 * transport_plan[i][j]))
    
    # Draw weight circles around B points
    for i in range(len(B)):
        circle = plt.Circle((B[i, 0], B[i, 1]), radius=math.sqrt(B_weights[i]), 
                           color='blue', alpha=0.1)
        plt.gca().add_artist(circle)
    
    # Plot points
    plt.scatter(B[:, 0], B[:, 1], s=50, c='blue', label='Target points (B)')
    plt.scatter([A_delta[i][0] for i in range(len(A_delta))], 
               [A_delta[i][1] for i in range(len(A_delta))], 
               s=20, c='red', label='Source points (A)')
    
    plt.title(f'Transport Plan Visualization, λ={lambda_val}²')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    else:
        plt.show()


def plot_multiple_transport_plans(A_deltas: List[List], B: np.ndarray, 
                                 transport_plans: List[np.ndarray],
                                 figsize: Tuple[int, int] = (10, 80)) -> None:
    """
    Plot multiple transport plans in subplots (for different scales).
    
    Args:
        A_deltas: List of representative points for each scale
        B: Target points
        transport_plans: List of transport plans for each scale
        figsize: Figure size tuple
    """
    fig, axes = plt.subplots(nrows=len(A_deltas), ncols=1, figsize=figsize)
    
    if len(A_deltas) == 1:
        axes = [axes]
    
    for k, ax in enumerate(axes):
        # Draw transport connections
        for i in range(len(transport_plans[k])):
            for j in range(len(B)):
                if transport_plans[k][i][j] > 1e-6:
                    ax.plot([A_deltas[k][i][0], B[j][0]], 
                           [A_deltas[k][i][1], B[j][1]], 
                           'k-', linewidth=max(1, 10 * transport_plans[k][i][j]))
        
        # Plot points
        ax.scatter(B[:, 0], B[:, 1], s=50, c='blue', label='Target points (B)')
        ax.scatter([A_deltas[k][i][0] for i in range(len(A_deltas[k]))], 
                  [A_deltas[k][i][1] for i in range(len(A_deltas[k]))], 
                  s=20, c='red', label='Source points (A)')
        
        ax.set_title(f"Scale {k+1}: Transported mass = {np.sum(transport_plans[k]):.4f}")
        ax.set_xlabel('X-axis')
        ax.set_ylabel('Y-axis')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()


def plot_untransported_mass(A: np.ndarray, A_mass: np.ndarray, B: np.ndarray, 
                           B_mass: np.ndarray, transport_plan: np.ndarray,
                           grid_size: int, scaling_factor_A: float = 10000,
                           scaling_factor_B: float = None,
                           figsize: Tuple[int, int] = (10, 8)) -> None:
    """
    Visualize untransported mass at source and target points.
    
    Args:
        A: Source points
        A_mass: Original mass at source points
        B: Target points
        B_mass: Original mass at target points
        transport_plan: Transport plan matrix
        grid_size: Grid size for source points visualization
        scaling_factor_A: Scaling factor for source point visualization
        scaling_factor_B: Scaling factor for target point visualization (default: len(B))
        figsize: Figure size tuple
    """
    if scaling_factor_B is None:
        scaling_factor_B = len(B)
    
    plt.figure(figsize=figsize)
    
    # Compute untransported mass
    untransported_mass_A = A_mass - np.sum(transport_plan, axis=1)
    untransported_mass_A[untransported_mass_A < 0] = 0
    untransported_mass_A_scaled = untransported_mass_A * scaling_factor_A
    untransported_mass_A_scaled[untransported_mass_A_scaled > 1.0] = 1.0
    
    untransported_mass_B = B_mass - np.sum(transport_plan, axis=0)
    untransported_mass_B[untransported_mass_B < 0.0001] = 0
    untransported_mass_B_scaled = untransported_mass_B * scaling_factor_B
    untransported_mass_B_scaled[untransported_mass_B_scaled > 1.0] = 1.0
    
    # Plot source points as scatter with alpha based on untransported mass
    plt.scatter(A[:, 0], A[:, 1], s=1, alpha=untransported_mass_A_scaled, 
               color="red", label='Untransported source mass')
    
    # Plot target points as scatter with size based on untransported mass
    plt.scatter(B[:, 0], B[:, 1], s=100 * untransported_mass_B_scaled, 
               alpha=untransported_mass_B_scaled, c='blue', 
               label='Untransported target mass')
    
    plt.title('Untransported Mass Visualization')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()


def plot_untransported_mass_heatmap(A: np.ndarray, A_mass: np.ndarray, 
                                   B: np.ndarray, B_mass: np.ndarray,
                                   transport_plan: np.ndarray, grid_size: int,
                                   figsize: Tuple[int, int] = (10, 8)) -> None:
    """
    Plot untransported mass as a heatmap with overlaid target points.
    
    Args:
        A: Source points
        A_mass: Original mass at source points
        B: Target points
        B_mass: Original mass at target points
        transport_plan: Transport plan matrix
        grid_size: Grid size for reshaping source mass
        figsize: Figure size tuple
    """
    plt.figure(figsize=figsize)
    
    # Compute untransported mass
    untransported_mass_A = A_mass - np.sum(transport_plan, axis=1)
    untransported_mass_A[untransported_mass_A < 0.00001] = 0
    
    untransported_mass_B = B_mass - np.sum(transport_plan, axis=0)
    untransported_mass_B[untransported_mass_B < 0.0001] = 0
    
    # Plot source mass as heatmap
    plt.imshow(untransported_mass_A.reshape(grid_size, grid_size), 
               cmap='Reds', interpolation='nearest', origin='lower')
    plt.colorbar(label='Untransported Source Mass')
    
    # Overlay target points with size proportional to untransported mass
    plt.scatter(B[:, 0] * grid_size, B[:, 1] * grid_size, 
               s=10000 * untransported_mass_B, c='blue', alpha=0.7,
               label='Untransported target mass')
    
    plt.title('Untransported Mass Heatmap')
    plt.xlabel('X-axis (grid units)')
    plt.ylabel('Y-axis (grid units)')
    plt.legend()
    plt.show()


def plot_statistics(path_lengths_aug: List[float], cycle_lengths_aug: List[float],
                   path_lengths_cons: List[float], cycle_lengths_cons: List[float],
                   iters_aug: List[int], iters_cons: List[int],
                   regions: List[int], figsize: Tuple[int, int] = (15, 10)) -> None:
    """
    Plot algorithm statistics across different scales.
    
    Args:
        path_lengths_aug: Path lengths during augmentation phases
        cycle_lengths_aug: Cycle lengths during augmentation phases
        path_lengths_cons: Path lengths during consolidation phases  
        cycle_lengths_cons: Cycle lengths during consolidation phases
        iters_aug: Iteration counts for augmentation phases
        iters_cons: Iteration counts for consolidation phases
        regions: Number of regions at each scale
        figsize: Figure size tuple
    """
    fig, axes = plt.subplots(2, 3, figsize=figsize)
    
    scales = list(range(len(path_lengths_aug)))
    
    # Path lengths
    axes[0, 0].bar(scales, path_lengths_aug, alpha=0.7, label='Augmentation')
    axes[0, 0].bar(scales, path_lengths_cons, alpha=0.7, label='Consolidation')
    axes[0, 0].set_title('Path Lengths')
    axes[0, 0].set_xlabel('Scale')
    axes[0, 0].set_ylabel('Path Length')
    axes[0, 0].legend()
    
    # Cycle lengths
    axes[0, 1].bar(scales, cycle_lengths_aug, alpha=0.7, label='Augmentation')
    axes[0, 1].bar(scales, cycle_lengths_cons, alpha=0.7, label='Consolidation')
    axes[0, 1].set_title('Cycle Lengths')
    axes[0, 1].set_xlabel('Scale')
    axes[0, 1].set_ylabel('Cycle Length')
    axes[0, 1].legend()
    
    # Iterations
    axes[0, 2].bar(scales, iters_aug, alpha=0.7, label='Augmentation')
    axes[0, 2].bar(scales, iters_cons, alpha=0.7, label='Consolidation')
    axes[0, 2].set_title('Iterations')
    axes[0, 2].set_xlabel('Scale')
    axes[0, 2].set_ylabel('Number of Iterations')
    axes[0, 2].legend()
    
    # Regions
    axes[1, 0].plot(scales, regions, 'o-', linewidth=2, markersize=8)
    axes[1, 0].set_title('Number of Regions')
    axes[1, 0].set_xlabel('Scale')
    axes[1, 0].set_ylabel('Number of Regions')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Combined statistics
    axes[1, 1].bar(scales, np.array(path_lengths_aug) + np.array(path_lengths_cons), 
                   alpha=0.7, label='Total Path Lengths')
    axes[1, 1].bar(scales, np.array(cycle_lengths_aug) + np.array(cycle_lengths_cons), 
                   alpha=0.7, label='Total Cycle Lengths')
    axes[1, 1].set_title('Total Path and Cycle Lengths')
    axes[1, 1].set_xlabel('Scale')
    axes[1, 1].set_ylabel('Total Length')
    axes[1, 1].legend()
    
    # Summary statistics text
    axes[1, 2].axis('off')
    summary_text = f"""Algorithm Summary:
    
Path Lengths (Aug):
  Mean: {np.mean(path_lengths_aug):.2f} ± {np.std(path_lengths_aug):.2f}
  
Cycle Lengths (Aug):
  Mean: {np.mean(cycle_lengths_aug):.2f} ± {np.std(cycle_lengths_aug):.2f}
  
Path Lengths (Cons):
  Mean: {np.mean(path_lengths_cons):.2f} ± {np.std(path_lengths_cons):.2f}
  
Cycle Lengths (Cons):
  Mean: {np.mean(cycle_lengths_cons):.2f} ± {np.std(cycle_lengths_cons):.2f}
  
Iterations (Aug):
  Mean: {np.mean(iters_aug):.2f} ± {np.std(iters_aug):.2f}
  
Iterations (Cons):
  Mean: {np.mean(iters_cons):.2f} ± {np.std(iters_cons):.2f}
  
Total Scales: {len(scales)}
Final Regions: {regions[-1] if regions else 0}"""
    
    axes[1, 2].text(0.1, 0.9, summary_text, transform=axes[1, 2].transAxes,
                    fontsize=10, verticalalignment='top', fontfamily='monospace')
    
    plt.tight_layout()
    plt.show()


def print_statistics_summary(path_lengths_aug: List[float], cycle_lengths_aug: List[float],
                           path_lengths_cons: List[float], cycle_lengths_cons: List[float],
                           iters_aug: List[int], iters_cons: List[int]) -> None:
    """
    Print a summary of algorithm statistics.
    
    Args:
        path_lengths_aug: Path lengths during augmentation phases
        cycle_lengths_aug: Cycle lengths during augmentation phases
        path_lengths_cons: Path lengths during consolidation phases
        cycle_lengths_cons: Cycle lengths during consolidation phases
        iters_aug: Iteration counts for augmentation phases
        iters_cons: Iteration counts for consolidation phases
    """
    print("ROT Algorithm Statistics Summary:")
    print("=" * 40)
    print(f"Path Lengths (Augmentation): {np.mean(path_lengths_aug):.2f} ± {np.std(path_lengths_aug):.2f}")
    print(f"Cycle Lengths (Augmentation): {np.mean(cycle_lengths_aug):.2f} ± {np.std(cycle_lengths_aug):.2f}")
    print(f"Path Lengths (Consolidation): {np.mean(path_lengths_cons):.2f} ± {np.std(path_lengths_cons):.2f}")
    print(f"Cycle Lengths (Consolidation): {np.mean(cycle_lengths_cons):.2f} ± {np.std(cycle_lengths_cons):.2f}")
    print(f"Iterations (Augmentation): {np.mean(iters_aug):.2f} ± {np.std(iters_aug):.2f}")
    print(f"Iterations (Consolidation): {np.mean(iters_cons):.2f} ± {np.std(iters_cons):.2f}")
    print("=" * 40)