import numpy as np
import math

import os
import uuid
import time

from bandit import Bandit

class DPTT:
    def __init__(self, config):
        """
        Initializes the Private TopTwo algorithm with K arms.

        Args: 
            K (int): The number of arms.
            beta (float): probability of playing the empirical best arm.
            eps (float): privacy parameter.
            delta (float): risk parameter.
            eta (float): geometric batching
        """
        #hyperparameters
        self.K = config["K"]
        self.eps = config["eps"]
        self.delta = config["delta"]
        self.beta = config["beta"]
        if "eta" in config.keys():
            self.eta = config["eta"]
        else:
            self.eta = 1
        if "kappa" in config.keys():
            self.Kappa = config["kappa"]
        else:
            self.Kappa = Kappa
        
        #name of the instance
        self.name = "DPTT_id_"+uuid.uuid4().hex[:8] # add random id 
        if "name" in config.keys():
            self.name = config["name"]
            
        #name of the experiment
        self.exp_name = "FromTerminal"
        if "exp_name" in config.keys():
            self.exp_name = config["exp_name"]
                
        #total counts, cumulated rewards and mean-reward estimates
        self.counts = np.zeros(self.K)
        self.rewards = np.zeros(self.K)
        self.values = np.zeros(self.K)

        #phase statistics
        self.phase = np.zeros(self.K)
        self.last_geom_grid = np.zeros(self.K)
        
        #current estimate of best arm
        self.best = None
        
        #useful to log info
        self.directory = "./experiments/" + self.exp_name +"/DPTT/"
        self.stopping = [] #store stats of stopping rule 
        self.info = dict()
        
    
    def run(self, bandit):
        """
        Runs DPTT algorithm on  bandit.

        Args:
            bandit (Bandit): instance of the Bandit class.
        """
        if bandit.local:
            self.directory = "./experiments/" + self.exp_name +"/CTB-DPTT/"
            
        #intialization phase
        for arm in range(self.K):
            reward = bandit.pull(arm)
            self.counts[arm] += 1
            self.rewards[arm] += reward
            self.update(arm)
        
        #main loop       
        while True : 
            arm = self.select_arm()
            reward = bandit.pull(arm)
            self.counts[arm] += 1
            self.rewards[arm] += reward
            #check if arm counts double:
            if self.doubled_counts(arm):
                self.update(arm)
                # Check stopping condition
                if self.check_stopping():
                    self.save_logs()
                    return self.counts.sum(), self.best


    def update(self, arm):
        """
        Updates phase number and total counts then constructs an eps-DP estimator of the mean rewards of every arm. 

        Args:
            arm (int): The index of the arm to update.
        """
        self.phase[arm] += 1
        self.last_geom_grid[arm] = self.counts[arm]

        # Sample noise from Laplace for DP
        noise = np.random.laplace(scale = 1 / (self.eps))

        # Add noise to accumulated sum of rewards
        self.rewards[arm] += noise

        # Update the noisy empirical mean
        self.values[arm] = self.rewards[arm]/self.counts[arm]

        # Compute new leader
        self.best = np.argmax(np.clip(self.values, 0, 1))
                
    def doubled_counts(self, arm):
        """
        Checks whether the counts of arm have doubled.

        Args: 
            arm(int) : index of the arm
    
        """
        return (self.counts[arm] >= (1 + self.eta)**(self.phase[arm]) ) 
        
    def select_arm(self):
        """
        Selects which arm to play next using the sampling rule of TopTwo.

        Returns:
            int: The index of the arm to play.
        """
        # Compute EB leader
        best = self.best

        # Compute challenger
        challenger  = self.compute_challenger(best)
        
        if np.random.uniform() <= self.beta:
            return best
        else:
            return challenger

    def compute_challenger(self, best):
        """ 
        Computes the challenger arm based on estimates from previous phase.
        
        Args:
            best (int): index of the best arm
        """
        # stats of best arm
        n_best = self.counts[best] 
        v_best = self.values[best]
        
        #compute challenger
        challenger = None 
        minCost = np.inf
        for j in range(self.K):
            # stats of potential challenger
            n_j = self.counts[j]
            v_j = self.values[j]
            #transportation cost
            cost_j = W(v_best, v_j, n_best, n_j, self.eps, prec = 1e-15) + np.log(n_j)
            if j != best and cost_j < minCost:
                challenger = j
                minCost = cost_j
        return challenger
    
    def check_stopping(self):
        """
        Checks whether the stopping condition is verified.
    
        """
        #stats of best
        best = self.best
        n_best = self.counts[best] 
        n_tilde_best = self.last_geom_grid[best] 
        v_best = self.values[best]
        k_best = self.phase[best]
        
        for j in range(self.K):
            # Modified transportation cost
            n_j = self.counts[j]
            n_tilde_j = self.last_geom_grid[j] #
            v_j = self.values[j]
            k_j = self.phase[j]
            # tilde_cost_j = tilde_W(v_best, v_j,  n_best, n_j, self.eps, self.eta, max_iter = 100, error = 1e-4, prec = 1e-15)
            tilde_cost_j = W(v_best, v_j, n_best, n_j, self.eps, prec = 1e-15)
            # threshold_j = true_threshold(k_best, k_j, self.delta, self.K)
            threshold_j = emp_threshold(k_best, k_j, self.delta)
            # threshold_j = emp_threshold_priv(k_best, k_j, self.delta, self.eps)
            if j != best and tilde_cost_j < threshold_j:
                # print(f"step {self.counts.sum()} \
                #     arm {j}, ratio = {cost_j / threshold_j}, \
                #         counts = {self.counts}, values = {self.values}")
                return False  
        return True
    
        
    def save_logs(self):
        #save logs of experiments in dedicated file
        tau  = self.counts.sum()
        a_star = self.best

        if not os.path.isdir(self.directory):
            try:
                os.makedirs(self.directory)
            except FileExistsError:
                print(f"{self.name} file exists")
        filename = self.directory+self.name+".csv"
        
        data = np.array([self.K, self.delta, self.eps, tau, a_star])
        np.savetxt(filename, data, delimiter=",")
        

