import numpy as np
from scipy.stats import norm
from scipy.optimize import fsolve, minimize
from sklearn.linear_model import LogisticRegression

## For quick update of Vinv
def sherman_morrison(X, V, w=1):
    result = V-(w*np.einsum('ij,j,k,kl -> il', V, X, X, V))/(1.+w*np.einsum('i,ij,j ->', X, V, X))
    return result

def uniform_ref(contexts):
    n = len(contexts)
    return np.full(n, 1.0 / n, dtype=float)


class UCB:
    def __init__(self, d, alpha, lam=1):
        self.alpha=alpha
        self.d=d
        self.yx=np.zeros(d)
        self.Binv=lam*np.eye(d)
        self.beta_hat = np.zeros(d)
        self.settings = {'alpha': self.alpha}

    def select_ac(self, contexts):
        means = np.array([np.dot(X, self.beta_hat) for X in contexts])
        stds = np.array([np.sqrt(X.T @ self.Binv @ X) for X in contexts])
        ucbs = means + self.alpha*stds
        a_t = np.argmax(ucbs)
        self.X_a = contexts[a_t]
        return(a_t)

    def update(self,reward):
        self.Binv = sherman_morrison(self.X_a, self.Binv)
        self.yx = self.yx+reward*self.X_a
        self.beta_hat = self.Binv @ self.yx


class TS:
    def __init__(self, d, v):
        ## Initialization
        self.beta_hat=np.zeros(d)
        self.f=np.zeros(d)
        self.Binv=np.eye(d)
        self.t = 0

        ## Hyperparameters
        self.v=v
        self.settings = {'v': self.v}

    def select_ac(self,contexts):
        ## Sample beta_tilde.
        N=len(contexts)
        V=(self.v**2)*self.Binv
        beta_tilde=np.random.multivariate_normal(self.beta_hat, V, size=N)
        est=np.array([np.dot(contexts[i], beta_tilde[i,]) for i in range(N)])
        ## Selecting action with tie-breaking.
        a_t=np.argmax(est)
        self.X_a=contexts[a_t]
        return(a_t)

    def update(self,reward):
        self.f=self.f+reward*self.X_a
        self.Binv = sherman_morrison(X=self.X_a, V=self.Binv)
        self.beta_hat=np.dot(self.Binv, self.f)


class PHE:
    def __init__(self, d, alpha, lam=1):
        self.alpha=alpha
        self.d=d
        self.yx=np.zeros(d)
        self.Binv=lam*np.eye(d)
        self.beta_hat = np.zeros(d)
        self.settings = {'alpha': self.alpha}
        self.context_list = []
        self.reward_list = []

    def select_ac(self, contexts):
        scores = np.array([np.dot(X, self.beta_hat) for X in contexts])
        a_t = np.argmax(scores)
        self.X_a = contexts[a_t]
        self.context_list.append(self.X_a)
        return(a_t)

    def update(self,reward):
        self.reward_list.append(reward[0])
        self.noise = np.random.normal(0, self.alpha, size=(len(self.reward_list)))
        pseudo_reward = np.array(self.reward_list) + self.noise
        pseudo_reward = np.repeat(pseudo_reward, self.d).reshape(-1, self.d)

        self.Binv = sherman_morrison(self.X_a, self.Binv)
        self.yx = np.sum(np.multiply(np.array(self.context_list), pseudo_reward), axis=0)
        self.beta_hat = self.Binv @ self.yx
        
        
class KL_EXP:
    def __init__(self, d, eta, ref_policy='uniform', lam=1.0):
        self.d = d
        self.eta = float(eta)
        self.Binv = lam * np.eye(d)       # (X^T X + lam I)^{-1} init
        self.yx = np.zeros(d)             # sum r_t x_t
        self.beta_hat = np.zeros(d)
        if ref_policy == 'uniform':
            self.ref_policy = uniform_ref
            ref_name = 'uniform'
        elif callable(ref_policy):
            self.ref_policy = ref_policy
            ref_name = getattr(ref_policy, '__name__', 'callable_ref')
        else:
            raise ValueError("ref_policy must be 'uniform' or a callable(contexts)->prob vector")
        self.settings = {'eta': self.eta, 'lam': lam, 'ref_policy': ref_name}

        # cache per step
        self.X_a = None

    def _policy_probs(self, contexts):
        """
        contexts: array-like of shape (N, d)
        returns: probs over N actions
        """
        contexts = np.asarray(contexts)
        scores = contexts @ self.beta_hat  # shape (N,)
        # reference probabilities
        p_ref = self.ref_policy(contexts)  # shape (N,), sums to 1

        # numerical stability: work in log-space
        # log w_a = log p_ref(a) + eta * score(a)
        logw = np.log(np.clip(p_ref, 1e-32, 1.0)) + self.eta * scores
        logw -= np.max(logw)
        w = np.exp(logw)
        probs = w / np.sum(w)
        return probs

    def select_ac(self, contexts):
        probs = self._policy_probs(contexts)
        a_t = np.random.choice(len(contexts), p=probs)
        self.X_a = np.asarray(contexts[a_t])
        return a_t

    def update(self, reward):
        # RLS update with chosen feature self.X_a and scalar reward
        self.Binv = sherman_morrison(self.X_a, self.Binv)
        self.yx += float(reward) * self.X_a
        self.beta_hat = self.Binv @ self.yx
        

