import torch

from expground.types import Sequence
from expground.algorithms.base_policy import Policy


def min_population_dist(
    v_funcs: Sequence[Policy],
    sigma: Sequence[float],
    critic_state,
    p: int = 2,
    _lambda: float = 1e-3,
):
    ratio = 1.0 - sigma

    values = [v_func(critic_state) for v_func in v_funcs]
    v_exp = [sigma[i] * v for i, v in enumerate(values)]
    v_exp = torch.sum(torch.cat(v_exp, dim=-1), dim=-1)

    p_norm_vec = [(ratio[i] * v) ** p for i, v in enumerate(values)]
    p_norm_vec = torch.cat(p_norm_vec, dim=-1)
    reg = torch.sum(p_norm_vec, dim=-1) ** (1 / p)

    loss = _lambda * reg + v_exp
    return loss