def true_threshold(k_a, k_b, delta, K):
    zeta = math.pi**2/6
    c_a = gaussian_calib(np.log( K * zeta / delta) + 2*np.log(k_a) + 3 - 2*np.log(2)) - 3 + 2*np.log(2)
    c_b = gaussian_calib(np.log( K * zeta / delta) + 2*np.log(k_b) + 3 - 2*np.log(2)) - 3 + 2*np.log(2)
    return c_a + c_b
           
def Kappa(t):
    """
    Computes the exploration bonus used in the challenger formula.
    
    Args: 
        t (int): total number of samples used so far.
    """
    alpha = 0.5
    return np.log(1+t)** (-alpha/2)

def emp_threshold(k_a, k_b, delta):
    return np.log((k_a + k_b)/delta)

def emp_threshold_priv(k_a, k_b, delta, epsilon):
    return np.log(((k_a + k_b))/delta) + (k_a + k_b)*np.log(1 + epsilon)

def emp_threshold2(k_a, k_b, delta):
    return np.log(1/delta)

def emp_c(t, delta):
    return np.log( (1+np.log(t)) / delta)

def c(n, m, K, t, delta):
    """
    Computes the non-private threshold of the stopping condition.

    Args: 
        n (int): counts of the current best arm
        m (int): counts of the challenger
        K (int): number of arms
        delta (float): risk parameter
        t (int): total number of samples

    """
    zeta = math.pi**2/6
    return 2*gaussian_calib(
        0.5 * np.log((K-1) * zeta * (t**2) / delta)) \
        + 2*np.log(4+np.log(n)) + 2*np.log(4+np.log(m))

def gaussian_calib(x):
    """
    Calibration function based on concentration of sum of KLs of gaussians.
    """
    return x+np.log(x)


def kl(x, y, prec = 1e-15):
    x = min(max(x, prec), 1 - prec)
    y = min(max(y, prec), 1 - prec)
    return x * np.log(x / y) + (1 - x) * np.log((1 - x) / (1 - y))


def g_eps_plus(x, epsilon):
    return x/( x*(1 - np.exp(epsilon)) + np.exp(epsilon) )

def g_eps_moins(x, epsilon):
    return (x*np.exp(epsilon)) / (x*(np.exp(epsilon) - 1) + 1)

def d_eps_plus(x, y, epsilon, prec = 1e-15):
    x = min(max(x, prec), 1 - prec)
    y = min(max(y, prec), 1 - prec)
    if y <= x :
        return 0
    if y <= g_eps_moins(x, epsilon):
        return kl(x, y, prec)
    return - epsilon * x - np.log(1 - y*(1 - np.exp(-epsilon)))

