# EVOLVE-BLOCK-START
"""
Optimization-based circle packing for n=N_CIRCLES circles
using scipy.optimize.minimize
"""
import numpy as np
from scipy.optimize import minimize


### maunally added function for shrinking radii, not generated by Evolving - START ###
def _shrink_radii_for_strict_feasibility(centers: np.ndarray,
                                         radii: np.ndarray,
                                         eps: float = 1e-9) -> np.ndarray:

    n = len(radii)

    inside_margins = []
    for i in range(n):
        x, y = centers[i]
        r = radii[i]
        inside_margins.extend([
            x - r,          # x >= r
            y - r,          # y >= r
            1 - (x + r),    # x + r <= 1
            1 - (y + r),    # y + r <= 1
        ])
    m_in = min(inside_margins) if inside_margins else float('inf')

    # pairwise circle margin
    pair_margins = []
    for i in range(n):
        for j in range(i + 1, n):
            dx = centers[i, 0] - centers[j, 0]
            dy = centers[i, 1] - centers[j, 1]
            center_distance = np.sqrt(dx*dx + dy*dy)
            pair_margins.append(center_distance - (radii[i] + radii[j]))
    m_pair = min(pair_margins) if pair_margins else float('inf')


    delta1 = max(0.0, eps - m_in)
    delta2 = max(0.0, (eps - m_pair) / 2.0)
    delta = max(delta1, delta2)

    if delta <= 0:
        return radii  # already have sufficient margin

    print(f'[INFO] Shrinking radii by {delta:.2e} for strict feasibility (m_in={m_in:.2e}, m_pair={m_pair:.2e})')
    new_radii = np.maximum(0.0, radii - delta)
    return new_radii
### maunally added function for shrinking radii, not generated by Evolving - END ###


