import numpy as np
import torch

EPS = 1e-1


class BetaArmSet:
    def __init__(self, n_arms, gen, extreme_means=False):
        self.n_arms = n_arms
        self.gen = gen

        if extreme_means:
            spike_frac = 0.9
            n_spikes = int(spike_frac * n_arms)
            self.alphas = torch.cat(
                [
                    torch.randint(45, 50, (n_spikes,), dtype=torch.float64, generator=gen),
                    torch.randint(1, 5, (n_arms - n_spikes,), dtype=torch.float64, generator=gen),
                ]
            )
            self.betas = torch.cat(
                [
                    torch.randint(1, 5, (n_spikes,), dtype=torch.float64, generator=gen),
                    torch.randint(45, 50, (n_arms - n_spikes,), dtype=torch.float64, generator=gen),
                ]
            )
            self.means = self.alphas / (self.alphas + self.betas)
        else:
            self.alphas = torch.randint(1, 10, (n_arms,), dtype=torch.float64, generator=gen)
            self.betas = torch.randint(1, 10, (n_arms,), dtype=torch.float64, generator=gen)
            self.means = self.alphas / (self.alphas + self.betas)

        self.means = EPS + (1 - EPS) * self.means  # scale to be above EPS

        # inds = np.argsort(-self.means)
        inds = np.argsort(self.means)
        self.alphas = self.alphas[inds]
        self.betas = self.betas[inds]
        self.means = self.means[inds]
        self.dist = torch.distributions.Beta(self.alphas, self.betas)

    def sample(self, inds):
        # samples = self.rng.beta(self.alphas[inds], self.betas[inds])
        samples = self.dist.sample()[inds]
        samples = EPS + (1 - EPS) * samples  # scale to be above EPS
        return samples

    def opt_probs(self, solver, swf):
        probs = solver.get_allocation_probabilities(self.means)
        self.optimal_probs = solver.water_filling(probs)
        self.optimal_wpm = swf(self.means * self.optimal_probs)