def d_eps_moins(x, y, epsilon, prec = 1e-15):
    x = min(max(x, prec), 1 - prec)
    y = min(max(y, prec), 1 - prec)
    if y >= x:
        return 0
    if y >= g_eps_plus(x, epsilon):
        return kl(x, y, prec)
    return epsilon * x - np.log(1 + y*(np.exp(epsilon) - 1))

def r_plus(a, b, c):
    return (np.sqrt(b*b + 4*a*c) - b)/(2*a)

def W(lamda, mu, omega_1, omega_2, epsilon, prec = 1e-15):
    lamda = min(max(lamda, prec), 1 - prec)
    mu = min(max(mu, prec), 1 - prec)
    if lamda <= mu:
        return 0
    u_star_np = (omega_1 * lamda + omega_2 * mu)/(omega_1 + omega_2)
    if g_eps_moins(mu, epsilon) >= lamda or ( g_eps_moins(mu, epsilon) < lamda and g_eps_plus(lamda, epsilon) <=  g_eps_moins(mu, epsilon) and u_star_np >= g_eps_plus(lamda, epsilon) and u_star_np <= g_eps_moins(mu, epsilon) ):
        star = u_star_np
    exp_eps_plus = np.exp(epsilon) - 1
    exp_eps_moins = 1 - np.exp(-epsilon)
    u_3_star = (omega_1 * (exp_eps_plus) - omega_2 * (exp_eps_moins))/((omega_1 + omega_2)*exp_eps_plus*exp_eps_moins)
    if g_eps_moins(mu, epsilon) < lamda and  g_eps_plus(lamda, epsilon) >  g_eps_moins(mu, epsilon) and u_3_star <= g_eps_plus(lamda, epsilon) and u_3_star >= g_eps_moins(mu, epsilon):
        star = u_3_star
    if (g_eps_moins(mu, epsilon) < lamda and g_eps_plus(lamda, epsilon) <=  g_eps_moins(mu, epsilon) and u_star_np < g_eps_plus(lamda, epsilon)) or (g_eps_moins(mu, epsilon) < lamda and g_eps_plus(lamda, epsilon) >  g_eps_moins(mu, epsilon) and u_3_star < g_eps_moins(mu, epsilon)):
        a = (omega_1 + omega_2)*(np.exp(epsilon) - 1)
        b = omega_2 - (mu*omega_2 + omega_1)*(np.exp(epsilon) - 1)
        c = omega_2 * mu
        star = r_plus(a, b, c)
    if (g_eps_moins(mu, epsilon) < lamda and g_eps_plus(lamda, epsilon) <=  g_eps_moins(mu, epsilon) and u_star_np > g_eps_moins(mu, epsilon)) or (g_eps_moins(mu, epsilon) < lamda and g_eps_plus(lamda, epsilon) >  g_eps_moins(mu, epsilon) and u_3_star > g_eps_plus(lamda, epsilon)):
        a = (omega_1 + omega_2)*(np.exp(epsilon) - 1)
        b = omega_1 - ((1 - lamda)*omega_1 + omega_2)*(np.exp(epsilon) - 1)
        c = omega_1 * (1 - lamda)
        star = 1 - r_plus(a, b, c)
    return omega_1 * d_eps_moins(lamda, star, epsilon) + omega_2 * d_eps_plus(mu, star, epsilon)

