import numpy as np
from collections import defaultdict
from scipy.optimize import minimize_scalar
from bandit.best_arm import UGapEc


class BMEfromBAI():
    """Best Mean-reward Estimation from Best Arm Identification"""

    def __init__(self, epsilon, delta, best_arm_identifier):
        self._epsilon = epsilon
        self._delta = delta
        self._best_arm_identifier = best_arm_identifier(2 / 3 * epsilon, delta / 2)
        self._additional_sample = 0

    def get_total_n_sample(self):
        return sum(self._best_arm_identifier.get_n_sample()) + self._additional_sample

    def get_max_mu(self):
        return self._max_mu
    
    def run(self, env):
        n_star = 9 / (2 * self._epsilon)**2 * np.log((env.n_arms() + (self._delta / 2)**8) / (self._delta / 2))
        n_star = int(np.ceil(n_star))

        # Identify best arm
        best_arm = self._best_arm_identifier.run(env)
        
        # Take additional sample from the best arm
        total_sample = sum(env.multi_pull(best_arm, n_star))
        self._additional_sample = n_star

        self._max_mu = total_sample / n_star
        
        return self._max_mu


class UGapEc2BME:

    def __init__(self, epsilon, delta):
        self._epsilon = epsilon
        self._delta = delta

        self._alg = BMEfromBAI(epsilon, delta, UGapEc)

    def get_total_n_sample(self):
        return self._alg.get_total_n_sample()

    def get_max_mu(self):
        return self._alg.get_max_mu()

    def run(self, env):
        best_mean = self._alg.run(env)
        return best_mean
        

