import numpy as np
import math

from bandit import Bandit
from algorithms.adapTT import AdaPTopTwo

class ImprovedAdaPTopTwo(AdaPTopTwo):
    def __init__(self, config):
        super().__init__(config)
        self.directory = "./experiments/" + self.exp_name +"/ImpAdaPTT/"
        
    def compute_challenger(self, best):
        """ 
        Computes the challenger arm based on estimates from previous phase.
        
        Args:
            best (int): index of the best arm
        """
        t = self.counts.sum()
        #stats of best arm
        n_best = self.counts[best] + self.ph_counts[best]
        v_best = self.values[best]
        eps = self.eps

        challenger = None 
        minCost = np.inf
        for j in range(self.K):
            n_j = self.counts[j] + self.ph_counts[j]
            v_j = self.values[j]
            delta_tilde_j = v_best - v_j if v_best - v_j > 0 else 0
            cost = (delta_tilde_j * min(3*eps, delta_tilde_j))/(1/n_best +  1/n_j)
            if j != best and cost < minCost:
                challenger = j
                minCost = cost
        return challenger
    
    def check_stopping(self):
        """
        Checks whether the stopping condition is met.

        Args: 
            ph_counts (numpy.array): counts of the current phase 

        """
        #stats of best arm
        best = np.argmax(self.values)
        n_best = self.last_ph_counts[best]#+1e-10
        v_best = self.values[best]
        eps = self.eps
        delta = self.delta

        #check stopping rule
        for j in range(self.K):
            n_j = self.last_ph_counts[j]#+1e-10
            v_j = self.values[j]
            delta_tilde_j = v_best - v_j if v_best - v_j > 0 else 0
            cost_j = 1*(delta_tilde_j * min(3*eps, delta_tilde_j))/(1/n_best +  1/n_j)

            threshold_j_low_priv = emp_c(n_best, n_j, delta, eps)
            threshold_j_high_priv = emp_c2(n_best, n_j, delta, eps)

            threshold_j = threshold_j_low_priv if delta_tilde_j < 3*eps else threshold_j_high_priv
            
            if j != best and cost_j < threshold_j:
                return False
        return True


def Kappa(t):
    """
    Computes the exploration bonus used in the challenger formula.
    
    Args: 
        t (int): total number of samples
    """
    alpha = 0.5
    return 1/(1+np.sqrt(t)) #np.log(1+t)** (-alpha/2)

def emp_c(n, m, delta, eps):
    term_0 = np.log(1/delta) + np.log(1 + np.log(n)) + np.log(1 + np.log(m))
    dp_term = (1/n + 1/m) * (1/eps**2) * np.log(1/delta)**2
    return term_0 + dp_term

def emp_c2(n,m,delta, eps):
    term_0 = np.log(1/delta) + np.log(1 + np.log(n)) + np.log(1 + np.log(m))
    dp_term = eps * (np.sqrt(n * (np.log(1/delta) +  np.log(1 + np.log(n)) ) ) + np.sqrt(m * (np.log(1/delta) +  np.log(1 + np.log(m)) )))
    return term_0 + dp_term

def dp_c(n, m, K, ph, delta, eps):
    """
    Computes the private threshold of the stopping condition.

    Args: 
        n (int): counts of the current best arm
        m (int): counts of the challenger
        K (int): number of arms
        delta (float): risk parameter
        ph (int): index of the phase
        eps (float): privacy parameter

    """
    zeta = math.pi**2 / 6
    term_0 = 2*gaussian_calib(
        0.5 * np.log((K-1) * zeta * (ph**2) / delta)) \
        + 2*np.log(4+np.log(n)) + 2*np.log(4+np.log(m))
    dp_term = (1/n + 1/m) * (1/eps**2)\
        * np.log(2 * K * zeta * (ph**2) / delta)**2
    return term_0 + dp_term
    
def gaussian_calib(x):
    """
    Calibration function based on concentration of sum of KLs of gaussians.
    """
    return x+np.log(x)



if __name__=="__main__":
    K = 5
    mu = np.linspace(0, 1, 5)
    config = {"K": K, "beta": 0.5, "eps": 1.0, "delta": 0.1}
    for n in range(10):
        my_bandit = Bandit(K, mu)
        adap_top_two = ImprovedAdaPTopTwo(config)
        adap_top_two.run(my_bandit)