import numpy as np
from math import log, sqrt
from numba import njit
from copy import deepcopy
from BanditAlgorithm import BanditAlgorithm

@njit
def _kl(p, q):
    eps = 1e-12
    p = min(max(p, eps), 1 - eps)
    q = min(max(q, eps), 1 - eps)
    return p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))

@njit
def _beta(n, delta):
    return np.log(n ** 1.5 / delta)

@njit
def _detect(nb, cumsums, delta):
    s = 1
    while s < nb:
        mu1 = cumsums[s-1] / s
        mu2 = (cumsums[nb-1] - cumsums[s-1]) / (nb - s)
        mu  = cumsums[nb-1] / nb
        val = s * _kl(mu1, mu) + (nb-s) * _kl(mu2, mu)
        if val > _beta(nb, delta):
            return 1
        s += 1
    return 0


class DAL(BanditAlgorithm):
   

    def __init__(self,
                 T: int,
                 delta: float,
                 *,
                 base_factory=None,
                 base_kwargs=None,
                 base_algorithm: BanditAlgorithm = None):
      
        if base_algorithm is not None:
            if hasattr(base_algorithm, "init_params"):
                base_factory = base_algorithm.__class__
                base_kwargs  = deepcopy(base_algorithm.init_params)
            else:
                raise ValueError("Pass a factory/kwargs, or wrap an algo with "
                                 "an `init_params` attribute.")
        if base_factory is None:
            raise ValueError("Must supply `base_factory` or `base_algorithm`.")

        self._factory  = base_factory
        self._kw       = base_kwargs or {}
        self.base      = self._make_base()          

        super().__init__(self.base.num_actions, T)

        self.delta = delta
        self.counts = np.zeros(self.num_actions, dtype=np.int64)
        self.sums   = [ [] for _ in range(self.num_actions) ]
        self.tau    = 0
        self.epoch  = 1

        self.indep_arms = None
        self.N_e   = None
        self.alpha = None
        self.expl_freq = None

        self.init_params = dict(T=T, delta=delta,
                                base_factory=base_factory,
                                base_kwargs=deepcopy(self._kw))

 
    def _make_base(self):
        return self._factory(**self._kw)

    def select_arm(self, arms,change_points):
        arms = np.asarray(arms, dtype=float)
        cond=True
        if self.indep_arms is not None:
            cond=np.allclose(self.base.all_arms[0],arms[0],rtol=1e-4)

        if self.indep_arms is None or not cond:
            self.base.all_arms = arms
            self.base.get_indep_arms()
            self.indep_arms = self.base.indep_arms
            self.N_e = len(self.indep_arms)
            
        self.alpha =  0.001*sqrt(self.epoch * self.N_e * log(self.T) / self.T)
        self.expl_freq = max(1, int(np.ceil(self.N_e / self.alpha)))

        phase = (self.t - self.tau) % self.expl_freq
        if phase < self.N_e:
            target_vec = self.indep_arms[phase]
            chosen = int(np.flatnonzero((arms == target_vec).all(axis=1))[0])
        else:
            chosen = self.base.select_arm(arms)
        self.cps=change_points

        self.chosen_arm = chosen
        return chosen

    def update_statistics(self, arm, reward):
        self.base.update_statistics(arm, reward)

        self.counts[arm] += 1
        if self.sums[arm]:
            self.sums[arm].append(self.sums[arm][-1] + reward)
        else:
            self.sums[arm].append(reward)

        nb = self.counts[arm]
        if nb > 2 and _detect(nb, np.array(self.sums[arm], dtype=np.float64),
                              self.delta):
            self._restart()

    def _restart(self):
        self.epoch += 1
        self.tau    = self.t

      
        self.base = self._make_base()

        self.counts[:] = 0
        self.sums      = [ [] for _ in range(self.num_actions) ]
        self.indep_arms = None        
        self.N_e  = None

    def reset(self):   self.__init__(**self.init_params)
    def re_init(self): self.reset()
    def __str__(self): return f"DAB({self.base})"

    @property
    def theta_hat(self):
        return getattr(self.base, "theta_hat", None)
