import numpy as np
from typing import Optional
from math import log, sqrt
from copy import deepcopy

from src.core.bandit import BanditAlgorithm
from src.core.detection import detect_change


class DAL_SCB(BanditAlgorithm):
    def __init__(
        self,
        T: int,
        delta: float,
        *,
        base_factory=None,
        base_kwargs=None,
        base_algorithm: Optional[BanditAlgorithm] = None,
        detector=None,
    ):
        if base_algorithm is not None:
            if hasattr(base_algorithm, "init_params"):
                base_factory = base_algorithm.__class__
                base_kwargs  = deepcopy(base_algorithm.init_params)
            else:
                raise ValueError("Pass a factory/kwargs, or wrap an algo with "
                                 "an `init_params` attribute.")
        if base_factory is None:
            raise ValueError("Must supply `base_factory` or `base_algorithm`.")

        self._factory  = base_factory
        self._kw       = base_kwargs or {}
        self.base      = self._make_base()          

        super().__init__(self.base.num_actions, T)

        self.delta = delta
        self.counts = np.zeros(self.num_actions, dtype=np.int64)
        self.sums   = [ [] for _ in range(self.num_actions) ]
        self.tau    = 0
        self.epoch  = 1

        self.indep_arms = None
        self.N_e   = None
        self.alpha = None
        self.expl_freq = None
        self.change_detector = detector or detect_change

        self.init_params = dict(
            T=T,
            delta=delta,
            base_factory=base_factory,
            base_kwargs=deepcopy(self._kw),
            detector=self.change_detector,
        )

 
    def _make_base(self):
        return self._factory(**self._kw)

    def select_arm(self, arms,change_points):
        arms = np.asarray(arms, dtype=float)
        cond=True
        if self.indep_arms is not None:
            cond=np.allclose(self.base.all_arms[0],arms[0],rtol=1e-4)

        if self.indep_arms is None or not cond:
            self.base.all_arms = arms
            self.base.get_indep_arms()
            self.indep_arms = self.base.indep_arms
            self.N_e = len(self.indep_arms)
            
        self.alpha =  0.001*sqrt(self.epoch * self.N_e * log(self.T) / self.T)
        self.expl_freq = max(1, int(np.ceil(self.N_e / self.alpha)))

        phase = (self.t - self.tau) % self.expl_freq
        if phase < self.N_e:
            target_vec = self.indep_arms[phase]
            chosen = int(np.flatnonzero((arms == target_vec).all(axis=1))[0])
        else:
            chosen = self.base.select_arm(arms)
        self.cps=change_points

        self.chosen_arm = chosen
        return chosen

    def update_statistics(self, arm, reward):
        self.base.update_statistics(arm, reward)

        self.counts[arm] += 1
        if self.sums[arm]:
            self.sums[arm].append(self.sums[arm][-1] + reward)
        else:
            self.sums[arm].append(reward)

        nb = self.counts[arm]
        if self.change_detector(nb, self.sums[arm], self.delta):
            self._restart()

    def _restart(self):
        self.epoch += 1
        self.tau    = self.t

      
        self.base = self._make_base()

        self.counts[:] = 0
        self.sums      = [ [] for _ in range(self.num_actions) ]
        self.indep_arms = None        
        self.N_e  = None

    def reset(self):   self.__init__(**self.init_params)
    def re_init(self): self.reset()
    def __str__(self): return f"DAL-SCB({self.base})"

    @property
    def theta_hat(self):
        return getattr(self.base, "theta_hat", None)
