import torch

from model_selection import register_model_selection_algorithm


@register_model_selection_algorithm("DEV")
def dev(weights: torch.Tensor, source_val_loss: torch.Tensor, **kwargs) -> torch.Tensor:
    r"""
    Deep Embedded Validation (DEV) algorithm for unsupervised model selection.

    Computes a corrected risk estimate by modeling the dependency between
    learned sample weights (estimated density ratios) and validation losses:

    $$
      \eta = -\frac{\mathrm{Cov}(w, \ell)}{\mathrm{Var}(w)},
      \quad
      R_{\mathrm{DEV}} = \mathbb{E}[\,w\,\ell\,] \;+\;\eta\,\mathbb{E}[\,w\,] \;-\;\eta
    $$

    Where:
    - $w$ (weights): shape $[M, N]$ for $M$ models, $N$ validation samples
    - $\ell$ (source_val_loss): shape $[M, N]$, validation losses
    - Expectations and variances computed over sample dimension

    Selects the model with minimal $R_{\mathrm{DEV}}$.

    Reference:
    You et al. "Towards Accurate Model Selection in Deep Unsupervised Domain Adaptation"
    https://proceedings.mlr.press/v97/you19a.html


    Args:
        weights: 
            Tensor of shape $[M, N]$. Learned relative weights
            for $M$ candidate models across $N$ validation samples.
        source_val_loss: 
            Tensor of shape $[M, N]$. Validation losses where lower values 
            indicate better model performance.
        **kwargs: 
            Additional keyword arguments (ignored, maintained for API compatibility)

    Returns:
        Tensor of shape $[M]$: One-hot encoded selection vector where 1 indicates
        the chosen model. Example: For 3 models with minimum at index 1, returns
        tensor([0., 1., 0.])
    """
    _ = kwargs
    # weights: [n_models, n_samples], source_val_loss: [n_models, n_samples]
    weights_mean = torch.mean(weights, dim=-1)  # [n_models]
    weighted_loss = weights * source_val_loss  # [n_models, n_samples]
    weighted_loss_mean = torch.mean(weighted_loss, dim=1)  # [n_models]

    weighted_loss_centered = weighted_loss - weighted_loss_mean.unsqueeze(
        1
    )  # [n_models, n_samples]
    weights_centered = weights - weights_mean.unsqueeze(1)  # [n_models, n_samples]

    cov_lw = torch.mean(weighted_loss_centered * weights_centered, dim=1)  # [n_models]
    var_w = torch.mean(weights_centered**2, dim=1)  # [n_models]

    eta = -cov_lw / var_w
    R_dev = weighted_loss_mean + eta * weights_mean - eta  # [n_models]

    # select model with minimum dev risk
    min_index = torch.argmin(R_dev)
    model_weights = torch.zeros_like(R_dev)
    model_weights[min_index] = 1

    return model_weights