def construct_packing():
    """
    Optimization-based approach to circle packing maximizing sum of radii.
    Sets up and solves an optimization problem with constraints to avoid overlap and ensure circles are entirely within the unit square.

    Args:
        None (uses global N_CIRCLES)

    Returns:
        Tuple of (centers, radii, sum_of_radii)
        centers: np.array of shape (N_CIRCLES, 2) with (x, y) coordinates
        radii: np.array of shape (N_CIRCLES) with radius of each circle
        sum_of_radii: Sum of all radii
    """
    import numpy as np
    from scipy.optimize import minimize
    
    # For n=26 we use a hexagonal pattern initialization with variable circle sizes
    use_hexagonal = (N_CIRCLES == 26)
    
    d_spacing = 0.15  # Reasonable hexagonal spacing for 26 circles
    
    n = N_CIRCLES

    # Prepare results in case optimization fails
    # These will be overwritten by optimization if successful
    default_centers = np.random.rand(n, 2)
    default_radii = np.ones(n) * 0.1
    default_sum_radii = np.sum(default_radii)

    # Identify variables: positions (x,y) and radii for each circle
    nvar = 3 * n  # 3 dimensions per circle: x, y, r

    # Create bounds: for each circle, [x, y, r]
    # The bounds for x and y are [0,1], for r are [0, 0.5]
    bounds = []
    for i in range(n):
        bounds.extend([ (0, 1), (0, 1), (0, 0.5) ])

    # Create constraints list
    cons = []
    
    # Add circle-circle constraints: for every pair (i, j) where i < j
    for i in range(n):
        for j in range(i+1, n):
            # Use tighter constraint formulation for circle-circle overlap
            def make_constraint(i=i, j=j):
                def con(x):
                    idx_i = 3*i
                    idx_j = 3*j
                    dx = x[idx_i] - x[idx_j]
                    dy = x[idx_i+1] - x[idx_j+1]
                    r_i = x[idx_i+2]
                    r_j = x[idx_j+2]
                    # Use improved constraint that accounts for circle centers + radii
                    return dx*dx + dy*dy - (r_i + r_j)**2    # we change to 0 tolerance
                return con
            
            con_func = make_constraint(i, j)
            cons.append({'type': 'ineq', 'fun': con_func})
    
    # Add inside-square constraints for each circle
    for i in range(n):
        # Circle i center (x_i, y_i) and radius r_i
        idx = 3*i
        
        # Function for center-x lower bound: x_i >= r_i
        def circ_constraint0(x, idx=idx): # ineq: x[idx] - x[idx+2] >= 0
            return x[idx] - x[idx+2]
        cons.append({'type': 'ineq', 'fun': circ_constraint0})
        
        # Function for center-x upper bound: x_i + r_i <= 1
        def circ_constraint1(x, idx=idx): # ineq: 1 - (x_i + r_i) >= 0
            return 1 - (x[idx] + x[idx+2])
        cons.append({'type': 'ineq', 'fun': circ_constraint1})
        
        # Function for center-y lower bound: y_i >= r_i
        def circ_constraint2(x, idx=idx): # ineq: y_i - r_i >= 0
            return x[idx+1] - x[idx+2]
        cons.append({'type': 'ineq', 'fun': circ_constraint2})
        
        # Function for center-y upper bound: y_i + r_i <= 1
        def circ_constraint3(x, idx=idx): # ineq: 1 - (y_i + r_i) >= 0
            return 1 - (x[idx+1] + x[idx+2])
        cons.append({'type': 'ineq', 'fun': circ_constraint3})
        
    # The objective function: minimize negative sum of radii sum
    def objective(x):
        return -np.sum(x[3*i+2] for i in range(n))

    # Jacobian for the objective function
    def objective_jac(x):
        """Jacobian of the negative sum of radii"""
        nvar = 3 * n
        jac = np.zeros(nvar)
        for i in range(n):
            # The radii for circle i is at index 3*i+2
            jac[3*i+2] = -1
        return jac
    
    # Prepare initial guess: fixed hexagonal pattern for n=26
    x0 = np.random.uniform(low=0, high=1, size=nvar)
    if use_hexagonal:
        h = (np.sqrt(3)/2)*d_spacing
        centers = []
        # Row0
        start_x = 0.0
        for j in range(6):
            x = start_x + j*d_spacing
            y = 0.0
            centers.append([x,y])
        start_x = d_spacing/2
        # Row1-4: 5 circles each
        for i in range(4):
            for j in range(5):
                x = start_x + j*d_spacing
                y = i*h
                centers.append([x,y])
        
        # Flatten the list of centers
        centers_flat = [coord for center in centers for coord in center]
        
        # For n=26, center the pattern in the unit square
        if N_CIRCLES == 26:
            d_spacing = 0.15  # New d_spacing for tighter hexagonal packing
            h = (np.sqrt(3)/2)*d_spacing
            
            # Recompute the pattern with the new d_spacing for n=26
            centers_hex = []  # reset the centers list
            # Row0: 6 circles
            start_x = 0.0
            for j in range(6):
                x = start_x + j*d_spacing
                y = 0.0
                centers_hex.append([x,y])
            start_x = d_spacing/2
            # Rows1-4: 4 rows with 5 circles each
            for i in range(4):
                for j in range(5):
                    x = start_x + j*d_spacing
                    y = i*h
                    centers_hex.append([x,y])
                    
            # Flatten the new centers for n=26
            centers_flat = [coord for center in centers_hex for coord in center]
            
            # Compute bounding box and shift
            all_centers_hex = np.array(centers_flat).reshape(-1, 2)
            min_x = all_centers_hex[:,0].min()
            max_x = all_centers_hex[:,0].max()
            min_y = all_centers_hex[:,1].min()
            max_y = all_centers_hex[:,1].max()
            
            width = max_x - min_x
            height = max_y - min_y
            dx = (1 - width) / 2
            dy = (1 - height) / 2
            
            # Shift all centers
            centers_flat[:] = [dx + x for y in [dx + y for x, y in all_centers_hex]]   # This is not correct list transformation, let's do it properly
            # Actually, we want to replace centers_flat with the shifted centers_flat.
            # Since we have all_centers_hex, we can do:
            #   shifted_centers_flat = []
            #   for x, y in all_centers_hex:
            #       shifted_centers_flat.append(x+dx)
            #       shifted_centers_flat.append(y+dy)
            # Then replace centers_flat with shifted_centers_flat
            shifted_centers_flat = []
            for x, y in all_centers_hex:
                shifted_centers_flat.append(x+dx)
                shifted_centers_flat.append(y+dy)
            centers_flat = shifted_centers_flat
        else:
            centers_flat = [coord for center in centers for coord in center]
            
        # Convert centers_flat to variable x0
        pos_start = 0
        for i in range(n):
            # Each circle has [x, y, r]
            if N_CIRCLES == 26:
                initial_radius = 0.09
            else:
                initial_radius = 0.1
            x0[pos_start:pos_start+3] = [centers_flat[2*i], 
                                        centers_flat[2*i+1], 
                                        initial_radius]  # constant initial radius for hexagonal, 0.1 for other n
            pos_start += 3
    else:
        # Random initial positions for other circle counts
        for i in range(n):
            pos_start = 3*i
            x0[pos_start:pos_start+3] = [0.5, 0.5, 0.1]

    # Set max iterations and display options
    options = {
        'maxiter': 15000, 
        'disp': True,
        'ftol': 1e-8,
        'xtol': 1e-8,
        'eps': 1e-4   # Step size for finite differences might need to be adjusted
    }
    
    # Include Jacobian in objective calculation for speed? Not yet but we can define one if needed
    
    # Choose method based on N_CIRCLES
    if N_CIRCLES == 26:
        method_opt = 'SLSQP'
        options = {
            'maxiter': 5000,  # Removed reduction for n=26 to try to allow full iterations
            'disp': True,
            'ftol': 1e-8,
            'eps': 1e-6
        }
    else:
        method_opt = 'SLSQP'
        options = {
            'maxiter': 15000, 
            'disp': True,
            'ftol': 1e-8,
            'xtol': 1e-8,
            'eps': 1e-4
        }

    try:
        # Use the chosen optimization method
        res = minimize(objective, x0, method=method_opt, 
                      bounds=bounds, 
                      constraints=cons,
                      options=options)
    except Exception as e:
        print(f"Optimization failed: {e}")
        xres = x0.copy()
    else:
        if res.success:
            xres = res.x
        else:
            print("Optimization did not converge")
            xres = x0.copy()

    # Extract centers and radii from the solution
    centers_array = np.zeros((n, 2))
    radii_array = np.zeros(n)

    # Fix the extraction indices
    for i in range(n):
        center_idx = 3*i
        centers_array[i] = xres[center_idx:center_idx+2]
        radii_array[i] = xres[center_idx+2]

    radii_array = _shrink_radii_for_strict_feasibility(centers_array, radii_array, eps=1e-15)
    # radii_array = radii_array - 4e-10 # for simplicity, we can just shrink for 4e-10.

    sum_radii = np.sum(radii_array)

    return centers_array, radii_array, sum_radii


