import torch

from model_selection import register_model_selection_algorithm


@register_model_selection_algorithm("IWV")
def iwv(weights: torch.Tensor, source_val_loss: torch.Tensor, **kwargs) -> torch.Tensor:
    r"""
    Importance-Weighted Validation (IWV) algorithm for unsupervised model selection.

    Estimates model risk through weighted average of validation losses:

    $$
    R_{\mathrm{IWV}}^{(i)} = \frac{1}{N} \sum_{j=1}^{N} w_{ij}\,\ell_{ij},
    $$

    where for $M$ candidate models and $N$ validation samples:
    - $w_{ij}$: Learned weight for model $i$ on sample $j$ (typically $\sum_j w_{ij} = 1$)
    - $\ell_{ij}$: Validation loss for model $i$ on sample $j$

    Selects the model with minimal $R_{\mathrm{IWV}}$.

    Reference: 
        Sugiyama et al., "Importance-Weighted Validation for Robust Model Selection"
        https://jmlr.org/papers/v8/sugiyama07a.html

    Args:
        weights: 
            Tensor of shape $[M, N]$. Learned relative weights
            for $M$ candidate models across $N$ validation samples.
        source_val_loss: 
            Tensor of shape $[M, N]$. Validation losses where lower values 
            indicate better model performance.
        **kwargs: 
            Ignored arguments (maintained for API compatibility)

    Returns:
        Tensor of shape $[M]$: One-hot encoded selection vector where 1 indicates
        the chosen model. Example: For 3 models with minimum at index 1, returns
        tensor([0., 1., 0.])
    """
    _ = kwargs
    # weighted validation loss
    weighted_loss = weights * source_val_loss
    iwv_risk = torch.mean(weighted_loss, axis=-1)

    # only take the model that with minimum iwv
    min_index = torch.min(iwv_risk, dim=0).indices
    model_weights = torch.zeros_like(iwv_risk)
    model_weights[min_index] = 1

    return model_weights