class EllipsoidBME:
    """Ellipsoid PAC Best Mean-reward Estimator"""

    def __init__(self, epsilon, delta, R, S, reg=None, tight=False):
        self._epsilon = epsilon
        self._delta = delta
        self._R = R            # R-sub-gaussian
        self._S = S            # max_i |mu_i|
        self._lambda = reg
        self._tight = tight    # tighter version

        self._history = defaultdict(list)

    def get_total_n_sample(self):
        return sum(self._n_sample)

    def get_max_mu(self):
        return np.max(self._mu)
    
    def run(self, env, tmax=1000000000, detail=False):

        n_arms = env.n_arms()
        if self._tight:
            self._ucb = np.full(n_arms, np.inf)
            self._lcb = np.full(n_arms, -np.inf)

        if self._lambda is None:
            # regularization paramter
            self._lambda = (self._R / self._S)**2
        
        n_sample = np.zeros(n_arms, dtype=int)  # count how many times each arm is pulled
        mu = np.zeros(n_arms) # estimated means

        phase = 1
        U = np.inf
        beta = 2 * self._R**2 * np.log(1 / self._delta)
        ellipsoid = 0.
        for t in range(tmax):
            
            if t < n_arms:
                # pull each arm once in the first K rounds
                arm = t
                sample = env.pull(arm)
                n_sample[arm] = 1
                mu[arm] = sample
                
                if detail:
                    self._history["sample"].append(sample)
                    self._history["arm"].append(arm)
                    self._history["U"].append(U)
                    self._history["beta"].append(beta)
                    self._history["ellipsoid"].append(ellipsoid)
                
                if t == n_arms - 1:
                    beta = self._get_bound(t, n_sample, n_arms)  # beta(t)
                    ucb = self._get_ucb(mu, beta, n_sample)            # mu(t) + sqrt(beta(t)/N(t))
                    if self._tight:
                        lcb = self._get_lcb(mu, beta, n_sample)
                        self._ucb = np.minimum(self._ucb, ucb)
                        self._lcb = np.minimum(self._lcb, lcb)

                continue
        
            # select arm
            if phase == 2 and n_sample[tentative_best] <= r * beta:
                arm = tentative_best
            else:
                arm = np.argmax(ucb)
        
            # pull arm
            sample = env.pull(arm)
        
            # update -> N(t), mu(t)
            n_sample[arm] += 1
            mu[arm] = ((n_sample[arm] + self._lambda) * mu[arm] + sample) / (n_sample[arm] + self._lambda + 1)

            if detail:
                self._history["sample"].append(sample)
                self._history["arm"].append(arm)
                self._history["U"].append(U)
                self._history["beta"].append(beta)
                self._history["ellipsoid"].append(ellipsoid)

            # UCB
            beta = self._get_bound(t, n_sample, n_arms)       # beta(t)
            ucb = self._get_ucb(mu, beta, n_sample)           # mu(t) + sqrt(beta(t)/N(t))
            if self._tight:
                lcb = self._get_lcb(mu, beta, n_sample)
                self._ucb = np.minimum(self._ucb, ucb)
                self._lcb = np.maximum(self._lcb, lcb)
                U = np.max(self._ucb)
                L = np.max(self._lcb)
            else:
                U = np.max(ucb)                               # U(t)
            ellipsoid = self._get_ellipsoid(n_sample, mu, U)  # sum_i (mu(t) - U(t) + 2 eps)^2_+
            stop = (ellipsoid >= beta)
            if self._tight:
                stop = stop or (U - L <= 2 * self._epsilon)
            #if ellipsoid >= beta:
            if stop:
                best_mean_estimate = U - self._epsilon

                print("best mean estimated", best_mean_estimate) 
                self._n_sample = n_sample
                self._mu = mu
               
                return best_mean_estimate
        
            # check phase change
            if phase == 1 and U < np.max(mu) + 2 * self._epsilon:
                phase_change = False
                if self._get_lower_g(mu, U) <= 0:
                    phase_change = True
                elif self._get_upper_g(mu, U) <= 0:
                    phase_change = True
                else:
                    phase_change = False
                if phase_change:
                    print(f"Phase change at {t}")
                    phase = 2
                    tentative_best = np.argmax(mu)
                    U_estiamte = self._get_U_estimate(mu)
                    r = self._get_r(U, mu)
        
        print("reached max iteration")
        self._n_sample = n_sample
        self._mu = mu

        return U - self._epsilon

    def _get_bound(self, t, N, K):
        bound = 2 * self._R**2 * np.log(1 / self._delta)
        bound += self._R**2 * np.sum(np.log(1 + N / self._lambda))
        bound += self._lambda * self._S**2 * np.sum(N / (N + self._lambda))
        return bound
    
    def _get_ucb(self, mu, bound, N):
        return mu + np.sqrt(bound / N)
    
    def _get_lcb(self, mu, bound, N):
        return mu - np.sqrt(bound / N)

    def _get_ellipsoid(self, N, mu, U):
        return np.sum(N * np.maximum((mu - U + 2 * self._epsilon), 0)**2)
    
    def _get_lower_g(self, mu, U):
        if U <= np.max(mu):
            return -np.inf
        g = 1
        g -= np.sum(np.maximum((mu - U + 2 * self._epsilon), 0)**2 / (U - mu)**2)
        return g
    
    def _get_upper_g(self, mu, U):
        best_arm = np.argmax(mu)
        mu_except_best = np.delete(mu, best_arm)
        mu_max = mu[best_arm]
        if U >= mu_max + 2 * self._epsilon:
            return np.inf
        elif U <= mu_max:
            return -np.inf
        g = 1 / (mu_max - U + 2 * self._epsilon)**3
        g -= np.sum(1 / (U - mu_except_best)**3)
        return g
    
    def _find_zero(self, f, lb, ub, tol=1e-6, max_iter=1000):
        """
        Finds a zero of the function f in the interval [lb, ub] using the bisection method.
        
        Parameters:
        f         -- function for which we want to find a root
        lb, ub    -- lower and upper bounds of the interval
        tol       -- tolerance for stopping condition
        max_iter  -- maximum number of iterations
        
        Returns:
        A value x in [lb, ub] such that f(x) is close to zero
        """
        
        if f(lb) * f(ub) > 0:
            print(f(lb), f(ub))
            raise ValueError("f(lb) and f(ub) must have opposite signs")
    
        for i in range(max_iter):
            mid = (lb + ub) / 2
            f_mid = f(mid)
            
            if abs(f_mid) < tol or (ub - lb) / 2 < tol:
                return mid
            
            if f(lb) * f_mid < 0:
                ub = mid
            else:
                lb = mid
    
        raise RuntimeError("Max iterations exceeded without finding root")
    
    def _find_minimizer(self, f, lb, ub, tol=1e-6):
        """
        Finds the minimizer of the function f in the interval [lb, ub] using Brent's method.
        
        Parameters:
        f    -- function to minimize
        lb   -- lower bound of the interval
        ub   -- upper bound of the interval
        tol  -- tolerance for the minimizer location
        
        Returns:
        A value x in [lb, ub] such that f(x) is (approximately) minimized
        """
        result = minimize_scalar(f, bounds=(lb, ub), method='bounded', options={'xatol': tol})
        
        if result.success:
            return result.x
        else:
            raise RuntimeError("Minimization did not converge")

    def _get_U_estimate(self, mu):
        best_arm = np.argmax(mu)
        mu_except_best = np.delete(mu, best_arm)
        mu_max = mu[best_arm]
        
        def lower_g(x):
            return self._get_lower_g(mu, x)
    
        def upper_g(x):
            return self._get_upper_g(mu, x)
            if x <= mu_max:
                return -np.inf
            elif x >= mu_max + 2 * self._epsilon:
                return np.inf
            else:
                ug = 1 / (mu_max - x + 2 * self._epsilon)**3
                ug -= np.sum(1 / (x - mu_except_best)**3)
                return ug
        
        def f(x):
            if x <= mu_max:
                return np.inf
            elif x >= mu_max + 2 * self._epsilon:
                return np.inf
            else:
                _f = 1
                _f -= np.sum(np.maximum(mu_except_best - x + 2 * self._epsilon, 0)**2 / (x - mu_except_best)**2)
                _f /= (mu_max - x + 2 * self._epsilon)**2
                _f += np.sum(1 / (x - mu_except_best)**2)
                return _f
    
        lower_U = self._find_zero(lower_g, mu_max, mu_max + 2 * self._epsilon)
        if upper_g(lower_U) >= 0:
            # this implies lower_U >= upper_U
            return lower_U
    
        upper_U = self._find_zero(upper_g, mu_max, mu_max + 2 * self._epsilon)
        U = self._find_minimizer(f, lower_U, upper_U)
    
        return U
    
    def _get_r(self, U, mu):
        best_arm = np.argmax(mu)
        mu_except_best = np.delete(mu, best_arm)
        mu_max = mu[best_arm]
        r = 1
        r -= np.sum(np.maximum(mu_except_best - U + 2 * self._epsilon, 0)**2 / (U - mu_except_best)**2)
        r /= (mu_max - U + 2 * self._epsilon)**2

        return r