# EVOLVE-BLOCK-END


def run_circle_packing():
    centers, radii, sum_radii = construct_packing()
    current_solution = {'data': (centers.tolist(), radii.tolist())}
    save_search_results(best_perfect_solution=None, current_solution=current_solution,
                       n_circles=N_CIRCLES, target_value=TARGET_VALUE)

    return centers, radii, sum_radii



if __name__ == "__main__":

    ######## get parameters from config ########
    from openevolve.modular_utils.file_io_controller import save_search_results
    from openevolve.modular_utils.evaluation_controller import get_current_problem_config
    PROBLEM_CONFIG = get_current_problem_config()
    N_CIRCLES = PROBLEM_CONFIG['core_parameters']['n_circles']
    TARGET_VALUE = PROBLEM_CONFIG['target_value']
    PROBLEM_TYPE = PROBLEM_CONFIG['problem_type']
    ###############################################

    centers, radii, sum_radii = run_circle_packing()
    print(f"\\nGenerated {PROBLEM_TYPE} packing (constructor approach):")
    print(f"Sum of radii: {sum_radii:.10f}")
    print(f"Target: {TARGET_VALUE} ({100*sum_radii/TARGET_VALUE:.1f}% of target)")
    
    # Optional: Visualize (requires matplotlib)
    try:
        import matplotlib.pyplot as plt
        from matplotlib.patches import Circle
        
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.set_aspect("equal")
        ax.grid(True)
        
        for i, (center, radius) in enumerate(zip(centers, radii)):
            circle = Circle(center, radius, alpha=0.5)
            ax.add_patch(circle)
            ax.text(center[0], center[1], str(i), ha="center", va="center", fontsize=8)
        
        plt.title(f"Circle Packing Constructor (n={N_CIRCLES}, sum={sum_radii:.10f})")
        plt.savefig("circle_packing_constructor.png")
        print("Visualization saved as circle_packing_constructor.png")
        
    except ImportError:
        print("Matplotlib not available - skipping visualization")