# def W(lamda, mu, omega_1, omega_2, epsilon, prec = 1e-15):
#     lamda = min(max(lamda, prec), 1 - prec)
#     mu = min(max(mu, prec), 1 - prec)
#     if lamda <= mu:
#         return 0
#     u_star_np = (omega_1 * lamda + omega_2 * mu)/(omega_1 + omega_2)
#     if g_eps_moins(mu, epsilon) >= lamda or ( g_eps_moins(mu, epsilon) < lamda and g_eps_plus(lamda, epsilon) <=  g_eps_moins(mu, epsilon) and u_star_np >= g_eps_plus(lamda, epsilon) and u_star_np <= g_eps_moins(mu, epsilon) ):
#         return omega_1 * kl(lamda, u_star_np) + omega_2 * kl(mu, u_star_np)
#     exp_eps_plus = np.exp(epsilon) - 1
#     exp_eps_moins = 1 - np.exp(-epsilon)
#     u_3_star = (omega_1 * (exp_eps_plus) - omega_2 * (exp_eps_moins))/(omega_1 + omega_2)
#     if g_eps_moins(mu, epsilon) < lamda and  g_eps_plus(lamda, epsilon) >  g_eps_moins(mu, epsilon) and u_3_star <= g_eps_plus(lamda, epsilon) and u_3_star >= g_eps_moins(mu, epsilon):
#         return omega_1 * (epsilon * lamda - np.log(1 + u_3_star*(np.exp(epsilon) - 1))) + omega_2 * ( - epsilon * mu - np.log(1 - u_3_star*(1 - np.exp(-epsilon))))
#     if (g_eps_moins(mu, epsilon) < lamda and g_eps_plus(lamda, epsilon) <=  g_eps_moins(mu, epsilon) and u_star_np < g_eps_plus(lamda, epsilon)) or (g_eps_moins(mu, epsilon) < lamda and g_eps_plus(lamda, epsilon) >  g_eps_moins(mu, epsilon) and u_3_star < g_eps_moins(mu, epsilon)):
#         a = (omega_1 + omega_2)*(np.exp(epsilon) - 1)
#         b = omega_2 - (mu*omega_2 + omega_1)*(np.exp(epsilon) - 1)
#         c = omega_2 * mu
#         mu_star_dp = r_plus(a, b, c)
#         return omega_1 * (epsilon * lamda - np.log(1 + mu_star_dp*(np.exp(epsilon) - 1)) ) + omega_2 * kl(mu, mu_star_dp) 
#     if (g_eps_moins(mu, epsilon) < lamda and g_eps_plus(lamda, epsilon) <=  g_eps_moins(mu, epsilon) and u_star_np > g_eps_moins(mu, epsilon)) or (g_eps_moins(mu, epsilon) < lamda and g_eps_plus(lamda, epsilon) >  g_eps_moins(mu, epsilon) and u_3_star > g_eps_plus(lamda, epsilon)):
#         a = (omega_1 + omega_2)*(np.exp(epsilon) - 1)
#         b = omega_1 - ((1 - lamda)*omega_1 + omega_2)*(np.exp(epsilon) - 1)
#         c = omega_1 * (1 - lamda)
#         mu_star_dp = 1 - r_plus(a, b, c)
#         return omega_1 *  kl(lamda, mu_star_dp) + omega_2 * (- epsilon * mu - np.log(1 - mu_star_dp*(1 - np.exp(-epsilon))))
    
# def f_x_plus(x, lamda, mu, r, epsilon):
#     return np.log(1 + (x)/( g_eps_plus(mu, epsilon)*(1 - x -  g_eps_plus(mu, epsilon)) ) ) + epsilon * ( ( r*epsilon*(x + g_eps_plus(mu, epsilon) - lamda) )/( np.sqrt( ( r*epsilon*(x + g_eps_plus(mu, epsilon) - lamda))**2 + 1)  + 1 )  - 1 )

# def x_plus(lamda, mu, r, epsilon, max_iter = 100, error = 1e-5):
#     low = max(0, lamda - g_eps_plus(mu, epsilon))
#     hi = mu - g_eps_plus(mu, epsilon)
#     for i in range(max_iter):
#         mid = (low + hi)/2
#         if mid == low or mid == hi:
#             return mid
#         fmid = f_x_plus(mid, lamda, mu, r, epsilon)
#         if fmid < -error:
#             low = mid
#         elif fmid > error:
#             hi = mid
#         else:
#             return mid
#     return (low + hi)/2

# def h(x):
#     if x <= 1e-3:
#         return (x*x)/4
#     return np.sqrt(1 + x*x) - 1 + np.log( (2*(np.sqrt(1 + x*x) - 1))/(x*x) )

# def tilde_d_eps_plus(lamda, mu, r, epsilon, max_iter = 100, error = 1e-4, prec = 1e-15):
#     lamda_clip = min(max(lamda, prec), 1 - prec)
#     mu = min(max(mu, prec), 1 - prec)
#     if mu <= lamda_clip:
#         return 0
#     x_eps = x_plus(lamda, mu, r, epsilon, max_iter, error)
#     g_eps = g_eps_plus(mu, epsilon)
#     return kl(g_eps + x_eps, mu) +  (h( r*epsilon*( g_eps + x_eps - lamda ) ))/(r)    

