import numpy as np

class Benchmark:

    def __init__(self, pbr):
        self.pbr = pbr
        self.calculate_sample_complexity_divide_and_battle()
        self.calculate_sample_complexity_trace_the_best()

    
    def calculate_sample_complexity_trace_the_best(self):
        m = (2*self.pbr.k /(self.pbr.eps**2)) * np.log((2*self.pbr.n)/self.pbr.delta)
        self.sample_complexity_trace_the_best = (np.ceil(self.pbr.n/self.pbr.k)) * m
        print(f'Sample complexity for Trace-the-Best: {self.sample_complexity_trace_the_best}')
    
    def calculate_sample_complexity_divide_and_battle(self):
        n = self.pbr.n
        delta = self.pbr.delta/2
        eps = self.pbr.eps/8
        sc=0
        while n > self.pbr.k:
            n = np.ceil(n/self.pbr.k)
            delta = delta/2
            eps = eps * 3 /4
            m = (2*self.pbr.k /(eps**2)) * np.log((self.pbr.k)/delta)
            sc += m * n
        delta = self.pbr.delta
        eps = self.pbr.eps*2/3
        m = (2*self.pbr.k /(eps**2)) * np.log((self.pbr.k)/delta)
        sc += m * n
        self.sample_complexity_divide_and_battle = sc
        print(f'Sample complexity for Divide-and-Battle: {self.sample_complexity_divide_and_battle}')
        

