from abc import ABCMeta, abstractmethod
import numpy as np
from collections import defaultdict


class PACBAI(metaclass=ABCMeta):
    """PAC Best Arm Identifier"""
    def __init__(self, epsilon, delta):
        self._epsilon = epsilon
        self._delta = delta
        
        self._n_sample = np.empty(0)
        self._total_sample = np.empty(0)

    def get_n_sample(self):
        return self._n_sample

    def get_total_n_sample(self):
        return sum(self._n_sample)

    def get_total_sample(self):
        return self._n_sample

    def get_sample_average(self):
        sample_average = self._total_sample / np.maximum(self._n_sample, 1)
        sample_average *= (self._n_sample != 0)
        return sample_average

    def set_stats(self, n_sample, total_sample, n_arms):
        self._n_sample = np.zeros(n_arms).astype(int)
        self._total_sample = np.zeros(n_arms)
        for arm in range(n_arms):
            self._n_sample[arm] = n_sample[arm]
            self._total_sample[arm] = total_sample[arm]


def _AggressiveElimination(
    env,
    epsilon,
    delta,
    arm_set=None,
    total_sample=defaultdict(float),
    n_sample=defaultdict(int)
):
    if arm_set is None:
        arm_set = set([i for i in range(env.n_arms())])
    n_arms = len(arm_set)

    # phi: small fraction s.t. top delta+phi fraction of arms include the best arm
    phi = np.sqrt(6 * np.log(n_arms) / n_arms**(3 / 4))

    # number of iterations to reduce the number of arms to n^(3/4) / 2
    n_iterations = (np.log(n_arms) + 4 * np.log(2)) / (4 * np.log(1 / (delta + phi)))
    n_iterations = int(np.ceil(n_iterations)) + 1

    n_base = (2 / epsilon**2) * np.log(1 / delta)
    n_base = int(np.ceil(n_base))

    arm_set = set([i for i in range(n_arms)])
    for i in range(n_iterations):
        # take n samples from each remaining arm
        n = (i + 1) * n_base
        for arm in arm_set:
            total_sample[arm] += np.sum(env.multi_pull(arm, n))
            n_sample[arm] += n

        # eliminate poor arms
        n_remove = len(arm_set) * (1 - (delta + phi))
        n_remove = int(np.ceil(n_remove))
        arm_list = list(arm_set)
        sorted_idx = np.argsort([total_sample[arm] for arm in arm_list])
        to_remove = [arm_list[idx] for idx in sorted_idx[:n_remove]]
        arm_set -= set(to_remove)
        
    return {
        "remaining_arms": arm_set,
        "total_sample": total_sample,
        "n_sample": n_sample
    }


def _NaiveElimination(
    env,
    epsilon,
    delta,
    arm_set=None,
    total_sample=defaultdict(float),
    n_sample=defaultdict(int)
):
    if arm_set is None:
        arm_set = set([i for i in range(env.n_arms())])
    n_arms = len(arm_set)

    # determine the sample size
    n_sample_needed = 2 / epsilon**2 * np.log(n_arms / delta)
    n_sample_needed = int(np.ceil(n_sample_needed))

    # pull arms
    max_average = -np.infty
    best_arm = None
    for arm in arm_set:
        if n_sample[arm] < n_sample_needed:
            total_sample[arm] += np.sum(
                env.multi_pull(arm, n_sample_needed - n_sample[arm])
            )
            n_sample[arm] = n_sample_needed
        sample_average = total_sample[arm] / n_sample[arm]
        if sample_average > max_average:
            max_average = sample_average
            best_arm = arm

    return {
        "best_arm": best_arm,
        "total_sample": total_sample,
        "n_sample": n_sample
    }


class NaiveElimination(PACBAI):
    """Naive PAC Best Arm Identifier"""
    def run(self, env):
        output = _NaiveElimination(env, self._epsilon, self._delta)

        self.set_stats(output["n_sample"], output["total_sample"], env.n_arms())

        return output["best_arm"]

class SuccessiveElimination(PACBAI):
    """Successive Elimination PAC Best Arm Identifier"""
    def run(self, env):
        const = np.pi**2 / 6
        n_arms = env.n_arms()

        step = 0
        total_sample = np.zeros(n_arms)
        n_sample = np.zeros(n_arms).astype(int)
        arm_set = set([i for i in range(n_arms)])

        while len(arm_set) > 1:
            step += 1
            for arm in arm_set:
                total_sample[arm] += env.pull(arm)
                n_sample[arm] += 1

            sample_average = total_sample / n_sample
            max_average = max(sample_average)
            alpha = np.sqrt(np.log(const * n_arms * step**2 / self._delta) / (2 * step))
            to_remove = np.where(max_average - sample_average >= 2 * alpha)[0]
            arm_set -= set(to_remove)

            if alpha <= self._epsilon / 2:
                break

        self._n_sample = n_sample
        self._total_sample = total_sample

        remaining_arms = list(arm_set)
        best_idx = np.argmax([sample_average[arm] for arm in remaining_arms])
        return remaining_arms[best_idx]


