"""
Core utility functions for Robust Optimal Transport (ROT) algorithm.

This module contains basic utility functions for distance calculations,
point generation, and other helper functions.
"""

import numpy as np
from typing import Tuple, List


def distance_to_point(B: np.ndarray, a: np.ndarray) -> np.ndarray:
    """
    Calculate squared Euclidean distances from point a to all points in B.
    
    Args:
        B: Array of shape (n, 2) containing n points
        a: Point of shape (2,) to calculate distances from
        
    Returns:
        Array of shape (n,) containing squared distances
    """
    return np.sum((B - a)**2, axis=1)


def generate_grid_points(grid_size: int) -> np.ndarray:
    """
    Generate a grid of points in the unit square [0,1]x[0,1].
    
    Args:
        grid_size: Number of points per dimension
        
    Returns:
        Array of shape (grid_size^2, 2) containing grid points
    """
    return np.array([[i / grid_size, j / grid_size] 
                     for i in range(grid_size) 
                     for j in range(grid_size)])


def generate_mass_distribution(A: np.ndarray, sigma: float = 0.15, 
                              noise_factor: float = 0.1, 
                              lambda_x: float = 3.0, 
                              lambda_y: float = 3.0,
                              grid_size: int = 100) -> np.ndarray:
    """
    Generate mass distribution for points A using Gaussian center + exponential noise.
    
    Args:
        A: Array of points of shape (n, 2)
        sigma: Standard deviation for Gaussian distribution centered at (0.5, 0.5)
        noise_factor: Fraction of exponential noise to add
        lambda_x: Rate parameter for exponential distribution in x direction
        lambda_y: Rate parameter for exponential distribution in y direction
        grid_size: Grid size for computing exponential density
        
    Returns:
        Normalized mass distribution of shape (n,)
    """
    # Gaussian mass centered at (0.5, 0.5)
    A_mass = np.exp(-((A[:, 0] - 0.5)**2 + (A[:, 1] - 0.5)**2) / (2 * sigma**2))
    
    # Add exponential noise
    x = np.linspace(0, 1, grid_size)
    y = np.linspace(0, 1, grid_size)
    xx, yy = np.meshgrid(x, y)
    density = lambda_x * lambda_y * np.exp(-lambda_x * xx) * np.exp(-lambda_y * yy)
    A_mass += density.flatten() * noise_factor
    
    # Normalize to sum to 1
    A_mass = A_mass / np.sum(A_mass)
    return A_mass


def generate_target_points(n: int, sigma: float = 0.15, 
                          noise_fraction: float = 0.1) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate target points B with uniform mass distribution.
    
    Args:
        n: Number of target points
        sigma: Standard deviation for normal distribution of main points
        noise_fraction: Fraction of points to generate from exponential noise
        
    Returns:
        Tuple of (B_points, B_mass) where:
        - B_points: Array of shape (n, 2) containing point coordinates
        - B_mass: Array of shape (n,) with uniform mass 1/n for each point
    """
    # Generate main points from normal distribution
    n_main = int(n * (1 - noise_fraction))
    B_main = np.random.normal(loc=0.5, scale=sigma, size=(n_main, 2))
    B_main = np.clip(B_main, 0, 1)
    
    # Generate noise points from exponential distribution
    n_noise = n - n_main
    B_noise = 1 - np.random.exponential(scale=0.1, size=(n_noise, 2))
    B_noise = np.clip(B_noise, 0, 1)
    
    # Combine all points
    B = np.vstack([B_main, B_noise])
    B_mass = np.ones(n) / n
    
    return B, B_mass


def partition_points(A: np.ndarray, A_mass: np.ndarray, B: np.ndarray, 
                    B_weights: np.ndarray, delta: float) -> List:
    """
    Partition points of A into buckets based on weighted nearest neighbors in B.
    
    Each bucket contains points that have the same weighted nearest neighbor
    and similar delta-approximate nearest neighbors.
    
    Args:
        A: Source points of shape (n, 2)
        A_mass: Mass distribution for A points of shape (n,)
        B: Target points of shape (m, 2) 
        B_weights: Weights for B points of shape (m,)
        delta: Approximation parameter
        
    Returns:
        List of partitions, where each partition contains:
        - Hierarchical dictionary structure based on distance categories
        - -1: list of point indices in this partition
        - -2: total mass in this partition
        - -3: distance categories for each B point
    """
    partitions = {}
    arrangement = []
    
    for j in range(len(A)):
        a = A[j]
        # Compute weighted distances from a to all points in B
        distances = distance_to_point(B, a)
        weighted_distances = distances - B_weights
        min_distance = np.min(weighted_distances)
        
        curr_partition = partitions
        dists = [-1 for _ in range(len(B))]
        
        for i in range(len(B)):
            if weighted_distances[i] < 2 * delta:
                if i not in curr_partition:
                    curr_partition[i] = [{}, {}, {}]
                
                # Categorize distance to B[i]
                if (weighted_distances[i] == min_distance and 
                    weighted_distances[i] < 0):
                    curr_partition = curr_partition[i][0]
                    dists[i] = 0
                elif (weighted_distances[i] <= min_distance + delta and 
                      weighted_distances[i] < delta):
                    curr_partition = curr_partition[i][1]
                    dists[i] = 1
                elif (weighted_distances[i] <= min_distance + 2 * delta and 
                      weighted_distances[i] < 2 * delta):
                    curr_partition = curr_partition[i][2]
                    dists[i] = 2
        
        # Initialize partition data if not exists
        if -1 not in curr_partition:
            curr_partition[-1] = []
            curr_partition[-2] = 0
            curr_partition[-3] = dists
        
        # Add current point to partition
        curr_partition[-1].append(j)
        curr_partition[-2] += A_mass[j]
        
        # Add to arrangement if this is the first point in partition
        if len(curr_partition[-1]) == 1:
            arrangement.append(curr_partition)
    
    return arrangement