import torch
from distances_spectral import (
    pairwise_distance,
    batched_features_for_ensemble,
    FeatureConfig
)


def compute_coulomb_loss_ensemble(
    batch: torch.Tensor,
    epsilon: float,
    gamma: float,
    just_dist: bool,
    ensemble_id: int,
    num_ensembles: int,
    feature_config: FeatureConfig | None = None,
    proj_dim: int = 2,
) -> torch.Tensor:

    if ensemble_id is not None:
        features = batched_features_for_ensemble(
            batch, ensemble_id, num_ensembles=num_ensembles, config=feature_config,
            proj_dim=proj_dim
        )
    else:
        features = batched_features(batch, config=feature_config)
    
    # Compute pairwise distances
    distance = pairwise_distance(features)
    
    if just_dist:
        return distance
    
    batch_size = batch.shape[0]
    mask = 1.0 - torch.eye(batch_size, device=batch.device)
    potential = mask / (distance + epsilon) ** gamma
    loss = potential.sum() / (batch_size * (batch_size - 1))
    
    return loss