class GeneralizedSuccessiveElimination(PACBAI):
    """Generalized Successive Elimination PAC Best Arm Identifier"""
    def run(self, env, tau):
        const = np.pi**2 / 6
        n_arms = env.n_arms()

        step = 0
        total_sample = np.zeros(n_arms)
        n_sample = np.zeros(n_arms).astype(int)
        arm_set = set([i for i in range(n_arms)])

        while len(arm_set) > 1:
            step += 1
            alpha = np.sqrt(np.log(const * n_arms * step**2 / self._delta) / (2 * step * tau))
            if alpha < self._epsilon / 2:
                alpha = self._epsilon / 2
                max_n = 2 / self._epsilon**2 * np.log(const * n_arms * step**2 / self._delta) 
                max_n = int(np.ceil(max_n))
                for arm in arm_set:
                    total_sample[arm] += sum(env.multi_pull(arm, max_n - n_sample[arm]))
                    n_sample[arm] = max_n
            else:
                for arm in arm_set:
                    # pull each arm for tau times
                    total_sample[arm] += sum(env.multi_pull(arm, tau))
                    n_sample[arm] += tau

            sample_average = total_sample / n_sample
            max_average = max(sample_average)
            to_remove = np.where(max_average - sample_average >= 2 * alpha)[0]
            arm_set -= set(to_remove)

            if alpha <= self._epsilon / 2:
                break

        self._n_sample = n_sample
        self._total_sample = total_sample

        remaining_arms = list(arm_set)
        best_idx = np.argmax([sample_average[arm] for arm in remaining_arms])
        return remaining_arms[best_idx]


class MedianElimination(PACBAI):
    """Median Elimination PAC Best Arm Identifier"""
    def run(self, env):
        epsilon = self._epsilon / 4
        delta = self._delta / 2
        n_arms = env.n_arms()
        
        total_sample = np.zeros(n_arms)
        n_sample = np.zeros(n_arms).astype(int)
        arm_set = set([i for i in range(n_arms)])
        
        while len(arm_set) > 1:
            n = (2/epsilon)**2 * np.log(3/delta)
            n = int(np.ceil(n))
            for arm in arm_set:
                total_sample[arm] += np.sum(env.multi_pull(arm, n - n_sample[arm]))
                n_sample[arm] = n

            sample_average = total_sample / n_sample
            
            median = np.median(sample_average[np.array(list(arm_set))])
            to_remove = np.where(sample_average < median)[0]
            arm_set -= set(to_remove)

            epsilon *= 3. / 4
            delta /= 2 
        
        self._n_sample = n_sample
        self._total_sample = total_sample

        return arm_set.pop()


class SimpleApproximateBestArm(PACBAI):
    """This assumes
    - there is a single epsilon-best arm
    - delta <= 0.05
    - n >= max{1/delta^4, 10^5}
    """
    def run(self, env):
        intermediate = _AggressiveElimination(env, self._epsilon, self._delta / 2)

        output = _NaiveElimination(
            env,
            self._epsilon,
            self._delta / np.e,
            arm_set=intermediate["remaining_arms"],
            total_sample=intermediate["total_sample"],
            n_sample=intermediate["n_sample"]
        )

        self.set_stats(output["n_sample"], output["total_sample"], env.n_arms())

        return output["best_arm"]


class ApproximateBestArm(PACBAI):
    """This assumes
    - delta <= 0.05
    """
    def run(self, env, rng):
        n_arms = env.n_arms()

        if n_arms < max(10**5, self._delta**(-4)):
        
            output = _NaiveElimination(env, self._epsilon, self._delta)

        else:

            alpha = 1 - 1 / np.e
            
            intermediate = _AggressiveElimination(
                env,
                self._epsilon * alpha,
                self._delta / 2
            )
            
            n_random = n_arms**(7 / 8) / 2
            random_arms = rng.choice([i for i in range(n_arms)], n_random, replace=False)
        
            output = _NaiveElimination(
                env,
                self._epsilon * (1 - alpha),
                self._delta / np.e,
                arm_set=intermediate["remaining_arms"].union(random_arms),
                total_sample=intermediate["total_sample"],
                n_sample=intermediate["n_sample"]
            )

        self.set_stats(output["n_sample"], output["total_sample"], env.n_arms())

        return output["best_arm"]


class ApproximateBestArmLikelihoodEstimation(PACBAI):

    def run(self, env, rng, lam):
        arm_set = set([i for i in range(env.n_arms())])
        n_arms = len(arm_set)
        total_sample = np.zeros(n_arms)
        n_sample = np.zeros(n_arms).astype(int)

        # pull each arm for n times
        n = (1 + lam / 2) / (2 * self._epsilon**2) * np.log(1 / self._delta)
        n = int(n)
        for arm in range(n_arms):
            total_sample[arm] = np.sum(env.multi_pull(arm, n))
            n_sample[arm] = n

        # keep the best lam * n_arms / 50 arms
        sorted_idx = np.argsort(total_sample)
        n_remove = n_arms * (1 - lam / 50)
        n_remove = int(n_remove)
        arm_set -= set(sorted_idx[:n_remove])

        # Agressive Elimination
        alpha = np.sqrt(1 - lam / 8)
        intermediate = _AggressiveElimination(
            env,
            self._epsilon * alpha,
            self._delta / 4,
            arm_set=arm_set,
            total_sample=total_sample,
            n_sample=n_sample
        )

        # Naive Elimination
        n_random = n_arms**(3 / 4)
        n_random = int(n_random)
        random_arms = rng.choice([i for i in range(n_arms)], n_random, replace=False)

        output = _NaiveElimination(
            env,
            self._epsilon * (1 - alpha),
            self._delta / 4,
            arm_set=intermediate["remaining_arms"].union(random_arms),
            total_sample=intermediate["total_sample"],
            n_sample=intermediate["n_sample"]
        )

        self.set_stats(output["n_sample"], output["total_sample"], env.n_arms())
    
        return output["best_arm"]
    