class SuccessiveEliminationBME:
    """Successive Elimination PAC Best Mean-reward Estimator"""

    def __init__(self, epsilon, delta, support=[0, 1]):
        self._epsilon = epsilon
        self._delta = delta
        self._support = support

        self._sample = defaultdict(list)
        self._n = defaultdict(list)

    def get_total_n_sample(self):
        return sum(self._n_sample)

    def get_max_mu(self):
        return np.max(self._mu)
        
    def run(self, env, detail=False):
        const = np.pi**2 / 6
        n_arms = env.n_arms()

        step = 0
        total_sample = np.zeros(n_arms)
        n_sample = np.zeros(n_arms).astype(int)
        arm_set = set([i for i in range(n_arms)])

        alpha = self._support[1] - self._support[0]
        while alpha > self._epsilon:
            step += 1
            for arm in arm_set:
                sample = env.pull(arm)
                total_sample[arm] += sample
                n_sample[arm] += 1
                if detail:
                    self._sample[arm].append(sample)
                    self._n[arm].append(n_sample[arm])

            sample_average = total_sample / n_sample
            max_average = max(sample_average)
            alpha = np.sqrt(np.log(2 * const * n_arms * step**2 / self._delta) / (2 * step))
            alpha *= self._support[1] - self._support[0]  # scaler
            to_remove = np.where(max_average - sample_average >= 2 * alpha)[0]
            arm_set -= set(to_remove)

        self._n_sample = n_sample
        self._total_sample = total_sample

        self._mu = total_sample / n_sample
        
        remaining_arms = list(arm_set)
        best_reward = np.max([sample_average[arm] for arm in remaining_arms])
        return best_reward

    def get_mean_reward(self, arm):
        # get sample mean reward from the arm
        if len(self._sample[arm]) == 0:
            return 0
        else:
            return sum(self._sample[arm]) / self._n_sample[arm]
    
    def get_best_mean_reward(self):
        averages = [self.get_mean_reward(arm) for arm in self._sample]
        return max(averages)
