from abc import ABCMeta, abstractmethod
import numpy as np
from collections import defaultdict


class PACBAI(metaclass=ABCMeta):
    """PAC Best Arm Identifier"""
    def __init__(self, epsilon, delta):
        self._epsilon = epsilon
        self._delta = delta
        
        self._n_sample = np.empty(0)
        self._total_sample = np.empty(0)

    def get_n_sample(self):
        return self._n_sample

    def get_total_n_sample(self):
        return sum(self._n_sample)

    def get_total_sample(self):
        return self._total_sample

    def get_sample_average(self):
        sample_average = self._total_sample / np.maximum(self._n_sample, 1)
        sample_average *= (self._n_sample != 0)
        return sample_average

    def set_stats(self, n_sample, total_sample, n_arms):
        self._n_sample = np.zeros(n_arms).astype(int)
        self._total_sample = np.zeros(n_arms)
        for arm in range(n_arms):
            self._n_sample[arm] = n_sample[arm]
            self._total_sample[arm] = total_sample[arm]


class UGapEc(PACBAI):

    def run(self, env, b=1., c=0.5, tmax=10000000):
        # reward in bounded in [0, b]
        # c: exploration parameter
        # we only consider m=1
        self._b = b
        self._c = c
        self._n_arms = env.n_arms()
        self._n_sample = np.zeros(self._n_arms).astype(int)
        self._total_sample = np.zeros(self._n_arms)

        for t in range(tmax):
            if t < self._n_arms:
                arm = t
                sample = env.pull(arm)
                self._total_sample[arm] += sample
                self._n_sample[arm] += 1
                continue

            # SELECT-ARM

            # Compute B_k(t) for each arm k
            beta = self._c * np.log(4 * self._n_arms * t**3 / self._delta)
            beta = self._b * np.sqrt(beta / self._n_sample)

            mu = self._total_sample / self._n_sample
            U = mu + beta
            L = mu - beta
            max_U = np.max(U)
            second_max_U = np.max(U[U != max_U]) if np.count_nonzero(U == max_U) == 1 else max_U
            V = np.where(U == max_U, second_max_U, max_U)  # max_{k
            B = V - L

            # Identify the arm
            J = np.argmin(B)

            if B[J] < self._epsilon:
                return J
            
            # Identify u and ell
            mask = np.full(len(U), True)
            mask[J] = False
            original_indices = np.arange(len(U))[mask]
            # Find the argmin among the remaining elements
            mask2 = (U[mask] == np.max(U[mask]))
            u = np.argmax(beta[mask][mask2])
            u = original_indices[mask2][u]
            ell = J

            if beta[u] > beta[ell]:
                arm = u
            else:
                arm = ell

            sample = env.pull(arm)
            self._total_sample[arm] += sample
            self._n_sample[arm] += 1

