import numpy as np
from numpy.linalg import inv, cholesky, solve
from scipy.linalg import cho_solve


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


class LogBanditAlg:
    def __init__(self, env, n, params):
        self.env = env
        self.rng = env.rng
        self.K = env.K
        self.d = env.d
        self.n = n
        # Hyperparameters
        self.forced_exploration = params.get('forced_exploration', self.d)
        self.crs = params.get('crs', 1.0)
        self.delta = params.get('delta', 0.01)
        self.lam0 = params.get('lam0', 1e-4)
        self.S = params.get('S', env.norm_theta)
        self.beta_flag = params.get('beta', True)
        # Data history
        self.A = []  # list of past arms
        self.R = []  # list of past rewards
        # Initial theta
        self.theta = np.zeros(self.d)
        self.Gram_inv = np.eye(self.d)  # Inverse Gram matrix

    def get_arms(self, X):
        self.X = X

    def update(self, t, arm_idx, reward):
        x = self.X[arm_idx]
        self.A.append(x)
        self.R.append(reward)

    def _irls(self, max_iter=10, tol=1e-3):
        # Solve logistic MLE by IRLS on history A,R
        if not self.A:
            return self.theta
        A = np.stack(self.A)
        R = np.array(self.R)
        theta = self.theta.copy()
        for _ in range(max_iter):
            z = A.dot(theta)
            mu = sigmoid(z)
            W = mu * (1 - mu)
            # add regularization
            reg = self.lam0 * np.eye(self.d)
            # Weighted Gram
            Xw = A * np.sqrt(W)[:,None]
            Gram = Xw.T.dot(Xw) + reg
            # Pseudo-response
            z_tilde = z + (R - mu) / np.maximum(W, 1e-6)
            b = A.T.dot(W * z_tilde)
            # Solve
            try:
                L = cholesky(Gram)
                theta_new = cho_solve((L, True), b)
            except np.linalg.LinAlgError:
                theta_new = solve(Gram, b)
            if np.linalg.norm(theta_new - theta) < tol:
                theta = theta_new
                break
            theta = theta_new
        self.theta = theta
        # compute inverse
        try:
            L = cholesky(Gram)
            self.Gram_inv = cho_solve((L, True), np.eye(self.d))
        except Exception:
            self.Gram_inv = inv(Gram)
        return theta


class UCBLog(LogBanditAlg):
    @staticmethod
    def print(): return "UCB-GLM"

    def get_arm(self, t):
        if t < self.forced_exploration:
            return int(self.rng.integers(self.K))
        # update MLE
        theta_hat = self._irls()
        # confidence width
        term1 = 2 * np.log(1 / self.delta)
        term2 = self.d * np.log(1 + t / self.lam0)
        beta = np.sqrt(term1 + term2)
        scale = self.crs * (beta if self.beta_flag else 1)
        # uncertainty
        proj = self.X.dot(self.Gram_inv)
        u = np.sqrt(np.einsum('ij,ij->i', proj, self.X))
        # score
        mu = sigmoid(self.X.dot(theta_hat))
        score = mu + scale * u
        return int(np.argmax(score))


class RandUCBLog(UCBLog):
    @staticmethod
    def print(): return "RandUCBLog"

    def __init__(self, env, n, params):
        super().__init__(env, n, params)
        self.M = params.get('M', 20)
        self.pdist = params.get('pdist', 'normal')
        self.pstd = params.get('pnormal_std', 0.125)
        self.is_coupled = params.get('is_coupled', True)

    def get_arm(self, t):
        if t < self.forced_exploration:
            return int(self.rng.integers(self.K))
        theta_hat = self._irls()
        # beta
        term1 = 2 * np.log(1 / self.delta)
        term2 = self.d * np.log(1 + t / self.lam0)
        beta = np.sqrt(term1 + term2)
        scale = self.crs * (beta if self.beta_flag else 1)
        # uncertainty
        proj = self.X.dot(self.Gram_inv)
        u = np.sqrt(np.einsum('ij,ij->i', proj, self.X))
        # sample random a
        x = np.linspace(-1,1,self.M)
        if self.pdist=='normal':
            p = np.exp(-x**2/(2*self.pstd**2))
        else:
            p = np.ones_like(x)
        p/=p.sum()
        if self.is_coupled:
            a = x[self.rng.choice(self.M, p=p)]
        else:
            a = x[self.rng.choice(self.M, size=self.K, p=p)]
        mu = sigmoid(self.X.dot(theta_hat))
        score = mu + a*scale*u
        return int(np.argmax(score))