class SupLinUCB:
    def __init__(self, d, T, alpha, lam=1.0, S=None, seed=None):
        self.d, self.T = int(d), int(T)
        self.alpha, self.lam = float(alpha), float(lam)
        self.S = int(np.ceil(np.log(self.T))) if S is None else int(S)

        # per-level BaseLinUCB stats
        self.Binv = [(1.0/self.lam) * np.eye(self.d) for _ in range(self.S)]
        self.yx = [np.zeros(self.d) for _ in range(self.S)]
        self.beta = [np.zeros(self.d) for _ in range(self.S)]

        self._last_ctx = None
        self._last_level = None
        self._last_case = None
        self.rng = np.random.default_rng(seed)
        self.t = 0
        self.settings = {'alpha': self.alpha, 'lam': self.lam, 'S': self.S, 'T': self.T}

    # ---------- BaseLinUCB at level s ----------
    def _means_widths_ucb(self, X, s):
        beta_s, Binv_s = self.beta[s], self.Binv[s]
        means = X @ beta_s
        # Alg.2 line 6: w = α * sqrt(x^T A^{-1} x)
        widths = self.alpha * np.sqrt(np.einsum('ij,jk,ik->i', X, Binv_s, X))
        ucbs = means + widths
        return means, widths, ucbs

    # ---------- action selection (Alg.3) ----------
    def select_ac(self, contexts):
        self.t += 1
        contexts = np.asarray(contexts, dtype=float)
        K, d = contexts.shape
        assert d == self.d

        s = 0                           # (paper uses 1..S; code uses 0..S-1)
        A_hat = np.arange(K)

        while True:
            X = contexts[A_hat]
            means, widths, ucbs = self._means_widths_ucb(X, s)

            # line 7: if all w^s_{t,a} ≤ 1/√T
            if np.all(widths <= 1.0 / np.sqrt(self.T)):
                a_local = int(np.argmax(ucbs))
                a_t = int(A_hat[a_local])
                self._last_ctx = contexts[a_t]
                self._last_level = None   # keep Ψ unchanged (Alg.3 line 9)
                self._last_case = 'greedy'
                return a_t

            # line 10: if all w^s_{t,a} ≤ 2^{-s}
            thr = 2.0 ** (-(s+1))        # 0-based: 2^{-(s_code+1)}
            if np.all(widths <= thr):
                max_ucb = float(np.max(ucbs))
                # line 11: keep near-optimal (r_hat + w) ≥ max - 2^{1-s}
                tol = 2.0 ** (-(s))      # 0-based for 2^{1-(s+1)} = 2^{-s}
                keep = (means + widths) >= (max_ucb - tol)
                A_hat = A_hat[keep]
                s += 1
                if s >= self.S:
                    # safety fallback: pick greedy now
                    a_t = int(A_hat[np.argmax((means + widths)[keep])])
                    self._last_ctx = contexts[a_t]
                    self._last_level = None
                    self._last_case = 'greedy'
                    return a_t
                continue

            # else (line 14): pick any a with w^s_{t,a} > 2^{-s} (choose the largest width)
            candidates = np.where(widths > thr)[0]
            if candidates.size == 0:
                candidates = np.arange(A_hat.size)
            a_local = int(candidates[np.argmax(widths[candidates])])
            a_t = int(A_hat[a_local])

            self._last_ctx = contexts[a_t]
            self._last_level = s         # update only level s (Alg.3 line 15)
            self._last_case = 'update'
            return a_t
        
    # ---------- update with observed reward ----------
    def update(self, reward):
        # If line 7 branch was taken, nothing to update (Ψ unchanged).
        if self._last_case != 'update':
            return

        s = self._last_level
        x = self._last_ctx

        Binv_s = self.Binv[s]
        yx_s   = self.yx[s]

        Binv_s = sherman_morrison(x, Binv_s)
        yx_s   = yx_s + reward * x
        beta_s = Binv_s @ yx_s

        self.Binv[s] = Binv_s
        self.yx[s]   = yx_s
        self.beta[s] = beta_s

        self._last_case = self._last_level = self._last_ctx = None