import numpy as np
from math import log, e, sqrt
from scipy.optimize import minimize
from src.core.bandit import BanditAlgorithm
from src.utils import sigmoid

def _nll_core(a, r):

   
    lsig     = -np.logaddexp(0.0, -a)
    log1msig = -a - np.logaddexp(0.0, -a)
    return -np.sum(r * lsig + (1.0 - r) * log1msig)


def _nll_grad_core(arms, a, r):
    sig  = 1.0 / (1.0 + np.exp(-a))
    diff = sig - r
    return arms.T @ diff


try:
    from numba import njit
    _nll_core      = njit(_nll_core,      fastmath=True, cache=True)
    _nll_grad_core = njit(_nll_grad_core, fastmath=True, cache=True)
except ModuleNotFoundError:
    pass

class OFUGLB(BanditAlgorithm):
    def __init__(self, num_actions, horizon, dim,
                 param_norm_ub, arm_norm_ub, failure_level,
                 tol=1e-3, lazy_update_fr=5, n_fast=1):
        super().__init__(num_actions, horizon)
        self.d = dim; self.S = param_norm_ub; self.delta = failure_level
        self.tol = tol
        self.lazy_update_fr = max(1, int(lazy_update_fr))
        self.n_fast = max(1, int(n_fast))

        self.theta_hat = np.zeros(dim)
        self.ucb_bonus = 0.0
        self.log_loss_hat = 0.0
        self.ctr = 1
        self.arms = np.empty((0, dim))
        self.rewards = np.empty((0,))

        self.init_params = dict(num_actions=num_actions, horizon=horizon,
                                dim=dim, param_norm_ub=param_norm_ub,
                                arm_norm_ub=arm_norm_ub,
                                failure_level=failure_level, tol=tol,
                                lazy_update_fr=lazy_update_fr, n_fast=n_fast)


    def _nll(self, th):
        if self.rewards.size == 0:
            return 0.0
        a = self.arms @ th
        return _nll_core(a, self.rewards)

    def _nll_J(self, th):
        if self.rewards.size == 0:
            return np.zeros_like(th)
        a = self.arms @ th
        return _nll_grad_core(self.arms, a, self.rewards)


    def select_arm(self, arm_set):
        arms = np.asarray(arm_set, dtype=float)
        self.all_arms = arms
        self.update_ucb_bonus()
        self.log_loss_hat = self._nll(self.theta_hat)

        if self.ctr == 1:
            idx = int(np.argmax(np.linalg.norm(arms, axis=1)))
            return idx

        quick_scores = arms @ self.theta_hat + self.S * np.linalg.norm(arms,
                                                                       axis=1)

        top_idx = np.argpartition(-quick_scores, self.n_fast)[:self.n_fast]
        scores = quick_scores.copy()

        for i in top_idx:
            x = arms[i]
            obj  = lambda th: -x @ th
            objJ = lambda th: -x
            cons = {'type': 'ineq',
                    'fun': lambda th: np.array([
                        self.ucb_bonus - (self._nll(th) - self.log_loss_hat),
                        self.S**2 - th @ th]),
                    'jac': lambda th: -np.vstack((self._nll_J(th), 2.0 * th))}

            res = minimize(obj, self.theta_hat, method='SLSQP', jac=objJ,
                           constraints=cons, tol=self.tol,
                           options={'maxiter': 40})

            val = -res.fun if res.success else x @ self.theta_hat
            scores[i] = val

        return int(np.argmax(scores))

    def update_statistics(self, idx, reward):
        x = self.all_arms[idx].astype(float)
        r = float(reward)

        self.arms    = np.vstack((self.arms, x))
        self.rewards = np.append(self.rewards, r)

        if self.ctr % self.lazy_update_fr == 0:
            cons = {'type': 'ineq',
                    'fun': lambda th: self.S**2 - th @ th,
                    'jac': lambda th: -2.0 * th}

            res = minimize(self._nll, self.theta_hat, jac=self._nll_J,
                           method='SLSQP', constraints=cons, tol=self.tol,
                           options={'maxiter': 60})
            if res.success:
                self.theta_hat = res.x

        self.ctr += 1

    def update_ucb_bonus(self):
        Lt = (1 + self.S / 2) * (len(self.rewards) - 1)
        self.ucb_bonus = log(1 / self.delta) + self.d * log(
            max(e, 2 * e * self.S * Lt / self.d)
        )


    def reset(self):   self.__init__(**self.init_params)
    def re_init(self): self.reset()
    def __str__(self): return "OFUGLB"
