import torch


def wpm_swf(u, weights, pow):
    assert pow <= 1, "Power must be less than or equal to 1."

    if pow == -torch.inf:
        wpm = torch.min(u)
    elif pow == 0:
        log_wpm = (torch.log(u) * weights).sum(dim=-1)
        wpm = torch.exp(log_wpm)
    else:
        log_wpm = torch.logsumexp(torch.log(weights) + torch.log(u) * pow, dim=-1) / pow
        wpm = torch.exp(log_wpm)

    return wpm


class WPMSolver:
    def __init__(self, weights, pow, num_alloc):
        self.weights = weights
        self.pow = pow
        self.num_alloc = num_alloc

    def water_filling(self, rates):
        n = rates.shape[0]
        res = torch.zeros_like(rates)

        # sort filling times in ascending order (larger rate => smaller time)
        indices = torch.argsort(rates, descending=True)
        rates = rates[indices]
        rates_suffix_sum = torch.cumsum(rates.flip(0), dim=0).flip(0)

        i, rem = 0, self.num_alloc
        while (rem > 0) and (i < n):
            # find time when next container fills up
            time = 1 / rates[i]

            if time * rates_suffix_sum[i] > rem:
                # partially fill remaining containers and stop
                time = rem / rates_suffix_sum[i]
                res[i:n] = time * rates[i:n]
                rem = 0
                break
            else:
                # fill this container to 1.0 and continue
                res[i] = 1.0
                rem -= 1
                i += 1

        # rearrange results to original order
        final_res = torch.zeros_like(res)
        final_res[indices] = res

        assert torch.abs(final_res.sum() - self.num_alloc) < 1e-6, (
            f"Total allocation mismatch: {final_res.sum()} vs {self.num_alloc}"
        )
        return final_res

    def get_allocation_rates(self, u):
        assert self.pow <= 1, "Power must be less than or equal to 1."

        if self.pow == -torch.inf:
            rates = 1 / u
        elif self.pow == 0:
            rates = self.weights
        elif self.pow == 1:
            rates = torch.zeros_like(u)
            inds = torch.argsort(self.weights * u, descending=True)[: self.num_alloc]
            rates[inds] = 1.0
        else:
            log_rates = (torch.log(self.weights) + self.pow * torch.log(u)) / (1 - self.pow)
            rates = torch.softmax(log_rates, dim=-1)

        return rates

    def get_allocation_probabilities(self, u):
        rates = self.get_allocation_rates(u)
        probs = self.water_filling(rates)
        return probs
