"""
Discrete set generation module for Robust Optimal Transport.

This module handles the creation and management of discrete sets
used in the ROT algorithm, including set partitioning and transport plan management.
"""

import numpy as np
from typing import Tuple, List
from .utils import distance_to_point, partition_points


def discrete_set(A: np.ndarray, A_mass: np.ndarray, B: np.ndarray, 
                B_weights: np.ndarray, delta: float, 
                transport_plan: np.ndarray) -> Tuple[List, List, np.ndarray, List, List, List, List]:
    """
    Generate discrete set from continuous distribution A.
    
    This function partitions the source points A and creates a representative
    discrete set along with the corresponding transport plan and distance matrices.
    
    Args:
        A: Source points of shape (n, 2)
        A_mass: Mass distribution for A of shape (n,)
        B: Target points of shape (m, 2)
        B_weights: Weights for B points of shape (m,)
        delta: Approximation parameter
        transport_plan: Current transport plan of shape (n, m)
        
    Returns:
        Tuple containing:
        - out_points: Representative points for each partition
        - out_mass: Mass for each representative point
        - out_transport_plan: Transport plan for representative points
        - C: Distance information sorted by weighted distance
        - distance_matrix: Distance matrix between representative points and B
        - arrangement: Partition arrangement
        - drawing_points: Random points from each partition for visualization
    """
    arrangement = partition_points(A, A_mass, B, B_weights, delta)

    out_points = []
    out_mass = []
    out_transport_plan = []
    C = []
    distance_matrix = []
    drawing_points = []
    
    # Process each partition
    for partition in arrangement:
        # Select representative point (first point in partition)
        # and random point for drawing
        rand_index = np.random.randint(len(partition[-1]))
        out_points.append(A[partition[-1][0]])
        drawing_points.append(A[partition[-1][rand_index]])
        out_mass.append(partition[-2])
        
        # Aggregate transport plan for this partition
        curr_transport = [np.sum(transport_plan[partition[-1], i]) 
                         for i in range(len(B))]
        out_transport_plan.append(curr_transport)
        
        # Compute distances from representative point to all B points
        distances = distance_to_point(B, A[partition[-1][0]])
        distance_matrix.append(distances - B_weights)
        
        # Create sorted list of (index, weighted_distance) pairs
        weighted_distances = [(i, distances[i] - B_weights[i]) 
                             for i in range(len(B))]
        weighted_distances.sort(key=lambda x: x[1])
        C.append(weighted_distances)
    
    out_transport_plan = np.array(out_transport_plan)
    
    return (out_points, out_mass, out_transport_plan, C, 
            distance_matrix, arrangement, drawing_points)


def reconstruct_full_transport_plan(arrangement: List, transport_plan_hat: np.ndarray,
                                  A_mass: np.ndarray, A_delta_mass: List,
                                  num_A: int, num_B: int) -> np.ndarray:
    """
    Reconstruct the full transport plan from the discrete representative plan.
    
    Args:
        arrangement: Partition arrangement from discrete_set
        transport_plan_hat: Transport plan for representative points
        A_mass: Original mass distribution for A
        A_delta_mass: Mass for each representative point
        num_A: Number of original A points
        num_B: Number of B points
        
    Returns:
        Full transport plan of shape (num_A, num_B)
    """
    sd_transport_plan = np.zeros((num_A, num_B))
    
    for i in range(len(arrangement)):
        for j in arrangement[i][-1]:
            for k in range(num_B):
                sd_transport_plan[j][k] = (transport_plan_hat[i][k] * 
                                         A_mass[j] / A_delta_mass[i])
    
    return sd_transport_plan