class GeometricCalibration:
    """
    Refines memberships using Geometric grouping principles:
    - Proximity: nearby tokens should have similar memberships
    - Continuity: memberships should extend smoothly along structures
    """
    def forward(self, raw_scores, token_features, H, W):
        # Step 1. Convert raw scores into initial memberships (column-softmax)
        memberships = column_softmax(raw_scores)

        # Step 2. Proximity adjustment: average memberships over spatial neighbors
        proximity_term = aggregate_neighbors(memberships, H, W)

        # Step 3. Similarity adjustment: weight neighbors by feature cosine similarity
        similarity_term = aggregate_feature_neighbors(memberships, token_features)

        # Step 4. Compute adaptive weights (alpha_p, alpha_s, alpha_c)
        # Closed-form, parameter-free
        w_p, w_s, w_c = compute_adaptive_coeffs(
            memberships, proximity_term, similarity_term, continuity_term
        )

        # Step 5. Fuse calibrated memberships
        A_hat = memberships + w_p * proximity_term + w_s * similarity_term + w_c * continuity_term

        return normalize_columns(A_hat), dict(w_p=w_p, w_s=w_s, w_c=w_c)


class PrimitivGenerator:
    def __init__(self, num_hyperedges, num_heads, context="both"):
        self.num_hyperedges = num_hyperedges
        self.num_heads = num_heads
        self.context_type = context
        self.calibrator = GeometricCalibration()

    def forward(self, tokens, H, W):
        # Step 1. Extract global context (mean, max, or both)
        context_vec = compute_global_context(tokens, self.context_type)

        # Step 2. Shift base primitives with context
        prototypes = adapt_prmitives(context_vec, self.num_hyperedges)

        # Step 3. Multi-head similarity between tokens and prototypes
        raw_scores = multihead_similarity(tokens, prototypes, self.num_heads)

        # Step 4. Apply Geomeric calibration
        memberships, coeffs = self.calibrator.forward(raw_scores, tokens, H, W)

        return memberships, coeffs


class HyperGraphConv:
    """
    Hypergraph convolution with adaptive hyperedges:
    - Vertex → Edge: pool tokens into hyperedge features
    - Edge → Vertex: redistribute evidence back to tokens
    - Residual connection preserves original token stream
    """
    def __init__(self, num_hyperedges, num_heads, context="both"):
        self.generator = HyperRelationGenerator(num_hyperedges, num_heads, context)

    def forward(self, tokens, H, W):
        # Step 1. Generate memberships via primitives competition + Geomeric calibration
        A, coeffs = self.generator.forward(tokens, H, W)

        # Step 2. Vertex-to-Edge: aggregate token features into hyperedge slots
        edge_features = aggregate_vertices_to_edges(A, tokens)

        # Step 3. Edge-to-Vertex: redistribute features back to tokens
        tokens_refined = redistribute_edges_to_vertices(A, edge_features)

        # Step 4. Residual connection
        return tokens + tokens_refined, coeffs, A


class HyperGraphBlock:
    """
    Wraps HyperGraphConv to integrate with CNN/Transformer backbones.
    Input: feature maps (B, C, H, W) or token sequence (B, N, C)
    Output: refined feature maps with grouping
    """
    def __init__(self, embed_dim, num_hyperedges=16, num_heads=4, context="both"):
        self.hgc = HyperGraphConv(num_hyperedges, num_heads, context)

    def forward(self, x):
        # Flatten 4D features to tokens if necessary
        tokens, H, W = flatten_features(x)  # (B, N, C), H, W

        # Apply hypergraph convolution
        tokens_refined, coeffs, A = self.hgc.forward(tokens, H, W)

        # Reshape back if input was 4D
        return reshape_features(tokens_refined, H, W), coeffs, A
