import numpy as np
import torch


class SWFUCB:
    def __init__(self, n_arms, num_alloc, solver, sampler, delta=0.05):
        self.n_arms = n_arms
        self.num_alloc = num_alloc
        self.delta = delta

        self.sampler = sampler
        self.solver = solver

        self.arm_sums = torch.zeros(n_arms, dtype=torch.float64)
        self.arm_counts = torch.zeros(n_arms, dtype=torch.float64)
        self.t = 0
        self.probs = torch.ones(n_arms, dtype=torch.float64) / n_arms

    def select_arms(self):
        if self.num_alloc * self.t < self.n_arms:
            inds = torch.arange(
                self.num_alloc * self.t, min(self.num_alloc * (self.t + 1), self.n_arms)
            )
            self.probs = torch.zeros(self.n_arms)
            self.probs[inds] = 1.0
        else:
            arm_means = self.arm_sums / (self.arm_counts + 1e-8)  # torch tensor
            delta_tot = self.delta

            # union bound over arms
            delta_i = delta_tot / torch.tensor(self.n_arms, dtype=torch.float64)

            log_term = torch.log(0.1 / delta_i)
            loglogcnt = torch.log(torch.log(2 * self.arm_counts + 2) + 1e-8)

            ucb_values = arm_means + 1.7 * torch.sqrt(
                (log_term + loglogcnt) / (self.arm_counts + 1e-8)
            )

            self.probs = self.solver.get_allocation_probabilities(ucb_values)

            probs = self.probs.clone()

            self.sampler.sample(probs)
            inds = torch.where(torch.isclose(probs, torch.tensor(1.0, dtype=torch.float64)))

        return inds

    def update(self, inds, rewards):
        self.t += 1
        self.arm_counts[inds] += 1
        self.arm_sums[inds] += rewards
