import torch

def kolm_swf(u, weights, pow):
    assert pow <= 0, "Power must be non-positive."

    if pow == -torch.inf:
        kolm = torch.min(u)
    elif pow == 0:
        kolm = weights @ u
    else:
        kolm = torch.logsumexp(pow * u + torch.log(weights), dim=0) / pow
    return kolm


class KolmSolver:
    def __init__(self, weights, pow, num_alloc):
        assert pow <= 1, "Power must be less than or equal to 1."

        self.weights = weights
        self.pow = pow
        self.num_alloc = num_alloc

    def water_filling(self, u, rates):
        n = rates.shape[0]
        res = torch.zeros_like(rates)

        if self.pow == 0:
            return rates

        # Find important times
        zero_times = -torch.log(self.weights * u)
        one_times = -self.pow * u - torch.log(self.weights * u)

        if self.pow == -torch.inf:
            zero_times = torch.zeros_like(self.weights)
            one_times = u

        times = torch.cat((zero_times, one_times), dim=0)
        sorted_times, indices = torch.sort(times.view(-1))

        uncapped_mask = torch.ones_like(rates, dtype=torch.bool)
        positive_mask = torch.zeros_like(rates, dtype=torch.bool)
        positive_mask[indices[0]] = True

        t = sorted_times[0]
        for i in range(1, len(sorted_times)):
            delta_t = sorted_times[i] - t

            vol_incr = rates * delta_t * (uncapped_mask & positive_mask)
            if vol_incr.sum() + res.sum() > self.num_alloc:
                # Partially fill and stop
                final_delta_t = (self.num_alloc - res.sum()) / (
                    rates * (uncapped_mask & positive_mask)
                ).sum()
                res += rates * final_delta_t * (uncapped_mask & positive_mask)
                break
            else:
                res += vol_incr
                t = sorted_times[i]

            # Update masks
            arm_ind = indices[i]
            if arm_ind < n:
                # new arm has zero volume at this time
                positive_mask[arm_ind] = True
            else:
                # new arm has full volume (one) at this time
                uncapped_mask[arm_ind - n] = False

        assert torch.abs(res.sum() - self.num_alloc) < 1e-6, (
            f"Total allocation mismatch: {res.sum()} vs {self.num_alloc}"
        )
        return res

    def get_allocation_rates(self, u):
        if self.pow == -torch.inf:
            rates = 1 / u
        elif self.pow == 0:
            rates = torch.zeros_like(u)
            inds = torch.argsort(self.weights * u, descending=True)[: self.num_alloc]
            rates[inds] = 1.0
        else:
            rates = 1 / (-self.pow * u)

        return rates

    def get_allocation_probabilities(self, u):
        rates = self.get_allocation_rates(u)
        probs = self.water_filling(u, rates)
        return probs
