import torch

EPS = 1e-8


def gini_swf(u, weights):
    assert torch.all(weights[1:] <= weights[:-1]), "Weights must be non-increasing"

    swf = weights @ torch.sort(u, descending=False)[0]
    return swf


class GiniSolver:
    def __init__(self, weights, num_alloc):
        assert torch.all(weights[1:] <= weights[:-1]), "Weights must be non-increasing"

        self.weights = weights
        self.num_alloc = num_alloc

        self.is_egal, self.is_util = False, False
        if torch.all(weights == weights[0]):
            self.is_util = True
        if torch.all(weights[1:] == 0.0):
            self.is_egal = True

    def water_filling(self, u):
        n = u.shape[0]
        sorted_u, indices = torch.sort(u, descending=False)

        # precompute suffix sums for weights and inverse utilities
        weights_suffix_sum = torch.cumsum(self.weights.flip(0), dim=0).flip(0)
        inv_u = 1 / sorted_u
        inv_u_suffix_sum = torch.cumsum(inv_u.flip(0), dim=0).flip(0)

        # calculate rate of SWF change when all elements to the right are filled to the same level
        rates = weights_suffix_sum / inv_u_suffix_sum
        rates = torch.clamp(rates, min=EPS)

        rem = self.num_alloc
        res = torch.zeros_like(u)
        right_inds = n * torch.ones_like(u, dtype=torch.long)

        while rem > 0:
            # find index whose block has maximum rate of SWF change
            max_rate_ind = torch.argmax(rates).item()
            num_elems = right_inds[max_rate_ind] - max_rate_ind

            # find time at which max_rate_index is fully filled
            time = (1 - res[max_rate_ind]) / inv_u[max_rate_ind]

            # find total mass if max_rate_index becomes fully filled
            max_mass = inv_u[max_rate_ind : right_inds[max_rate_ind]] * time

            # check if total mass still does not fill rem
            if max_mass.sum() <= rem:
                # fill max_rate_index and all elements in its block
                res[max_rate_ind : right_inds[max_rate_ind]] += max_mass
                rem -= max_mass.sum()

                # update right indices and rates for remainder of the block to the left
                rates[max_rate_ind] = 0
                j = max_rate_ind - 1
                R = right_inds[max_rate_ind]
                while (j >= 0) and (right_inds[j] == R):
                    right_inds[j] = max_rate_ind
                    weights_suffix_sum[j] -= weights_suffix_sum[max_rate_ind]
                    inv_u_suffix_sum[j] -= inv_u_suffix_sum[max_rate_ind]
                    rates[j] = weights_suffix_sum[j] / (inv_u_suffix_sum[j] + EPS)
                    j -= 1
            else:
                # partially fill max_rate_index and its block
                time = rem / (inv_u[max_rate_ind : right_inds[max_rate_ind]]).sum()
                res[max_rate_ind : right_inds[max_rate_ind]] += (
                    inv_u[max_rate_ind : right_inds[max_rate_ind]] * time
                )
                rem = 0

        # rearrange results to original order
        final_res = torch.zeros_like(res)
        final_res[indices] = res

        assert torch.abs(final_res.sum() - self.num_alloc) < 1e-6, (
            f"Total allocation mismatch: {final_res.sum()} vs {self.num_alloc}"
        )
        return final_res

    def get_allocation_probabilities(self, u):
        if self.num_alloc == u.shape[0]:
            return torch.ones_like(u)
        return self.water_filling(u)