class LogGreedy(LogBanditAlg):
    @staticmethod
    def print(): return "log_e-greedy"

    def __init__(self, env, n, params):
        super().__init__(env, n, params)
        self.epsilon = params.get('epsilon', 0.05)

    def get_arm(self, t):
        if t < self.forced_exploration or self.rng.random() < self.epsilon:
            return int(self.rng.integers(self.K))
        theta_hat = self._irls()
        mu = sigmoid(self.X.dot(theta_hat))
        return int(np.argmax(mu))


class LogTS(LogBanditAlg):
    @staticmethod
    def print(): return "GLM-TS"

    def __init__(self, env, n, params):
        super().__init__(env, n, params)

    def get_arm(self, t):
        if t < self.forced_exploration:
            return int(self.rng.integers(self.K))
        theta_hat = self._irls()
        # posterior approx: covariance Gram_inv
        scale = self.crs * (1 if not self.beta_flag else np.sqrt(self.d * np.log(1+t)))
        z = self.rng.standard_normal(self.d)
        theta_tilde = theta_hat + scale * (cholesky(self.Gram_inv) @ z)
        mu = sigmoid(self.X.dot(theta_tilde))
        return int(np.argmax(mu))


class LogFPL(LogBanditAlg):
    @staticmethod
    def print(): return "GLM-FPL"

    def __init__(self, env, n, params):
        super().__init__(env, n, params)
        # 'a' is the perturbation scale
        self.a = params.get('a', 1.0)
        # You'll also need the regularization parameter from the base class
        self.lam0 = params.get('lam0', 1e-4)

    def get_arm(self, t):
        if not self.A:
            return self.rng.integers(self.K)

        past_R = np.array(self.R)
        noise = self.rng.uniform(-self.a, self.a, size=len(past_R))
        R_adj = np.clip(past_R + noise, 0, 1) # Perturbed rewards, clipped to [0, 1]

        theta_tilde = np.zeros(self.d)
        A = np.stack(self.A)

        for _ in range(10): # Max 10 IRLS iterations
            z = A.dot(theta_tilde)
            mu = sigmoid(z)
            weights = mu * (1 - mu)
            
            weights = np.maximum(weights, 1e-6)
            Xw = A * np.sqrt(weights)[:, None]
            Gram = Xw.T.dot(Xw) + self.lam0 * np.eye(self.d)
            z_tilde = z + (R_adj - mu) / weights
            b = A.T.dot(weights * z_tilde)

            try:
                L = np.linalg.cholesky(Gram)
                theta_new = cho_solve((L, True), b)
            except np.linalg.LinAlgError:
                theta_new = solve(Gram, b, assume_a='pos')

            if np.linalg.norm(theta_new - theta_tilde) < 1e-3:
                theta_tilde = theta_new
                break
            theta_tilde = theta_new

        scores = self.X.dot(theta_tilde)
        return int(np.argmax(scores))


class LogFP(LogBanditAlg):
    @staticmethod
    def print(): return "GLM-FP"

    def get_arm(self, t):
        # The rest of your code now runs only after some data is collected.
        theta_hat = self._irls()
    
        # Calculate the scaling factor for the perturbation
        beta = np.sqrt(self.d * np.log(1 + t / self.lam0) + 2 * np.log(1 / self.delta))
        scale = self.crs * (beta if self.beta_flag else 1.0)
        
        # This code will now execute safely
        proj = self.X.dot(self.Gram_inv)
        u = np.sqrt(np.einsum('ij,ij->i', proj, self.X))

        # Perturb the context vectors
        eta = self.rng.standard_normal(self.d)
        pert = scale * u
        tilde_X = self.X + np.outer(pert, eta)

        # Select the arm with the highest perturbed reward
        mu = sigmoid(tilde_X.dot(theta_hat))
        return int(np.argmax(mu))