import numpy as np
from collections import defaultdict


class BMEfromBAI():
    """Best Mean-reward Estimation from Best Arm Identification"""

    def __init__(self, epsilon, delta, best_arm_identifier):
        self._epsilon = epsilon
        self._delta = delta
        self._best_arm_identifier = best_arm_identifier(2 / 3 * epsilon, delta / 2)
        self._additional_sample = 0

    def get_total_n_sample(self):
        return sum(self._best_arm_identifier.get_n_sample()) + self._additional_sample

    def run(self, env):
        n_star = 9 / (2 * self._epsilon)**2 * np.log((env.n_arms() + (self._delta / 2)**8) / (self._delta / 2))
        n_star = int(np.ceil(n_star))

        # Identify best arm
        best_arm = self._best_arm_identifier.run(env)
        n_sample = self._best_arm_identifier.get_n_sample()[best_arm]
        total_sample = self._best_arm_identifier.get_total_sample()[best_arm]

        # Take additional sample from the best arm
        if n_sample < n_star:
            total_sample += sum(env.multi_pull(best_arm, n_star - n_sample))
            n_sample = n_star
            self._additional_sample = n_star - n_sample

        return total_sample / n_sample


class SuccessiveEliminationBME:
    """Successive Elimination PAC Best Mean-reward Estimator"""

    def __init__(self, epsilon, delta, support=[0, 1]):
        self._epsilon = epsilon
        self._delta = delta
        self._support = support

        self._sample = defaultdict(list)
        self._n = defaultdict(list)

    def get_total_n_sample(self):
        return sum(self._n_sample)

    def run(self, env, detail=False):
        const = np.pi**2 / 6
        n_arms = env.n_arms()

        step = 0
        total_sample = np.zeros(n_arms)
        n_sample = np.zeros(n_arms).astype(int)
        arm_set = set([i for i in range(n_arms)])

        alpha = self._support[1] - self._support[0]
        while alpha > self._epsilon:
            step += 1
            for arm in arm_set:
                sample = env.pull(arm)
                total_sample[arm] += sample
                n_sample[arm] += 1
                if detail:
                    self._sample[arm].append(sample)
                    self._n[arm].append(n_sample[arm])

            sample_average = total_sample / n_sample
            max_average = max(sample_average)
            alpha = np.sqrt(np.log(2 * const * n_arms * step**2 / self._delta) / (2 * step))
            alpha *= self._support[1] - self._support[0]  # scaler
            to_remove = np.where(max_average - sample_average >= 2 * alpha)[0]
            arm_set -= set(to_remove)

        self._n_sample = n_sample
        self._total_sample = total_sample

        remaining_arms = list(arm_set)
        best_reward = np.max([sample_average[arm] for arm in remaining_arms])
        return best_reward

    def get_mean_reward(self, arm):
        # get sample mean reward from the arm
        if len(self._sample[arm]) == 0:
            return 0
        else:
            return sum(self._sample[arm]) / self._n_sample[arm]
    
    def get_best_mean_reward(self):
        averages = [self.get_mean_reward(arm) for arm in self._sample]
        return max(averages)