from abc import ABCMeta, abstractmethod
import numpy as np


class Arm(metaclass=ABCMeta):

    def __init__(self, batch_size, rng):
        self._batch_size = batch_size
        self._rng = rng
        self._samples = np.empty(0)
        self._count = 0
        
    @abstractmethod
    def _sample_batch(self, batch_size):
        # implement a method to sample a batch
        pass

    def pull(self):
        if self._count == len(self._samples):
            self._sample_batch(self._batch_size)
            self._count = 0
        value = self._samples[self._count]
        self._count += 1
        return value

    def multi_pull(self, n):
        n_remain = len(self._samples) - self._count
        n_need = n - n_remain
        if n_need <= 0:
            # sufficient samples remaining
            values = self._samples[self._count:self._count + n]
            self._count += n
        else:
            # insufficient samples remaining
            values = np.copy(self._samples[self._count:])
            if n_need < self._batch_size:
                # need less samples than batch size
                self._sample_batch(self._batch_size)
            else:
                self._sample_batch(n_need)
            values = np.hstack([values, self._samples[:n_need]])
            self._count = n_need
        return values


class BernoulliArm(Arm):

    def __init__(self, mean, batch_size=1000, rng=np.random.default_rng(42)):
        self._mean = mean
        super().__init__(batch_size, rng)

    def _sample_batch(self, batch_size=None):
        if batch_size is None:
            batch_size = self._batch_size
        self._samples = self._rng.random(batch_size) < self._mean


class GaussianArm(Arm):

    def __init__(self, mean, std, batch_size=1000, rng=np.random.default_rng(42)):
        self._mean = mean
        self._std = std
        super().__init__(batch_size, rng)

    def _sample_batch(self, batch_size=None):
        if batch_size is None:
            batch_size = self._batch_size
        self._samples = self._rng.normal(loc=self._mean, scale=self._std, size=batch_size)


class MAB(metaclass=ABCMeta):
    """Multi Armed Bandit"""

    def __init__(self, arms):
        self._arms = arms
        self._n_arms = len(arms)

    def pull(self, arm):
        assert arm < self._n_arms
        return self._arms[arm].pull()

    def multi_pull(self, arm, n):
        assert arm < self._n_arms
        return self._arms[arm].multi_pull(n)        
    
    def n_arms(self):
        return self._n_arms
        

class BernoulliMAB(MAB):

    def __init__(self, means, rng, batch_size=1000):
        arms = [BernoulliArm(mean, batch_size=batch_size, rng=rng) for mean in means]
        super().__init__(arms)


class GaussianMAB(MAB):

    def __init__(self, means, stds, rng, batch_size=1000):
        arms = [GaussianArm(mean, std, batch_size=batch_size, rng=rng) for mean, std in zip(means, stds)]
        super().__init__(arms)