# def f_x_moins(x, lamda, mu, r, epsilon):
#     return -np.log(1 + (x)/( (g_eps_moins(mu, epsilon) - x)*(1 -  g_eps_moins(mu, epsilon)) ) ) + epsilon * ( 1 - (( r*epsilon*(x + lamda - g_eps_moins(mu, epsilon) ) ) / ( np.sqrt( ( r*epsilon*(x + lamda - g_eps_moins(mu, epsilon) ))**2 + 1)  + 1 ) ) )

# def x_moins(lamda, mu, r, epsilon, max_iter = 100, error = 1e-4):
#     low = max(0, g_eps_moins(mu, epsilon) - lamda)
#     hi = g_eps_moins(mu, epsilon) - mu
#     for i in range(max_iter):
#         mid = (low + hi)/2
#         if mid == low or mid == hi:
#             return mid
#         fmid = f_x_moins(mid, lamda, mu, r, epsilon)
#         if fmid < -error:
#             hi = mid
#         elif fmid > error:
#             low = mid
#         else:
#             return mid
#     return (low + hi)/2

# def tilde_d_eps_moins(lamda, mu, r, epsilon, max_iter = 100, error = 1e-4, prec = 1e-15):
#     lamda_clip = min(max(lamda, prec), 1 - prec)
#     mu = min(max(mu, prec), 1 - prec)
#     if mu >= lamda_clip:
#         return 0
#     x_eps = x_moins(lamda, mu, r, epsilon, max_iter, error)
#     g_eps = g_eps_moins(mu, epsilon)
#     return kl(g_eps - x_eps, mu) +  (h( r*epsilon*(x_eps + lamda - g_eps  ) ))/(r)    

# def r_eta(x, eta):
#     return (x)/(1 + np.emath.logn( 1+ eta, x))

# def f_x_W(u, lamda, mu, omega_1, omega_2, epsilon, eta,  max_iter = 100, error = 1e-4):
#     return u * (omega_1 + omega_2) - omega_1 * g_eps_moins(u, epsilon) - omega_2 * g_eps_plus(u, epsilon) + omega_1 * x_moins(lamda, u, r_eta(omega_1, eta), epsilon,  max_iter, error) - omega_2 * x_plus(mu, u, r_eta(omega_2, eta), epsilon,  max_iter, error)

# def mu_star(lamda, mu, omega_1, omega_2, epsilon, eta,  max_iter = 100, error = 1e-4, prec = 1e-15):
#     low = min(max(mu, prec), 1 - prec)
#     hi = min(max(lamda, prec), 1 - prec)
#     for i in range(max_iter):
#         mid = (low + hi)/2
#         if mid == low or mid == hi:
#             return mid
#         fmid = f_x_W(mid, lamda, mu, omega_1, omega_2, epsilon, eta,  max_iter, error)
#         if fmid < -error:
#             low = mid
#         elif fmid > error:
#             hi = mid
#         else:
#             return mid
#     return (low + hi)/2

# def tilde_W(lamda, mu, omega_1, omega_2, epsilon, eta, max_iter = 100, error = 1e-4, prec = 1e-15):
#     lamda_clip = min(max(lamda, prec), 1 - prec)
#     mu_clip = min(max(mu, prec), 1 - prec)
#     if lamda_clip <= mu_clip:
#         return 0
#     star = mu_star(lamda, mu, omega_1, omega_2, epsilon, eta,  max_iter, error, prec = 1e-15)
#     return omega_1 * tilde_d_eps_moins(lamda, star, r_eta(omega_1, eta), epsilon, max_iter, error, prec) +  omega_2 * tilde_d_eps_plus(mu, star, r_eta(omega_2, eta), epsilon, max_iter, error, prec)

# def argmax_rd_tie(arr):
    return np.random.choice(np.where(arr == arr.max())[0])


if __name__=="__main__":
    K = 5
    mu = np.linspace(0, 1, 5)
    config = {"K": K, "beta": 0.5, "eps": 1.0, "delta": 0.1, "eta": 1}
    for n in range(10):
        my_bandit = Bandit(K, mu)
        dp_top_two = DPTT(config)
        dp_top_two.run(my_bandit)